add training script for polarization estimation, refactor model definitions, randomised polarisation support in data_loader

This commit is contained in:
Joseph Hopfmüller
2024-12-11 09:48:38 +01:00
parent 0e29b87395
commit 39ae13d0af
8 changed files with 1899 additions and 259 deletions

View File

@@ -0,0 +1,443 @@
from typing import Any
import lightning as L
import numpy as np
import torch
import torch.nn as nn
# import torch.nn.functional as F
from util.complexNN import DropoutComplex, Scale, ONNRect, EOActivation, energy_conserving, clamp, complex_mse_loss
from util.datasets import FiberRegenerationDataset
class regeneratorData(L.LightningDataModule):
def __init__(
self,
config_globs,
output_symbols,
output_dim,
dtype,
drop_first,
shuffle=True,
train_split=None,
batch_size=None,
loader_settings=None,
seed=None,
num_symbols=None,
test_globs=None,
):
super().__init__()
self._config_globs = config_globs
self._test_globs = test_globs
self._test_data_available = test_globs is not None
if self._test_data_available:
self.test_dataloader = self._test_dataloader
self._output_symbols = output_symbols
self._output_dim = output_dim
self._dtype = dtype
self._drop_first = drop_first
self._seed = seed
self._shuffle = shuffle
self._num_symbols = num_symbols
self._train_split = train_split if train_split is not None else 0.8
self.batch_size = batch_size if batch_size is not None else 1024
self._loader_settings = loader_settings if loader_settings is not None else {}
def _get_data(self):
self._data_train = FiberRegenerationDataset(
file_path=self._config_globs,
symbols=self._output_symbols,
output_dim=self._output_dim,
dtype=self._dtype,
real=not self._dtype.is_complex,
drop_first=self._drop_first,
num_symbols=self._num_symbols,
)
# self._data_plot = FiberRegenerationDataset(
# file_path=self._config_globs,
# symbols=self._output_symbols,
# output_dim=self._output_dim,
# dtype=self._dtype,
# real=not self._dtype.is_complex,
# drop_first=self._drop_first,
# num_symbols=400,
# )
if self._test_data_available:
self._data_test = FiberRegenerationDataset(
file_path=self._test_globs,
symbols=self._output_symbols,
output_dim=self._output_dim,
dtype=self._dtype,
real=not self._dtype.is_complex,
drop_first=self._drop_first,
num_symbols=self._num_symbols,
)
return self._data_train, self._data_test
return self._data_train
def _split_data(self, stage="fit", split=None, shuffle=None):
_split = split if split is not None else self._train_split
_shuffle = shuffle if shuffle is not None else self._shuffle
dataset_size = len(self._data_train)
indices = list(range(dataset_size))
split_index = int(np.floor(_split * dataset_size))
train_indices, valid_indices = indices[:split_index], indices[split_index:]
if _shuffle:
np.random.seed(self._seed)
np.random.shuffle(train_indices)
if _shuffle:
if stage == "fit" or stage == "predict":
self._train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
# if stage == "fit" or stage == "validate":
# self._valid_sampler = torch.utils.data.SubsetRandomSampler(valid_indices)
else:
if stage == "fit" or stage == "predict":
self._train_sampler = train_indices
if stage == "fit" or stage == "validate":
self._valid_sampler = valid_indices
if stage == "fit":
return self._train_sampler, self._valid_sampler
elif stage == "validate":
return self._valid_sampler
elif stage == "predict":
return self._train_sampler
def prepare_data(self):
self._get_data()
def setup(self, stage=None):
stage = stage or "fit"
self._split_data(stage=stage)
def train_dataloader(self):
return torch.utils.data.DataLoader(
self._data_train,
batch_size=self.batch_size,
sampler=self._train_sampler,
**self._loader_settings
)
def val_dataloader(self):
return torch.utils.data.DataLoader(
self._data_train,
batch_size=self.batch_size,
sampler=self._valid_sampler,
**self._loader_settings
)
def _test_dataloader(self):
return torch.utils.data.DataLoader(
self._data_test,
shuffle=self._shuffle,
batch_size=self.batch_size,
**self._loader_settings
)
def predict_dataloader(self):
return torch.utils.data.DataLoader(
self._data_plot,
shuffle=False,
batch_size=40,
pin_memory=True,
drop_last=True,
num_workers=4,
prefetch_factor=2,
)
# def plot_dataloader(self):
class regenerator(L.LightningModule):
def __init__(
self,
*dims,
layer_function=ONNRect,
layer_func_kwargs: dict | None = {"square": True},
act_function=EOActivation,
act_func_kwargs: dict | None = None,
parametrizations: list[dict] | None = [
{
"tensor_name": "weight",
"parametrization": energy_conserving,
},
{
"tensor_name": "alpha",
"parametrization": clamp,
},
{
"tensor_name": "alpha",
"parametrization": clamp,
},
],
dtype=torch.complex64,
dropout_prob=0.01,
scale_layers=False,
optimizer=torch.optim.AdamW,
optimizer_kwargs: dict | None = {
"lr": 0.01,
"amsgrad": True,
},
lr_scheduler=None,
lr_scheduler_kwargs: dict | None = {
"patience": 20,
"factor": 0.5,
"min_lr": 1e-6,
"cooldown": 10,
},
sps = 128,
# **kwargs,
):
torch.set_float32_matmul_precision('high')
layer_func_kwargs = layer_func_kwargs if layer_func_kwargs is not None else {}
act_func_kwargs = act_func_kwargs if act_func_kwargs is not None else {}
optimizer_kwargs = optimizer_kwargs if optimizer_kwargs is not None else {}
lr_scheduler_kwargs = lr_scheduler_kwargs if lr_scheduler_kwargs is not None else {}
super().__init__()
self.example_input_array = torch.randn(1, dims[0], dtype=dtype)
self._sps = sps
self.optimizer_settings = {
"optimizer": optimizer,
"optimizer_kwargs": optimizer_kwargs,
"lr_scheduler": lr_scheduler,
"lr_scheduler_kwargs": lr_scheduler_kwargs,
}
# if len(dims) == 0:
# try:
# dims = kwargs["dims"]
# except KeyError:
# raise ValueError("dims must be provided")
self._n_hidden_layers = len(dims) - 2
self.build_model(dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob, scale_layers)
def build_model(self, dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob, scale_layers):
input_layer = nn.Sequential(
layer_function(dims[0], dims[1], dtype=dtype, **layer_func_kwargs),
act_function(size=dims[1], **act_func_kwargs),
DropoutComplex(p=dropout_prob),
)
if scale_layers:
input_layer = nn.Sequential(Scale(dims[0]), input_layer)
self.layer_0 = input_layer
for i in range(1, self._n_hidden_layers):
layer = nn.Sequential(
layer_function(dims[i], dims[i + 1], dtype=dtype, **layer_func_kwargs),
act_function(size=dims[i + 1], **act_func_kwargs),
DropoutComplex(p=dropout_prob),
)
if scale_layers:
layer = nn.Sequential(Scale(dims[i]), layer)
setattr(self, f"layer_{i}", layer)
output_layer = nn.Sequential(
layer_function(dims[-2], dims[-1], dtype=dtype, **layer_func_kwargs),
act_function(size=dims[-1], **act_func_kwargs),
Scale(dims[-1]),
)
setattr(self, f"layer_{self._n_hidden_layers}", output_layer)
if parametrizations is not None:
self._apply_parametrizations(self, parametrizations)
def _apply_parametrizations(self, layer, parametrizations):
for sub_layer in layer.children():
if len(sub_layer._modules) > 0:
self._apply_parametrizations(sub_layer, parametrizations)
else:
for parametrization in parametrizations:
tensor_name = parametrization.get("tensor_name", None)
if tensor_name is None:
continue
parametrization_func = parametrization.get("parametrization", None)
if parametrization_func is None:
continue
param_kwargs = parametrization.get("kwargs", {})
if tensor_name in sub_layer._parameters:
parametrization_func(sub_layer, tensor_name, **param_kwargs)
def _trace_powers(self, enable, x, powers=None):
if not enable:
return
if powers is None:
powers = []
powers.append(x.abs().square().sum())
return powers
# def plot(self, mode):
# self.predict_step()
# def validation_epoch_end(self, outputs):
# x = torch.vstack([output['x'].view(output['x'].shape[0], -1, 2)[:, output['x'].shape[1]//2, :].squeeze() for output in outputs])
# y = torch.vstack([output['y'].view(output['y'].shape[0], -1, 2).squeeze() for output in outputs])
# y_hat = torch.vstack([output['y_hat'].view(output['y_hat'].shape[0], -1, 2).squeeze() for output in outputs])
# timesteps = torch.vstack([output['timesteps'].squeeze() for output in outputs])
# powers = torch.vstack([output['powers'] for output in outputs])
# return {'x': x, 'y': y, 'y_hat': y_hat, 'timesteps': timesteps, 'powers': powers}
def on_validation_epoch_end(self):
if self.current_epoch % 10 == 0 or self.current_epoch == self.trainer.max_epochs - 1 or self.current_epoch < 10:
x = self.val_outputs['x']
# x = x.view(x.shape[0], -1, 2)
# x = x[:, x.shape[1]//2, :].squeeze()
y = self.val_outputs['y']
# y = y.view(y.shape[0], -1, 2).squeeze()
y_hat = self.val_outputs['y_hat']
# y_hat = y_hat.view(y_hat.shape[0], -1, 2).squeeze()
timesteps = self.val_outputs['timesteps']
# timesteps = timesteps.squeeze()
powers = self.val_outputs['powers']
# powers = powers.squeeze()
fiber_in = x.detach().cpu().numpy()
fiber_out = y.detach().cpu().numpy()
regen = y_hat.detach().cpu().numpy()
timesteps = timesteps.detach().cpu().numpy()
# powers = np.array([power.detach().cpu().numpy() for power in powers])
# fiber_in = np.concat(fiber_in, axis=0)
# fiber_out = np.concat(fiber_out, axis=0)
# regen = np.concat(regen, axis=0)
# timesteps = np.concat(timesteps, axis=0)
# powers = powers.detach().cpu().numpy()
import gc
fig = self.plot_model_head(fiber_in, fiber_out, regen, timesteps, sps=self._sps)
self.logger.experiment.add_figure("model response", fig, self.current_epoch)
# fig = self.plot_model_eye(fiber_in, fiber_out, regen, timesteps, sps=self._sps)
# self.logger.experiment.add_figure("model eye", fig, self.current_epoch)
# fig = self.plot_model_powers(powers)
# self.logger.experiment.add_figure("powers", fig, self.current_epoch)
gc.collect()
# x, y, y_hat, timesteps, powers = self.validation_epoch_end(self.outputs)
# self.plot(x, y, y_hat, timesteps, powers)
def plot_model_head(self, fiber_in, fiber_out, regen, timesteps, sps):
import matplotlib
matplotlib.use("TkCairo")
import matplotlib.pyplot as plt
ordering = np.argsort(timesteps)
signals = [signal[ordering] for signal in [fiber_in, fiber_out, regen]]
timesteps = timesteps[ordering]
signals = [signal[:sps*40] for signal in signals]
timesteps = timesteps[:sps*40]
fig, axs = plt.subplots(1, 2, sharex=True, sharey=True)
fig.set_figwidth(16)
fig.set_figheight(4)
for i, ax in enumerate(axs):
for j, signal in enumerate(signals):
ax.plot(timesteps / sps, np.square(np.abs(signal[:,i])), label=["fiber in", "fiber out", "regen"][j] + [" x", " y"][i])
ax.set_xlabel("symbol")
ax.set_ylabel("amplitude")
ax.minorticks_on()
ax.tick_params(axis="y", which="minor", left=False, right=False)
ax.grid(which="major", axis="x")
ax.grid(which="minor", axis="x", linestyle=":")
ax.grid(which="major", axis="y")
ax.legend(loc="upper right")
fig.tight_layout()
return fig
def plot_model_eye(self, fiber_in, fiber_out, regen, timesteps, sps):
...
def plot_model_powers(self, powers):
...
def forward(self, x, trace_powers=False):
powers = self._trace_powers(trace_powers, x)
x = self.layer_0(x)
powers = self._trace_powers(trace_powers, x, powers)
for i in range(1, self._n_hidden_layers):
x = getattr(self, f"layer_{i}")(x)
powers = self._trace_powers(trace_powers, x, powers)
x = getattr(self, f"layer_{self._n_hidden_layers}")(x)
powers = self._trace_powers(trace_powers, x, powers)
if trace_powers:
return x, powers
return x
def configure_optimizers(self):
optimizer = self.optimizer_settings["optimizer"](
self.parameters(), **self.optimizer_settings["optimizer_kwargs"]
)
if self.optimizer_settings["lr_scheduler"] is not None:
lr_scheduler = self.optimizer_settings["lr_scheduler"](
optimizer, **self.optimizer_settings["lr_scheduler_kwargs"]
)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": lr_scheduler,
"monitor": "val_loss",
}
}
return {"optimizer": optimizer}
def training_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0):
x, y, timesteps = batch
y_hat = self(x)
loss = complex_mse_loss(y_hat, y, power=True)
self.log("train_loss", loss, on_epoch=True, on_step=True)
return loss
def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0):
x, y, timesteps = batch
if batch_idx == 0:
y_hat, powers = self.forward(x, trace_powers=True)
else:
y_hat = self.forward(x)
loss = complex_mse_loss(y_hat, y, power=True)
self.log("val_loss", loss, on_epoch=True)
y = y.view(y.shape[0], -1, 2).squeeze()
x = x.view(x.shape[0], -1, 2)
x = x[:, x.shape[1]//2, :].squeeze()
y_hat = y_hat.view(y_hat.shape[0], -1, 2).squeeze()
timesteps = timesteps.squeeze()
if batch_idx == 0:
powers = np.array([power.detach().cpu() for power in powers])
self.val_outputs = {"y": y, "x": x, "y_hat": y_hat, "timesteps": timesteps, "powers": powers}
else:
self.val_outputs["y"] = torch.vstack([self.val_outputs["y"], y])
self.val_outputs["x"] = torch.vstack([self.val_outputs["x"], x])
self.val_outputs["y_hat"] = torch.vstack([self.val_outputs["y_hat"], y_hat])
self.val_outputs["timesteps"] = torch.concat([self.val_outputs["timesteps"], timesteps], dim=0)
return loss
def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0):
x, y, timesteps = batch
y_hat = self(x)
loss = complex_mse_loss(y_hat, y, power=True)
self.log("test_loss", loss, on_epoch=True)
return loss
# def predict_step(self, batch, batch_idx):
# x, y, timesteps = batch
# y_hat = self(x)
# return y, x, y_hat, timesteps

View File

@@ -0,0 +1,204 @@
import torch
from torch.nn import Module, Sequential
from util.complexNN import (
DropoutComplex,
Scale,
ONNRect,
photodiode,
EOActivation,
polarimeter,
normalize_by_first
)
class polarisation_estimator2(Module):
def __init__(self):
super(polarisation_estimator2, self).__init__()
self.layers = Sequential(
polarimeter(),
torch.nn.Linear(4, 4),
torch.nn.ReLU(),
torch.nn.Dropout(p=0.01),
torch.nn.Linear(4, 4),
torch.nn.ReLU(),
torch.nn.Dropout(p=0.01),
torch.nn.Linear(4, 4),
)
def forward(self, x):
# x = self.polarimeter(x)
for layer in self.layers:
x = layer(x)
return x
class polarisation_estimator(Module):
def __init__(
self,
*dims,
layer_function=ONNRect,
layer_func_kwargs: dict | None = None,
output_layer_function=photodiode,
# output_layer_func_kwargs: dict | None = None,
act_function=EOActivation,
act_func_kwargs: dict | None = None,
parametrizations: list[dict] = None,
dtype=torch.float64,
dropout_prob=0.01,
scale_layers=False,
):
super(polarisation_estimator, self).__init__()
self._n_hidden_layers = len(dims) - 2
layer_func_kwargs = layer_func_kwargs or {}
act_func_kwargs = act_func_kwargs or {}
self.build_model(dims, layer_function, layer_func_kwargs, output_layer_function, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob, scale_layers)
def forward(self, x):
x = self.layer_0(x)
for i in range(1, self._n_hidden_layers):
x = getattr(self, f"layer_{i}")(x)
x = getattr(self, f"layer_{self._n_hidden_layers}")(x)
x = torch.remainder(x, torch.ones_like(x) * 2 * torch.pi)
return x.squeeze()
def build_model(self, dims, layer_function, layer_func_kwargs, output_layer_function, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob, scale_layers):
for i in range(0, self._n_hidden_layers):
self.add_module(f"layer_{i}", Sequential())
if scale_layers:
self.get_submodule(f"layer_{i}").add_module("scale", Scale(dims[i]))
module = layer_function(dims[i], dims[i + 1], dtype=dtype, **layer_func_kwargs)
self.get_submodule(f"layer_{i}").add_module("ONN", module)
module = act_function(size=dims[i + 1], **act_func_kwargs)
self.get_submodule(f"layer_{i}").add_module("activation", module)
module = DropoutComplex(p=dropout_prob)
self.get_submodule(f"layer_{i}").add_module("dropout", module)
self.add_module(f"layer_{self._n_hidden_layers}", Sequential())
if scale_layers:
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("scale", Scale(dims[-2]))
module = layer_function(dims[-2], dims[-1], dtype=dtype, **layer_func_kwargs)
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("ONN", module)
module = output_layer_function(size=dims[-1])
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("photodiode", module)
# module = normalize_by_first()
# self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("normalize", module)
if parametrizations is not None:
self._apply_parametrizations(self, parametrizations)
def _apply_parametrizations(self, layer, parametrizations):
for sub_layer in layer.children():
if len(sub_layer._modules) > 0:
self._apply_parametrizations(sub_layer, parametrizations)
else:
for parametrization in parametrizations:
tensor_name = parametrization.get("tensor_name", None)
if tensor_name is None:
continue
parametrization_func = parametrization.get("parametrization", None)
if parametrization_func is None:
continue
param_kwargs = parametrization.get("kwargs", {})
if tensor_name in sub_layer._parameters:
parametrization_func(sub_layer, tensor_name, **param_kwargs)
class regenerator(Module):
def __init__(
self,
*dims,
layer_function=ONNRect,
layer_func_kwargs: dict | None = None,
act_function=EOActivation,
act_func_kwargs: dict | None = None,
parametrizations: list[dict] = None,
dtype=torch.float64,
dropout_prob=0.01,
scale_layers=False,
):
super(regenerator, self).__init__()
self._n_hidden_layers = len(dims) - 2
layer_func_kwargs = layer_func_kwargs or {}
act_func_kwargs = act_func_kwargs or {}
self.build_model(dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob, scale_layers)
def build_model(self, dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob, scale_layers):
for i in range(0, self._n_hidden_layers):
self.add_module(f"layer_{i}", Sequential())
if scale_layers:
self.get_submodule(f"layer_{i}").add_module("scale", Scale(dims[i]))
module = layer_function(dims[i], dims[i + 1], dtype=dtype, **layer_func_kwargs)
self.get_submodule(f"layer_{i}").add_module("ONN", module)
module = act_function(size=dims[i + 1], **act_func_kwargs)
self.get_submodule(f"layer_{i}").add_module("activation", module)
module = DropoutComplex(p=dropout_prob)
self.get_submodule(f"layer_{i}").add_module("dropout", module)
self.add_module(f"layer_{self._n_hidden_layers}", Sequential())
if scale_layers:
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("scale", Scale(dims[-2]))
module = layer_function(dims[-2], dims[-1], dtype=dtype, **layer_func_kwargs)
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("ONN", module)
module = act_function(size=dims[-1], **act_func_kwargs)
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("activation", module)
# module = Scale(size=dims[-1])
# self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("out_scale", module)
if parametrizations is not None:
self._apply_parametrizations(self, parametrizations)
def _apply_parametrizations(self, layer, parametrizations):
for sub_layer in layer.children():
if len(sub_layer._modules) > 0:
self._apply_parametrizations(sub_layer, parametrizations)
else:
for parametrization in parametrizations:
tensor_name = parametrization.get("tensor_name", None)
if tensor_name is None:
continue
parametrization_func = parametrization.get("parametrization", None)
if parametrization_func is None:
continue
param_kwargs = parametrization.get("kwargs", {})
if tensor_name in sub_layer._parameters:
parametrization_func(sub_layer, tensor_name, **param_kwargs)
def _trace_powers(self, enable, x, powers=None):
if not enable:
return
if powers is None:
powers = []
powers.append(x.abs().square().sum())
return powers
def forward(self, x, trace_powers=False):
powers = self._trace_powers(trace_powers, x)
x = self.layer_0(x)
powers = self._trace_powers(trace_powers, x, powers)
for i in range(1, self._n_hidden_layers):
x = getattr(self, f"layer_{i}")(x)
powers = self._trace_powers(trace_powers, x, powers)
x = getattr(self, f"layer_{self._n_hidden_layers}")(x)
powers = self._trace_powers(trace_powers, x, powers)
if trace_powers:
return x, powers
return x

View File

@@ -20,6 +20,22 @@ class DataSettings:
xy_delay: tuple | float | int = 0
drop_first: int = 1000
train_split: float = 0.8
polarisations: tuple | list = (0,)
randomise_polarisations: bool = False
"""
change to:
config_path: tuple | list | None = None
dtype: torch.dtype | None = None
symbols: int | float = 1
output_dim: int = 2
shuffle: bool = True
drop_first: float | int = 0
train_split: float = 0.8
randomise_polarisations: bool = False
"""
# pytorch settings
@@ -30,8 +46,8 @@ class PytorchSettings:
device: str = "cuda"
dataloader_workers: int = 2
dataloader_prefetch: int = 2
dataloader_workers: int = 1
dataloader_prefetch: int = 1
save_models: bool = True
model_dir: str = ".models"
@@ -56,6 +72,24 @@ class ModelSettings:
model_layer_kwargs: dict | None = None
model_layer_parametrizations: list= field(default_factory=list)
"""
change to:
dims: tuple | list | None = None
layer_function: nn.Module | None = None
layer_func_kwargs: dict | None = None
activation_function: nn.Module | None = None
activation_func_kwargs: dict | None = None
output_function: nn.Module | None = None
output_func_kwargs: dict | None = None
dropout_function: nn.Module | None = None
dropout_func_kwargs: dict | None = None
scale_function: nn.Module | None = None
scale_func_kwargs: dict | None = None
parametrizations: list | None = None
"""
@dataclass
class OptimizerSettings:
@@ -65,6 +99,17 @@ class OptimizerSettings:
scheduler: str | None = None
scheduler_kwargs: dict | None = None
"""
change to:
optimizer: torch.optim.Optimizer | None = None
optimizer_kwargs: dict | None = None
learning_rate: float | None = None
scheduler: torch.optim.lr_scheduler | None = None
scheduler_kwargs: dict | None = None
"""
def _pruner_default_kwargs():
# MedianPruner

View File

@@ -37,6 +37,7 @@ from rich.console import Console
from util.datasets import FiberRegenerationDataset
import util
import hypertraining.models as models
from .settings import (
GlobalSettings,
@@ -59,8 +60,527 @@ def traverse_dict_update(target, source):
except TypeError:
target.__dict__[k] = v
def get_parameter_names_and_values(model):
def is_parametrized(module):
if hasattr(module, "parametrizations"):
return True
return False
class Trainer:
def _get_param_info(module, prefix='', parametrization=False):
param_list = []
for name, param in module.named_parameters(recurse = parametrization):
if parametrization and name.startswith("parametrizations"):
name_parts = name.split('.')
name = name_parts[1]
param = getattr(module, name)
full_name = prefix + ('.' if prefix else '') + name
param_value = param.data
param_list.append((full_name, param_value))
for child_name, child_module in module.named_children():
child_prefix = prefix + ('.' if prefix else '') + child_name
if child_name == "parametrizations":
continue
param_list.extend(_get_param_info(child_module, child_prefix, is_parametrized(child_module)))
return param_list
return _get_param_info(model)
class PolarizationTrainer:
def __init__(
self,
*,
global_settings=None,
data_settings=None,
pytorch_settings=None,
model_settings=None,
optimizer_settings=None,
console=None,
checkpoint_path=None,
settings_override=None,
reset_epoch=False,
):
self.mod = torch.pi/2
self.resume = checkpoint_path is not None
torch.serialization.add_safe_globals([
*util.complexNN.__all__,
GlobalSettings,
DataSettings,
ModelSettings,
OptimizerSettings,
PytorchSettings,
models.regenerator,
torch.nn.utils.parametrizations.orthogonal,
])
if self.resume:
print(f"loading checkpoint from {checkpoint_path}")
self.checkpoint_dict = torch.load(checkpoint_path, weights_only=True)
if settings_override is not None:
traverse_dict_update(self.checkpoint_dict["settings"], settings_override)
if reset_epoch:
self.checkpoint_dict["epoch"] = -1
self.global_settings: GlobalSettings = self.checkpoint_dict["settings"]["global_settings"]
self.data_settings: DataSettings = self.checkpoint_dict["settings"]["data_settings"]
self.pytorch_settings: PytorchSettings = self.checkpoint_dict["settings"]["pytorch_settings"]
self.model_settings: ModelSettings = self.checkpoint_dict["settings"]["model_settings"]
self.optimizer_settings: OptimizerSettings = self.checkpoint_dict["settings"]["optimizer_settings"]
else:
if global_settings is None:
global_settings = GlobalSettings()
raise UserWarning("Global settings not provided, using default settings")
if data_settings is None:
data_settings = DataSettings()
raise UserWarning("Data settings not provided, using default settings")
if pytorch_settings is None:
pytorch_settings = PytorchSettings()
raise UserWarning("Pytorch settings not provided, using default settings")
if model_settings is None:
model_settings = ModelSettings()
raise UserWarning("Model settings not provided, using default settings")
if optimizer_settings is None:
optimizer_settings = OptimizerSettings()
raise UserWarning("Optimizer settings not provided, using default settings")
self.global_settings: GlobalSettings = global_settings
self.data_settings: DataSettings = data_settings
self.pytorch_settings: PytorchSettings = pytorch_settings
self.model_settings: ModelSettings = model_settings
self.optimizer_settings: OptimizerSettings = optimizer_settings
self.console = console or Console()
self.writer = None
def setup_tb_writer(self, append=None):
log_dir = self.pytorch_settings.summary_dir + "/pol_" + (datetime.now().strftime("%Y%m%d_%H%M%S"))
if append is not None:
log_dir += "_" + str(append)
print(f"Logging to {log_dir}")
self.writer = SummaryWriter(log_dir=log_dir)
def save_checkpoint(self, save_dict, filename):
torch.save(save_dict, filename)
def build_checkpoint_dict(self, loss=None, epoch=None):
return {
"epoch": -1 if epoch is None else epoch,
"loss": float("inf") if loss is None else loss,
"model_state_dict": copy.deepcopy(self.model.state_dict()),
"optimizer_state_dict": copy.deepcopy(self.optimizer.state_dict()),
"scheduler_state_dict": copy.deepcopy(self.scheduler.state_dict()) if hasattr(self, "scheduler") else None,
"model_kwargs": copy.deepcopy(self.model_kwargs),
"settings": {
"global_settings": copy.deepcopy(self.global_settings),
"data_settings": copy.deepcopy(self.data_settings),
"pytorch_settings": copy.deepcopy(self.pytorch_settings),
"model_settings": copy.deepcopy(self.model_settings),
"optimizer_settings": copy.deepcopy(self.optimizer_settings),
},
}
def define_model(self, model_kwargs=None):
if self.resume:
model_kwargs = self.checkpoint_dict["model_kwargs"]
else:
model_kwargs = model_kwargs
if model_kwargs is None:
n_hidden_layers = self.model_settings.n_hidden_layers
input_dim = 2 * self.data_settings.output_size
dtype = getattr(torch, self.data_settings.dtype)
afunc = getattr(util.complexNN, self.model_settings.model_activation_func)
layer_func = getattr(util.complexNN, self.model_settings.model_layer_function)
layer_parametrizations = self.model_settings.model_layer_parametrizations
hidden_dims = [self.model_settings.overrides.get(f"n_hidden_nodes_{i}") for i in range(n_hidden_layers)]
self.model_kwargs = {
"dims": (input_dim, *hidden_dims, self.model_settings.output_dim),
"layer_function": layer_func,
"layer_func_kwargs": self.model_settings.model_layer_kwargs,
"act_function": afunc,
"act_func_kwargs": None,
"parametrizations": layer_parametrizations,
"dtype": dtype,
"dropout_prob": self.model_settings.dropout_prob,
"scale_layers": self.model_settings.scale,
}
else:
self.model_kwargs = model_kwargs
input_dim = self.model_kwargs["dims"][0]
dtype = self.model_kwargs["dtype"]
# dims = self.model_kwargs.pop("dims")
model_kwargs = copy.deepcopy(self.model_kwargs)
self.model = models.polarisation_estimator(*model_kwargs.pop('dims'),**model_kwargs)
# self.model = models.polarisation_estimator2()
if self.writer is not None:
try:
self.writer.add_graph(self.model, torch.rand(1, input_dim, dtype=dtype), use_strict_trace=False)
except RuntimeError:
self.writer.add_graph(self.model, torch.rand(1, 2, dtype=dtype), use_strict_trace=False)
self.model = self.model.to(self.pytorch_settings.device)
if self.resume:
self.model.load_state_dict(self.checkpoint_dict["model_state_dict"], strict=False)
def get_sliced_data(self, override=None):
symbols = self.data_settings.symbols
in_out_delay = self.data_settings.in_out_delay
xy_delay = self.data_settings.xy_delay
data_size = self.data_settings.output_size
dtype = getattr(torch, self.data_settings.dtype)
num_symbols = None
config_path = self.data_settings.config_path
polarisations = self.data_settings.polarisations
randomise_polarisations = self.data_settings.randomise_polarisations
if override is not None:
num_symbols = override.get("num_symbols", None)
config_path = override.get("config_path", config_path)
polarisations = override.get("polarisations", polarisations)
randomise_polarisations = override.get("randomise_polarisation", randomise_polarisations)
# get dataset
dataset = FiberRegenerationDataset(
file_path=config_path,
symbols=symbols,
output_dim=data_size,
target_delay=in_out_delay,
xy_delay=xy_delay,
drop_first=self.data_settings.drop_first,
dtype=dtype,
real=not dtype.is_complex,
num_symbols=num_symbols,
polarisations=polarisations,
randomise_polarisations=randomise_polarisations,
)
dataset_size = len(dataset)
indices = list(range(dataset_size))
split = int(np.floor(self.data_settings.train_split * dataset_size))
if self.data_settings.shuffle:
np.random.seed(self.global_settings.seed)
np.random.shuffle(indices)
train_indices, valid_indices = indices[:split], indices[split:]
if self.data_settings.shuffle:
train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
valid_sampler = torch.utils.data.SubsetRandomSampler(valid_indices)
else:
train_sampler = train_indices
valid_sampler = valid_indices
train_loader = torch.utils.data.DataLoader(
dataset,
batch_size=self.pytorch_settings.batchsize,
sampler=train_sampler,
drop_last=True,
pin_memory=True,
num_workers=self.pytorch_settings.dataloader_workers,
prefetch_factor=self.pytorch_settings.dataloader_prefetch,
)
valid_loader = torch.utils.data.DataLoader(
dataset,
batch_size=self.pytorch_settings.batchsize,
sampler=valid_sampler,
drop_last=True,
pin_memory=True,
num_workers=self.pytorch_settings.dataloader_workers,
prefetch_factor=self.pytorch_settings.dataloader_prefetch,
)
return train_loader, valid_loader
def train_model(
self,
optimizer,
train_loader,
epoch,
enable_progress=False,
):
if enable_progress:
progress = Progress(
TextColumn("[yellow] Training..."),
TextColumn("Error: {task.description}"),
BarColumn(),
TaskProgressColumn(),
TextColumn("[green]Batch"),
MofNCompleteColumn(),
TimeRemainingColumn(),
TimeElapsedColumn(),
transient=False,
console=self.console,
refresh_per_second=10,
)
task = progress.add_task("-.---e--", total=len(train_loader))
progress.start()
running_loss2 = 0.0
running_loss = 0.0
self.model.train()
loader_len = len(train_loader)
write_div = 0
loss_div = 0
for batch_idx, batch in enumerate(train_loader):
x = batch["x"]
y = batch["sop"]
self.model.zero_grad(set_to_none=True)
x, y = (
x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device),
)
y_pred = self.model(x)
# loss = torch.nn.functional.smooth_l1_loss(torch.cos(torch.fmod(y_pred, torch.pi/2)).squeeze(), torch.cos(torch.fmod(y, torch.pi/2)).squeeze(), beta=0.5)
loss = torch.nn.functional.mse_loss(y_pred, y)
# loss = util.complexNN.naive_angle_loss(y_pred, y, mod=torch.pi/2)
loss_value = loss.item()
loss.backward()
optimizer.step()
running_loss += loss_value
running_loss2 += loss_value
write_div += 1
loss_div += 1
if enable_progress:
progress.update(task, advance=1, description=f"{loss_value:.3e}")
if batch_idx % self.pytorch_settings.write_every == 0:
self.writer.add_scalar(
"training loss",
running_loss2 / write_div,
epoch * loader_len + batch_idx,
)
running_loss2 = 0.0
write_div = 0
if enable_progress:
progress.stop()
return running_loss / loss_div
def eval_model(self, valid_loader, epoch, enable_progress=True):
if enable_progress:
progress = Progress(
TextColumn("[green]Evaluating..."),
TextColumn("Error: {task.description}"),
BarColumn(),
TaskProgressColumn(),
TextColumn("[green]Batch"),
MofNCompleteColumn(),
TimeRemainingColumn(),
TimeElapsedColumn(),
transient=False,
console=self.console,
refresh_per_second=10,
)
progress.start()
task = progress.add_task("-.---e--", total=len(valid_loader))
self.model.eval()
running_loss = 0
loss_div = 0
with torch.no_grad():
for _, batch in enumerate(valid_loader):
x = batch["x"]
y = batch["sop"]
x, y = (
x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device),
)
y_pred = self.model(x)
# loss = torch.nn.functional.smooth_l1_loss(torch.cos(torch.fmod(y_pred, torch.pi/2)).squeeze(), torch.cos(torch.fmod(y, torch.pi/2)).squeeze(), beta=0.5)
loss = torch.nn.functional.mse_loss(y_pred, y)
# loss = util.complexNN.naive_angle_loss(y_pred, y, mod=torch.pi/2)
loss_value = loss.item()
running_loss += loss_value
loss_div += 1
if enable_progress:
progress.update(task, advance=1, description=f"{loss_value:.3e}")
running_loss = running_loss/loss_div
self.writer.add_scalar(
"eval loss",
running_loss,
epoch,
)
# self.write_parameters(epoch + 1)
self.writer.flush()
if enable_progress:
progress.stop()
return running_loss
# def run_model(self, model, loader, trace_powers=False):
# model.eval()
# fiber_out = []
# fiber_in = []
# regen = []
# timestamps = []
# with torch.no_grad():
# model = model.to(self.pytorch_settings.device)
# for batch in loader:
# x = batch["x"]
# y = batch["angle"]
# timestamp = batch["timestamp"]
# plot_data = batch["plot_data"]
# x, y = (
# x.to(self.pytorch_settings.device),
# y.to(self.pytorch_settings.device),
# )
# if trace_powers:
# y_pred, powers = model(x, trace_powers).cpu()
# else:
# y_pred = model(x, trace_powers).cpu()
# # x = x.cpu()
# # y = y.cpu()
# y_pred = y_pred.view(y_pred.shape[0], -1, 2)
# y = y.view(y.shape[0], -1, 2)
# plot_data = plot_data.view(plot_data.shape[0], -1, 2)
# # x = x.view(x.shape[0], -1, 2)
# # timestamp = timestamp.view(-1, 1)
# fiber_out.append(plot_data.squeeze())
# fiber_in.append(y.squeeze())
# regen.append(y_pred.squeeze())
# timestamps.append(timestamp.squeeze())
# fiber_out = torch.vstack(fiber_out).cpu()
# fiber_in = torch.vstack(fiber_in).cpu()
# regen = torch.vstack(regen).cpu()
# timestamps = torch.concat(timestamps).cpu()
# if trace_powers:
# return fiber_in, fiber_out, regen, timestamps, powers
# return fiber_in, fiber_out, regen, timestamps
def write_parameters(self, epoch, attributes: list[str] | tuple[str] = None):
parameter_list = get_parameter_names_and_values(self.model)
for name, value in parameter_list:
plot = (attributes is None) or (name in attributes)
if plot:
vals: np.ndarray = value.detach().cpu().numpy().flatten()
if vals.ndim <= 1 and len(vals) == 1:
if np.iscomplexobj(vals):
self.writer.add_scalar(f"{name} (Mag)", np.abs(vals), epoch)
self.writer.add_scalar(f"{name} (Phase)", np.angle(vals), epoch)
else:
self.writer.add_scalar(f"{name}", vals, epoch)
else:
if np.iscomplexobj(vals):
self.writer.add_histogram(f"{name} (Mag)", np.abs(vals), epoch, bins="fd")
self.writer.add_histogram(f"{name} (Phase)", np.angle(vals), epoch, bins="fd")
else:
self.writer.add_histogram(f"{name}", vals, epoch, bins="fd")
def train(self):
if self.writer is None:
self.setup_tb_writer()
self.define_model()
print(
f"number of parameters (trainable): {sum(p.numel() for p in self.model.parameters())} ({sum(p.numel() for p in self.model.parameters() if p.requires_grad)})"
)
# self.write_parameters(0)
if isinstance(self.data_settings.config_path, (list, tuple)):
for i, config_path in enumerate(self.data_settings.config_path):
paths = Path.cwd().glob(config_path)
for j, path in enumerate(paths):
text = str(path) + '\n'
with open(path, 'r') as f:
text += f.read()
text += '\n'
self.writer.add_text(f"config_{i*len(self.data_settings.config_path)+j}", text)
elif isinstance(self.data_settings.config_path, str):
paths = Path.cwd().glob(self.data_settings.config_path)
for j, path in enumerate(paths):
text = str(path) + '\n'
with open(path, 'r') as f:
text += f.read()
text += '\n'
self.writer.add_text(f"config_{j}", text)
self.writer.flush()
train_loader, valid_loader = self.get_sliced_data()
optimizer_name = self.optimizer_settings.optimizer
self.optimizer: optim.Optimizer = getattr(optim, optimizer_name)(
self.model.parameters(), **self.optimizer_settings.optimizer_kwargs
)
if self.optimizer_settings.scheduler is not None:
self.scheduler = getattr(optim.lr_scheduler, self.optimizer_settings.scheduler)(
self.optimizer, **self.optimizer_settings.scheduler_kwargs
)
self.writer.add_scalar("learning rate", self.optimizer.param_groups[0]["lr"], -1)
if not self.resume:
self.best = self.build_checkpoint_dict()
else:
self.best = self.checkpoint_dict
self.best["loss"] = float("inf")
for epoch in range(self.best["epoch"] + 1, self.pytorch_settings.epochs):
enable_progress = True
if enable_progress:
self.console.rule(f"Epoch {epoch + 1}/{self.pytorch_settings.epochs}")
self.train_model(
self.optimizer,
train_loader,
epoch,
enable_progress=enable_progress,
)
loss = self.eval_model(
valid_loader,
epoch,
enable_progress=enable_progress,
)
if self.optimizer_settings.scheduler is not None:
self.scheduler.step(loss)
self.writer.add_scalar("learning rate", self.optimizer.param_groups[0]["lr"], epoch)
if self.pytorch_settings.save_models and self.model is not None:
save_path = (
Path(self.pytorch_settings.model_dir) / f"pol_{self.writer.get_logdir().split('/')[-1]}_{epoch}.tar"
)
save_path.parent.mkdir(parents=True, exist_ok=True)
checkpoint = self.build_checkpoint_dict(loss, epoch)
self.save_checkpoint(checkpoint, save_path)
if loss < self.best["loss"]:
self.best = checkpoint
save_path = (
Path(self.pytorch_settings.model_dir) / f"best_pol_{self.writer.get_logdir().split('/')[-1]}.tar"
)
save_path.parent.mkdir(parents=True, exist_ok=True)
self.save_checkpoint(self.best, save_path)
self.writer.flush()
self.writer.close()
return self.best
class RegenerationTrainer:
def __init__(
self,
*,
@@ -82,10 +602,11 @@ class Trainer:
ModelSettings,
OptimizerSettings,
PytorchSettings,
util.complexNN.regenerator,
models.regenerator,
torch.nn.utils.parametrizations.orthogonal,
])
if self.resume:
print(f"loading checkpoint from {checkpoint_path}")
self.checkpoint_dict = torch.load(checkpoint_path, weights_only=True)
if settings_override is not None:
traverse_dict_update(self.checkpoint_dict["settings"], settings_override)
@@ -170,11 +691,13 @@ class Trainer:
self.model_kwargs = {
"dims": (input_dim, *hidden_dims, self.model_settings.output_dim),
"layer_function": layer_func,
"layer_parametrizations": layer_parametrizations,
"activation_function": afunc,
"layer_func_kwargs": self.model_settings.model_layer_kwargs,
"act_function": afunc,
"act_func_kwargs": None,
"parametrizations": layer_parametrizations,
"dtype": dtype,
"dropout_prob": self.model_settings.dropout_prob,
"scale": self.model_settings.scale,
"scale_layers": self.model_settings.scale,
}
else:
self.model_kwargs = model_kwargs
@@ -182,7 +705,8 @@ class Trainer:
dtype = self.model_kwargs["dtype"]
# dims = self.model_kwargs.pop("dims")
self.model = util.complexNN.regenerator(**self.model_kwargs)
model_kwargs = copy.deepcopy(self.model_kwargs)
self.model = models.regenerator(*model_kwargs.pop('dims'),**model_kwargs)
if self.writer is not None:
self.writer.add_graph(self.model, torch.zeros(1, input_dim, dtype=dtype))
@@ -204,9 +728,13 @@ class Trainer:
num_symbols = None
config_path = self.data_settings.config_path
polarisations = self.data_settings.polarisations
randomise_polarisations = self.data_settings.randomise_polarisations
if override is not None:
num_symbols = override.get("num_symbols", None)
config_path = override.get("config_path", config_path)
polarisations = override.get("polarisations", polarisations)
randomise_polarisations = override.get("randomise_polarisation", randomise_polarisations)
# get dataset
dataset = FiberRegenerationDataset(
file_path=config_path,
@@ -218,6 +746,8 @@ class Trainer:
dtype=dtype,
real=not dtype.is_complex,
num_symbols=num_symbols,
polarisations=polarisations,
randomise_polarisations=randomise_polarisations,
)
dataset_size = len(dataset)
@@ -286,7 +816,9 @@ class Trainer:
running_loss = 0.0
self.model.train()
loader_len = len(train_loader)
for batch_idx, (x, y, _) in enumerate(train_loader):
for batch_idx, batch in enumerate(train_loader):
x = batch["x"]
y = batch["y"]
self.model.zero_grad(set_to_none=True)
x, y = (
x.to(self.pytorch_settings.device),
@@ -307,7 +839,7 @@ class Trainer:
self.writer.add_scalar(
"training loss",
running_loss2 / (self.pytorch_settings.write_every if batch_idx > 0 else 1),
epoch + batch_idx/loader_len,
epoch * loader_len + batch_idx,
)
running_loss2 = 0.0
@@ -337,7 +869,9 @@ class Trainer:
self.model.eval()
running_error = 0
with torch.no_grad():
for _, (x, y, _) in enumerate(valid_loader):
for _, batch in enumerate(valid_loader):
x = batch["x"]
y = batch["y"]
x, y = (
x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device),
@@ -360,37 +894,26 @@ class Trainer:
if (epoch + 1) % 10 == 0 or epoch < 10:
# plotting is slow, so only do it every 10 epochs
title_append, subtitle = self.build_title(epoch + 1)
head_fig, eye_fig, powers_fig = self.plot_model_response(
model=self.model,
title_append=title_append,
subtitle=subtitle,
show=False,
)
self.writer.add_figure(
"fiber response",
self.plot_model_response(
model=self.model,
title_append=title_append,
subtitle=subtitle,
show=False,
),
head_fig,
epoch + 1,
)
self.writer.add_figure(
"eye diagram",
self.plot_model_response(
model=self.model,
title_append=title_append,
subtitle=subtitle,
show=False,
mode="eye",
),
eye_fig,
epoch + 1,
)
self.writer.add_figure(
"powers",
self.plot_model_response(
model=self.model,
title_append=title_append,
subtitle=subtitle,
mode="powers",
show=False,
),
powers_fig,
epoch + 1,
)
@@ -411,7 +934,11 @@ class Trainer:
with torch.no_grad():
model = model.to(self.pytorch_settings.device)
for x, y, timestamp in loader:
for batch in loader:
x = batch["x"]
y = batch["y"]
timestamp = batch["timestamp"]
plot_data = batch["plot_data"]
x, y = (
x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device),
@@ -424,9 +951,11 @@ class Trainer:
# y = y.cpu()
y_pred = y_pred.view(y_pred.shape[0], -1, 2)
y = y.view(y.shape[0], -1, 2)
x = x.view(x.shape[0], -1, 2)
plot_data = plot_data.view(plot_data.shape[0], -1, 2)
# x = x.view(x.shape[0], -1, 2)
# timestamp = timestamp.view(-1, 1)
fiber_out.append(x[:, x.shape[1] // 2, :].squeeze())
fiber_out.append(plot_data.squeeze())
fiber_in.append(y.squeeze())
regen.append(y_pred.squeeze())
timestamps.append(timestamp.squeeze())
@@ -440,28 +969,23 @@ class Trainer:
return fiber_in, fiber_out, regen, timestamps
def write_parameters(self, epoch, attributes: list[str] | tuple[str] = None):
for i, layer in enumerate(self.model._layers):
tag = f"layer {i}"
if hasattr(layer, "parametrizations"):
attribute_pool = set(layer.parametrizations._modules) | set(layer._parameters)
else:
attribute_pool = set(layer._parameters)
for attribute in attribute_pool:
plot = (attributes is None) or (attribute in attributes)
if plot:
vals: np.ndarray = getattr(layer, attribute).detach().cpu().numpy().flatten()
if vals.ndim <= 1 and len(vals) == 1:
if np.iscomplexobj(vals):
self.writer.add_scalar(f"{tag} {attribute} (Mag)", np.abs(vals), epoch)
self.writer.add_scalar(f"{tag} {attribute} (Phase)", np.angle(vals), epoch)
else:
self.writer.add_scalar(f"{tag} {attribute}", vals, epoch)
parameter_list = get_parameter_names_and_values(self.model)
for name, value in parameter_list:
plot = (attributes is None) or (name in attributes)
if plot:
vals: np.ndarray = value.detach().cpu().numpy().flatten()
if vals.ndim <= 1 and len(vals) == 1:
if np.iscomplexobj(vals):
self.writer.add_scalar(f"{name} (Mag)", np.abs(vals), epoch)
self.writer.add_scalar(f"{name} (Phase)", np.angle(vals), epoch)
else:
if np.iscomplexobj(vals):
self.writer.add_histogram(f"{tag} {attribute} (Mag)", np.abs(vals), epoch, bins="fd")
self.writer.add_histogram(f"{tag} {attribute} (Phase)", np.angle(vals), epoch, bins="fd")
else:
self.writer.add_histogram(f"{tag} {attribute}", vals, epoch, bins="fd")
self.writer.add_scalar(f"{name}", vals, epoch)
else:
if np.iscomplexobj(vals):
self.writer.add_histogram(f"{name} (Mag)", np.abs(vals), epoch, bins="fd")
self.writer.add_histogram(f"{name} (Phase)", np.angle(vals), epoch, bins="fd")
else:
self.writer.add_histogram(f"{name}", vals, epoch, bins="fd")
def train(self):
if self.writer is None:
@@ -474,44 +998,48 @@ class Trainer:
)
title_append, subtitle = self.build_title(0)
head_fig, eye_fig, powers_fig = self.plot_model_response(
model=self.model,
title_append=title_append,
subtitle=subtitle,
show=False,
)
self.writer.add_figure(
"fiber response",
self.plot_model_response(
model=self.model,
title_append=title_append,
subtitle=subtitle,
show=False,
),
head_fig,
0,
)
self.writer.add_figure(
"eye diagram",
self.plot_model_response(
model=self.model,
title_append=title_append,
subtitle=subtitle,
mode="eye",
show=False,
),
eye_fig,
0,
)
self.writer.add_figure(
"powers",
self.plot_model_response(
model=self.model,
title_append=title_append,
subtitle=subtitle,
mode="powers",
show=False,
),
powers_fig,
0,
)
self.write_parameters(0)
self.writer.add_text("datasets", '\n'.join(self.data_settings.config_path))
if isinstance(self.data_settings.config_path, (list, tuple)):
for i, config_path in enumerate(self.data_settings.config_path):
paths = Path.cwd().glob(config_path)
for j, path in enumerate(paths):
text = str(path) + '\n'
with open(path, 'r') as f:
text += f.read()
text += '\n'
self.writer.add_text(f"config_{i*len(self.data_settings.config_path)+j}", text)
elif isinstance(self.data_settings.config_path, str):
paths = Path.cwd().glob(self.data_settings.config_path)
for j, path in enumerate(paths):
text = str(path) + '\n'
with open(path, 'r') as f:
text += f.read()
text += '\n'
self.writer.add_text(f"config_{j}", text)
self.writer.flush()
@@ -741,54 +1269,50 @@ class Trainer:
def plot_model_response(
self,
model=None,
model:torch.nn.Module=None,
title_append="",
subtitle="",
mode: Literal["eye", "head", "powers"] = "head",
# mode: Literal["eye", "head", "powers"] = "head",
show=False,
):
if mode == "powers":
input_data = torch.ones(
1, 2 * self.data_settings.output_size, dtype=getattr(torch, self.data_settings.dtype)
).to(self.pytorch_settings.device)
model = model.to(self.pytorch_settings.device)
model.eval()
with torch.no_grad():
_, powers = model(input_data, trace_powers=True)
input_data = torch.ones(
1, 2 * self.data_settings.output_size, dtype=getattr(torch, self.data_settings.dtype)
).to(self.pytorch_settings.device)
model = model.to(self.pytorch_settings.device)
model.eval()
with torch.no_grad():
_, powers = model(input_data, trace_powers=True)
powers = [power.item() for power in powers]
layer_names = ["input", *[str(x).split("(")[0] for x in model._layers._modules.values()]]
powers = [power.item() for power in powers]
layer_names = [name for (name, _) in model.named_children()]
# remove dropout layers
mask = [1 if "Dropout" not in layer_name else 0 for layer_name in layer_names]
layer_names = [layer_name for layer_name, m in zip(layer_names, mask) if m]
powers = [power for power, m in zip(powers, mask) if m]
fig = self._plot_model_response_powers(
powers, layer_names, title_append=title_append, subtitle=subtitle, show=show
)
return fig
power_fig = self._plot_model_response_powers(
powers, layer_names, title_append=title_append, subtitle=subtitle, show=show
)
data_settings_backup = copy.deepcopy(self.data_settings)
pytorch_settings_backup = copy.deepcopy(self.pytorch_settings)
self.data_settings.drop_first = 99.5 + random.randint(0, 1000)
self.data_settings.shuffle = False
self.data_settings.train_split = 1.0
self.pytorch_settings.batchsize = (
self.pytorch_settings.eye_symbols if mode == "eye" else self.pytorch_settings.head_symbols
)
self.pytorch_settings.batchsize = max(self.pytorch_settings.head_symbols, self.pytorch_settings.eye_symbols)
config_path = random.choice(self.data_settings.config_path) if isinstance(self.data_settings.config_path, (list, tuple)) else self.data_settings.config_path
fiber_length = int(float(str(config_path).split('-')[-7])/1000)
plot_loader, _ = self.get_sliced_data(
override={
"num_symbols": self.pytorch_settings.batchsize,
"config_path": config_path,
}
)
fiber_length = int(float(str(config_path).split('-')[4])/1000)
if not hasattr(self, "_plot_loader"):
self._plot_loader, _ = self.get_sliced_data(
override={
"num_symbols": self.pytorch_settings.batchsize,
"config_path": config_path,
"shuffle": False,
"polarisations": (np.random.rand(1)*np.pi*2,),
"randomise_polarisation": False,
}
)
self._sps = self._plot_loader.dataset.samples_per_symbol
self.data_settings = data_settings_backup
self.pytorch_settings = pytorch_settings_backup
fiber_in, fiber_out, regen, timestamps = self.run_model(model, plot_loader)
fiber_in, fiber_out, regen, timestamps = self.run_model(model, self._plot_loader)
fiber_in = fiber_in.view(-1, 2)
fiber_out = fiber_out.view(-1, 2)
regen = regen.view(-1, 2)
@@ -802,36 +1326,32 @@ class Trainer:
# https://github.com/matplotlib/matplotlib/issues/27713#issuecomment-1915497463
import gc
if mode == "head":
fig = self._plot_model_response_head(
fiber_in,
fiber_out,
regen,
timestamps=timestamps,
labels=("fiber in", "fiber out", "regen"),
sps=plot_loader.dataset.samples_per_symbol,
title_append=title_append + f" ({fiber_length} km)",
subtitle=subtitle,
show=show,
)
elif mode == "eye":
head_fig = self._plot_model_response_head(
fiber_in[:self.pytorch_settings.head_symbols*self._sps],
fiber_out[:self.pytorch_settings.head_symbols*self._sps],
regen[:self.pytorch_settings.head_symbols*self._sps],
timestamps=timestamps[:self.pytorch_settings.head_symbols*self._sps],
labels=("fiber in", "fiber out", "regen"),
sps=self._sps,
title_append=title_append + f" ({fiber_length} km)",
subtitle=subtitle,
show=show,
)
# raise NotImplementedError("Eye diagram not implemented")
fig = self._plot_model_response_eye(
fiber_in,
fiber_out,
regen,
timestamps=timestamps,
eye_fig = self._plot_model_response_eye(
fiber_in[:self.pytorch_settings.eye_symbols*self._sps],
fiber_out[:self.pytorch_settings.eye_symbols*self._sps],
regen[:self.pytorch_settings.eye_symbols*self._sps],
timestamps=timestamps[:self.pytorch_settings.eye_symbols*self._sps],
labels=("fiber in", "fiber out", "regen"),
sps=plot_loader.dataset.samples_per_symbol,
sps=self._sps,
title_append=title_append + f" ({fiber_length} km)",
subtitle=subtitle,
show=show,
)
else:
raise ValueError(f"Unknown mode: {mode}")
gc.collect()
return fig
return head_fig, eye_fig, power_fig
def build_title(self, number: int):
title_append = f"epoch {number}"

View File

@@ -1,7 +1,10 @@
from datetime import datetime
from pathlib import Path
import matplotlib
import numpy as np
import torch
import torch.utils.tensorboard
import torch.utils.tensorboard.summary
from hypertraining.settings import (
GlobalSettings,
DataSettings,
@@ -10,7 +13,7 @@ from hypertraining.settings import (
OptimizerSettings,
)
from hypertraining.training import Trainer
from hypertraining.training import RegenerationTrainer, PolarizationTrainer
# import torch
import json
@@ -23,7 +26,7 @@ global_settings = GlobalSettings(
)
data_settings = DataSettings(
config_path="data/20241204-13*-128-16384-100000-0-0-17-0-PAM4-0.ini",
config_path="data/*-128-16384-10000-0-0-0-0-PAM4-0-0.4.ini",
# config_path=[f"data/20241202-*-128-16384-{length}-0-0-17-0-PAM4-0.ini" for length in range(48000, 53000, 1000)],
dtype="complex64",
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
@@ -31,17 +34,16 @@ data_settings = DataSettings(
# output_size = (11, 32), # 26 taps -> 2 taps per input symbol -> 1 tap every 1cm (model has 52 inputs (x/y))
output_size=26, # study: single_core_regen_20241123_011232 (model_input_dim/2)
shuffle=True,
in_out_delay=0,
xy_delay=0,
drop_first=128 * 64,
drop_first=64,
train_split=0.8,
randomise_polarisations=True,
)
pytorch_settings = PytorchSettings(
epochs=10000,
batchsize=2**12,
batchsize=2**14,
device="cuda",
dataloader_workers=12,
dataloader_workers=16,
dataloader_prefetch=8,
summary_dir=".runs",
write_every=2**5,
@@ -51,12 +53,14 @@ pytorch_settings = PytorchSettings(
model_settings = ModelSettings(
output_dim=2,
n_hidden_layers=4,
n_hidden_layers=5,
overrides={
# "hidden_layer_dims": (8, 8, 4, 4),
"n_hidden_nodes_0": 8,
"n_hidden_nodes_1": 8,
"n_hidden_nodes_2": 4,
"n_hidden_nodes_3": 4,
"n_hidden_nodes_4": 2,
},
model_activation_func="EOActivation",
dropout_prob=0.01,
@@ -92,6 +96,14 @@ model_settings = ModelSettings(
"tensor_name": "scales",
"parametrization": util.complexNN.clamp,
},
{
"tensor_name": "angle",
"parametrization": util.complexNN.clamp,
"kwargs": {
"min": -torch.pi,
"max": torch.pi,
},
},
# {
# "tensor_name": "scale",
# "parametrization": util.complexNN.clamp,
@@ -143,7 +155,7 @@ def save_dict_to_file(dictionary, filename):
json.dump(dictionary, f, indent=4)
def sweep_lengths(*lengths, model=None, data_glob:str=None, strategy="newest"):
def sweep_lengths(*lengths, model=None, data_glob: str = None, strategy="newest"):
assert model is not None, "Model must be provided."
assert data_glob is not None, "Data glob must be provided."
model = model
@@ -153,9 +165,9 @@ def sweep_lengths(*lengths, model=None, data_glob:str=None, strategy="newest"):
regens = {}
timestampss = {}
trainer = Trainer(
checkpoint_path=model,
)
trainer = RegenerationTrainer(
checkpoint_path=model,
)
trainer.define_model()
for length in lengths:
@@ -165,13 +177,13 @@ def sweep_lengths(*lengths, model=None, data_glob:str=None, strategy="newest"):
continue
if strategy == "newest":
sorted_kwargs = {
'key': lambda x: x.stat().st_mtime,
'reverse': True,
"key": lambda x: x.stat().st_mtime,
"reverse": True,
}
elif strategy == "oldest":
sorted_kwargs = {
'key': lambda x: x.stat().st_mtime,
'reverse': False,
"key": lambda x: x.stat().st_mtime,
"reverse": False,
}
else:
raise ValueError(f"Unknown strategy {strategy}.")
@@ -186,22 +198,21 @@ def sweep_lengths(*lengths, model=None, data_glob:str=None, strategy="newest"):
timestampss[length] = timestamps
data = torch.zeros(2 * len(timestampss.keys()) + 2, 2, tuple(fiber_outs.values())[-1].shape[0])
channel_names = ["" for _ in range(2 * len(timestampss.keys())+2)]
channel_names = ["" for _ in range(2 * len(timestampss.keys()) + 2)]
data[1, 0, :] = timestampss[tuple(timestampss.keys())[-1]] / 128
data[1, 1, :] = fiber_ins[tuple(timestampss.keys())[-1]][:, 0].abs().square()
channel_names[1] = "fiber in x"
for li, length in enumerate(timestampss.keys()):
data[2+2 * li, 0, :] = timestampss[length] / 128
data[2+2 * li, 1, :] = fiber_outs[length][:, 0].abs().square()
data[2+2 * li + 1, 0, :] = timestampss[length] / 128
data[2+2 * li + 1, 1, :] = regens[length][:, 0].abs().square()
data[2 + 2 * li, 0, :] = timestampss[length] / 128
data[2 + 2 * li, 1, :] = fiber_outs[length][:, 0].abs().square()
data[2 + 2 * li + 1, 0, :] = timestampss[length] / 128
data[2 + 2 * li + 1, 1, :] = regens[length][:, 0].abs().square()
channel_names[2+2 * li+1] = f"regen x {length}"
channel_names[2+2 * li] = f"fiber out x {length}"
channel_names[2 + 2 * li + 1] = f"regen x {length}"
channel_names[2 + 2 * li] = f"fiber out x {length}"
# get current backend
backend = matplotlib.get_backend()
@@ -210,7 +221,7 @@ def sweep_lengths(*lengths, model=None, data_glob:str=None, strategy="newest"):
eye = util.eye_diagram.eye_diagram(data.to(dtype=torch.float32).detach().cpu().numpy(), channel_names=channel_names)
print_attrs = ("channel_name", "success", "min_area")
with np.printoptions(precision=3, suppress=True, formatter={'float': '{:0.3e}'.format}):
with np.printoptions(precision=3, suppress=True, formatter={"float": "{:0.3e}".format}):
for result in eye.eye_stats:
print_dict = {attr: result[attr] for attr in print_attrs}
rprint(print_dict)
@@ -221,18 +232,77 @@ def sweep_lengths(*lengths, model=None, data_glob:str=None, strategy="newest"):
if __name__ == "__main__":
lengths = range(90000, 100000+10000, 10000)
# lengths = range(90000, 100000+10000, 10000)
# lengths = [100000]
sweep_lengths(*lengths, model=".models/best_20241204_132605.tar", data_glob="data/202412*-{length}-*-0.ini", strategy="newest")
# sweep_lengths(*lengths, model=".models/best_20241204_132605.tar", data_glob="data/202412*-{length}-*-0.ini", strategy="newest")
# trainer = Trainer(
# global_settings=global_settings,
# data_settings=data_settings,
# pytorch_settings=pytorch_settings,
# model_settings=model_settings,
# optimizer_settings=optimizer_settings,
# # checkpoint_path=".models/best_20241202_143149.tar",
# # 20241202_143149
trainer = RegenerationTrainer(
global_settings=global_settings,
data_settings=data_settings,
pytorch_settings=pytorch_settings,
model_settings=model_settings,
optimizer_settings=optimizer_settings,
# checkpoint_path=".models/best_20241205_235929.tar",
# 20241202_143149
)
trainer.train()
# from hypertraining.lighning_models import regenerator, regeneratorData
# import lightning as L
# model = regenerator(
# 2 * data_settings.output_size,
# *model_settings.overrides["hidden_layer_dims"],
# model_settings.output_dim,
# layer_function=getattr(util.complexNN, model_settings.model_layer_function),
# layer_func_kwargs=model_settings.model_layer_kwargs,
# act_function=getattr(util.complexNN, model_settings.model_activation_func),
# act_func_kwargs=None,
# parametrizations=model_settings.model_layer_parametrizations,
# dtype=getattr(torch, data_settings.dtype),
# dropout_prob=model_settings.dropout_prob,
# scale_layers=model_settings.scale,
# optimizer=getattr(torch.optim, optimizer_settings.optimizer),
# optimizer_kwargs=optimizer_settings.optimizer_kwargs,
# lr_scheduler=getattr(torch.optim.lr_scheduler, optimizer_settings.scheduler),
# lr_scheduler_kwargs=optimizer_settings.scheduler_kwargs,
# )
# trainer.train()
# dm = regeneratorData(
# config_globs=data_settings.config_path,
# output_symbols=data_settings.symbols,
# output_dim=data_settings.output_size,
# dtype=getattr(torch, data_settings.dtype),
# drop_first=data_settings.drop_first,
# shuffle=data_settings.shuffle,
# train_split=data_settings.train_split,
# batch_size=pytorch_settings.batchsize,
# loader_settings={
# "num_workers": pytorch_settings.dataloader_workers,
# "prefetch_factor": pytorch_settings.dataloader_prefetch,
# "pin_memory": True,
# "drop_last": True,
# },
# seed=global_settings.seed,
# )
# # writer = L.SummaryWriter(pytorch_settings.summary_dir + f"/{datetime.now().strftime('%Y%m%d_%H%M%S')}")
# # from torch.utils.tensorboard import SummaryWriter
# subdir = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}"
# # writer = SummaryWriter(pytorch_settings.summary_dir + f"/{subdir}")
# logger = L.pytorch.loggers.TensorBoardLogger(pytorch_settings.summary_dir, name=subdir, log_graph=True)
# trainer = L.Trainer(
# fast_dev_run=False,
# # max_epochs=pytorch_settings.epochs,
# max_epochs=2,
# enable_checkpointing=True,
# default_root_dir=f".models/{subdir}/",
# logger=logger,
# )
# trainer.fit(model, dm)

View File

@@ -0,0 +1,230 @@
from datetime import datetime
from pathlib import Path
import matplotlib
import numpy as np
import torch
import torch.utils.tensorboard
import torch.utils.tensorboard.summary
from hypertraining.settings import (
GlobalSettings,
DataSettings,
PytorchSettings,
ModelSettings,
OptimizerSettings,
)
from hypertraining.training import RegenerationTrainer, PolarizationTrainer
# import torch
import json
import util
from rich import print as rprint
global_settings = GlobalSettings(
seed=0xC0FFEE,
)
data_settings = DataSettings(
config_path="data/*-128-16384-10000-0-0-0-0-PAM4-0-0.4.ini",
# config_path=[f"data/20241202-*-128-16384-{length}-0-0-17-0-PAM4-0.ini" for length in range(48000, 53000, 1000)],
dtype="complex64",
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
symbols=13, # study: single_core_regen_20241123_011232
# output_size = (11, 32), # 26 taps -> 2 taps per input symbol -> 1 tap every 1cm (model has 52 inputs (x/y))
output_size=26, # study: single_core_regen_20241123_011232 (model_input_dim/2)
shuffle=True,
drop_first=64,
train_split=0.8,
# polarisations=tuple(np.random.rand(2)*2*np.pi),
randomise_polarisations=True,
)
pytorch_settings = PytorchSettings(
epochs=10000,
batchsize=2**12,
device="cuda",
dataloader_workers=16,
dataloader_prefetch=8,
summary_dir=".runs",
write_every=2**5,
save_models=True,
model_dir=".models",
)
model_settings = ModelSettings(
output_dim=3,
n_hidden_layers=3,
overrides={
"n_hidden_nodes_0": 2,
"n_hidden_nodes_1": 2,
"n_hidden_nodes_2": 2,
},
dropout_prob=0.01,
model_layer_function="ONNRect",
model_activation_func="EOActivation",
model_layer_kwargs={"square": True},
scale=False,
model_layer_parametrizations=[
{
"tensor_name": "weight",
"parametrization": util.complexNN.energy_conserving,
},
{
"tensor_name": "alpha",
"parametrization": util.complexNN.clamp,
},
{
"tensor_name": "gain",
"parametrization": util.complexNN.clamp,
"kwargs": {
"min": 0,
"max": float("inf"),
},
},
{
"tensor_name": "phase_bias",
"parametrization": util.complexNN.clamp,
"kwargs": {
"min": 0,
"max": 2 * torch.pi,
},
},
{
"tensor_name": "scales",
"parametrization": util.complexNN.clamp,
},
{
"tensor_name": "angle",
"parametrization": util.complexNN.clamp,
"kwargs": {
"min": 0,
"max": 2*torch.pi,
},
},
{
"tensor_name": "loss",
"parametrization": util.complexNN.clamp,
},
],
)
optimizer_settings = OptimizerSettings(
optimizer="AdamW",
optimizer_kwargs={
"lr": 0.005,
"amsgrad": True,
# "weight_decay": 1e-7,
},
# learning_rate=0.05,
scheduler="ReduceLROnPlateau",
scheduler_kwargs={
"patience": 2**6,
"factor": 0.75,
# "threshold": 1e-3,
"min_lr": 1e-6,
"cooldown": 10,
},
)
def save_dict_to_file(dictionary, filename):
"""
Save the best dictionary to a JSON file.
:param best: Dictionary containing the best training results.
:type best: dict
:param filename: Path to the JSON file where the dictionary will be saved.
:type filename: str
"""
with open(filename, "w") as f:
json.dump(dictionary, f, indent=4)
def sweep_lengths(*lengths, model=None, data_glob: str = None, strategy="newest"):
assert model is not None, "Model must be provided."
assert data_glob is not None, "Data glob must be provided."
model = model
fiber_ins = {}
fiber_outs = {}
regens = {}
timestampss = {}
trainer = RegenerationTrainer(
checkpoint_path=model,
)
trainer.define_model()
for length in lengths:
data_glob_length = data_glob.replace("{length}", str(length))
files = list(Path.cwd().glob(data_glob_length))
if len(files) == 0:
continue
if strategy == "newest":
sorted_kwargs = {
"key": lambda x: x.stat().st_mtime,
"reverse": True,
}
elif strategy == "oldest":
sorted_kwargs = {
"key": lambda x: x.stat().st_mtime,
"reverse": False,
}
else:
raise ValueError(f"Unknown strategy {strategy}.")
file = sorted(files, **sorted_kwargs)[0]
loader, _ = trainer.get_sliced_data(override={"config_path": file})
fiber_in, fiber_out, regen, timestamps = trainer.run_model(trainer.model, loader=loader)
fiber_ins[length] = fiber_in
fiber_outs[length] = fiber_out
regens[length] = regen
timestampss[length] = timestamps
data = torch.zeros(2 * len(timestampss.keys()) + 2, 2, tuple(fiber_outs.values())[-1].shape[0])
channel_names = ["" for _ in range(2 * len(timestampss.keys()) + 2)]
data[1, 0, :] = timestampss[tuple(timestampss.keys())[-1]] / 128
data[1, 1, :] = fiber_ins[tuple(timestampss.keys())[-1]][:, 0].abs().square()
channel_names[1] = "fiber in x"
for li, length in enumerate(timestampss.keys()):
data[2 + 2 * li, 0, :] = timestampss[length] / 128
data[2 + 2 * li, 1, :] = fiber_outs[length][:, 0].abs().square()
data[2 + 2 * li + 1, 0, :] = timestampss[length] / 128
data[2 + 2 * li + 1, 1, :] = regens[length][:, 0].abs().square()
channel_names[2 + 2 * li + 1] = f"regen x {length}"
channel_names[2 + 2 * li] = f"fiber out x {length}"
# get current backend
backend = matplotlib.get_backend()
matplotlib.use("TkCairo")
eye = util.eye_diagram.eye_diagram(data.to(dtype=torch.float32).detach().cpu().numpy(), channel_names=channel_names)
print_attrs = ("channel_name", "success", "min_area")
with np.printoptions(precision=3, suppress=True, formatter={"float": "{:0.3e}".format}):
for result in eye.eye_stats:
print_dict = {attr: result[attr] for attr in print_attrs}
rprint(print_dict)
rprint()
eye.plot(all_stats=False)
matplotlib.use(backend)
if __name__ == "__main__":
trainer = PolarizationTrainer(
global_settings=global_settings,
data_settings=data_settings,
pytorch_settings=pytorch_settings,
model_settings=model_settings,
optimizer_settings=optimizer_settings,
# checkpoint_path='.models/pol_pol_20241208_122418_1116.tar',
# reset_epoch=True
)
trainer.train()

View File

@@ -260,12 +260,94 @@ class ONNRect(nn.Module):
self.crop = lambda x: x
self.crop.__doc__ = "No cropping"
def forward(self, x):
x = self.pad(x)
x = self.pad(x).to(dtype=self.weight.dtype)
out = self.crop((self.weight @ x.mT).mT)
return out
class polarimeter(nn.Module):
def __init__(self):
super(polarimeter, self).__init__()
# self.input_length = input_length
def forward(self, data):
# S0 = I
# S1 = (2*I_x - I)/I
# S2 = (2*I_45 - I)/I
# S3 = (2*I_RHC - I)/I
# # data: (batch, input_length*2) -> (batch, input_length, 2)
data = data.view(data.shape[0], -1, 2)
x = data[:, :, 0].mean(dim=1)
y = data[:, :, 1].mean(dim=1)
# x = x.mean(dim=1)
# y = y.mean(dim=1)
# angle = torch.atan2(y.abs().square().real, x.abs().square().real)
# return torch.stack([angle, angle, angle, angle], dim=1)
# horizontal polarisation
I_x = x.abs().square()
# vertical polarisation
I_y = y.abs().square()
# 45 degree polarisation
I_45 = (x + y).abs().square()
# right hand circular polarisation
I_RHC = (x + 1j*y).abs().square()
# S0 = I_x + I_y
# S1 = I_x - I_y
# S2 = I_45 - I_m45
# S3 = I_RHC - I_LHC
S0 = (I_x + I_y)
S1 = ((2*I_x - S0)/S0)
S2 = ((2*I_45 - S0)/S0)
S3 = ((2*I_RHC - S0)/S0)
return torch.stack([S0/S0, S1/S0, S2/S0, S3/S0], dim=1)
class normalize_by_first(nn.Module):
def __init__(self):
super(normalize_by_first, self).__init__()
def forward(self, data):
return data / data[:, 0].unsqueeze(1)
class photodiode(nn.Module):
def __init__(self, size, bias=True):
super(photodiode, self).__init__()
self.input_dim = size
self.scale = nn.Parameter(torch.rand(size))
self.pd_bias = nn.Parameter(torch.rand(size))
def forward(self, x):
return x.abs().square().to(dtype=x.dtype.to_real()).mul(self.scale).add(self.pd_bias)
class input_rotator(nn.Module):
def __init__(self, input_dim):
super(input_rotator, self).__init__()
assert input_dim % 2 == 0, "Input dimension must be even"
self.input_dim = input_dim
# self.angle = nn.Parameter(torch.randn(1, dtype=self.dtype.to_real()))
def forward(self, x, angle=None):
# take channels (0,1), (2,3), ... and rotate them by the angle
angle = angle or self.angle
sine = torch.sin(angle)
cosine = torch.cos(angle)
rot = torch.tensor([[cosine, -sine], [sine, cosine]], dtype=self.dtype)
return torch.matmul(x.view(-1, 2), rot).view(x.shape)
# def __repr__(self):
# return f"ONNRect({self.input_dim}, {self.output_dim})"
@@ -371,7 +453,7 @@ class Identity(nn.Module):
M(z) = z
"""
def __init__(self):
def __init__(self, size=None):
super(Identity, self).__init__()
def forward(self, x):
@@ -404,9 +486,28 @@ class MZISingle(nn.Module):
def forward(self, x: torch.Tensor):
return x * torch.exp(1j * self.phi) * torch.sin(self.omega + self.func(x))
def naive_angle_loss(x: torch.Tensor, target: torch.Tensor, mod=2*torch.pi):
return torch.fmod((x - target), mod).square().mean()
def cosine_loss(x: torch.Tensor, target: torch.Tensor):
return (2*(1 - torch.cos(x - target))).mean()
def angle_mse_loss(x: torch.Tensor, target: torch.Tensor):
x = torch.fmod(x, 2*torch.pi)
target = torch.fmod(target, 2*torch.pi)
x_cos = torch.cos(x)
x_sin = torch.sin(x)
target_cos = torch.cos(target)
target_sin = torch.sin(target)
cos_diff = x_cos - target_cos
sin_diff = x_sin - target_sin
squared_diff = cos_diff**2 + sin_diff**2
return squared_diff.mean()
class EOActivation(nn.Module):
def __init__(self, bias, size=None):
def __init__(self, size=None):
# 10.1109/SiPhotonics60897.2024.10543376
super(EOActivation, self).__init__()
if size is None:
@@ -571,81 +672,10 @@ class ZReLU(nn.Module):
return torch.relu(x)
class regenerator(nn.Module):
def __init__(
self,
*dims,
layer_function=ONN,
layer_kwargs: dict | None = None,
layer_parametrizations: list[dict] = None,
activation_function=Pow,
dtype=torch.float64,
dropout_prob=0.01,
scale=False,
**kwargs,
):
super(regenerator, self).__init__()
if len(dims) == 0:
try:
dims = kwargs["dims"]
except KeyError:
raise ValueError("dims must be provided")
self._n_hidden_layers = len(dims) - 2
self._layers = nn.Sequential()
if layer_kwargs is None:
layer_kwargs = {}
# self.powers = []
for i in range(self._n_hidden_layers + 1):
if scale:
self._layers.append(Scale(dims[i]))
self._layers.append(layer_function(dims[i], dims[i + 1], dtype=dtype, **layer_kwargs))
if i < self._n_hidden_layers:
if dropout_prob is not None:
self._layers.append(DropoutComplex(p=dropout_prob))
self._layers.append(activation_function(bias=True, size=dims[i + 1]))
self._layers.append(Scale(dims[-1]))
# add parametrizations
if layer_parametrizations is not None:
for layer in self._layers:
for layer_parametrization in layer_parametrizations:
tensor_name = layer_parametrization.get("tensor_name", None)
parametrization = layer_parametrization.get("parametrization", None)
param_kwargs = layer_parametrization.get("kwargs", {})
if tensor_name is not None and tensor_name in layer._parameters and parametrization is not None:
parametrization(layer, tensor_name, **param_kwargs)
# def __call__(self, input_x, **kwargs):
# return self.forward(input_x, **kwargs)
def forward(self, input_x, trace_powers=False):
x = input_x
if trace_powers:
powers = [x.abs().square().sum()]
# check if tracing
if torch.jit.is_tracing():
for layer in self._layers:
x = layer(x)
if trace_powers:
powers.append(x.abs().square().sum())
else:
# with torch.nn.utils.parametrize.cached():
for layer in self._layers:
x = layer(x)
if trace_powers:
powers.append(x.abs().square().sum())
if trace_powers:
return x, powers
return x
__all__ = [
complex_sse_loss,
complex_mse_loss,
angle_mse_loss,
UnitaryLayer,
unitary,
energy_conserving,
@@ -662,6 +692,7 @@ __all__ = [
ZReLU,
MZISingle,
EOActivation,
photodiode,
# SaturableAbsorberLambertW,
# SaturableAbsorber,
# SpreadLayer,

View File

@@ -54,7 +54,7 @@ def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=Fals
config["glova"]["nos"] = str(symbols)
data = np.concatenate([data, timestamps.reshape(-1,1)], axis=-1)
data = np.concatenate([data, timestamps.reshape(-1, 1)], axis=-1)
data = torch.tensor(data, device=device, dtype=dtype)
@@ -113,6 +113,8 @@ class FiberRegenerationDataset(Dataset):
dtype: torch.dtype = None,
real: bool = False,
device=None,
polarisations: tuple | list = (0,),
randomise_polarisations: bool = False,
**kwargs,
):
"""
@@ -145,6 +147,8 @@ class FiberRegenerationDataset(Dataset):
assert output_dim is None or output_dim > 0, "output_len must be positive or None"
assert drop_first >= 0, "drop_first must be non-negative"
self.randomise_polarisations = randomise_polarisations
faux = kwargs.pop("faux", False)
if faux:
@@ -165,7 +169,7 @@ class FiberRegenerationDataset(Dataset):
data_raw = None
self.config = None
files = []
for file_path in (file_path if isinstance(file_path, (tuple, list)) else [file_path]):
for file_path in file_path if isinstance(file_path, (tuple, list)) else [file_path]:
data, config = load_data(
file_path,
skipfirst=drop_first,
@@ -186,6 +190,19 @@ class FiberRegenerationDataset(Dataset):
files.append(config["data"]["file"].strip('"'))
self.config["data"]["file"] = str(files)
for i, angle in enumerate(torch.tensor(np.array(polarisations))):
data_raw_copy = data_raw.clone()
if angle == 0:
continue
sine = torch.sin(angle)
cosine = torch.cos(angle)
data_raw_copy[:, 2] = data_raw[:, 2] * cosine - data_raw[:, 3] * sine
data_raw_copy[:, 3] = data_raw[:, 2] * sine + data_raw[:, 3] * cosine
if i == 0:
data_raw = data_raw_copy
else:
data_raw = torch.cat([data_raw, data_raw_copy], dim=0)
self.device = data_raw.device
self.samples_per_symbol = int(self.config["glova"]["sps"])
@@ -258,17 +275,27 @@ class FiberRegenerationDataset(Dataset):
elif self.target_delay_samples < 0:
data_raw = data_raw[:, : self.target_delay_samples]
timestamps = data_raw[-1, :]
data_raw = data_raw[:-1, :]
timestamps = data_raw[4, :]
data_raw = data_raw[:4, :]
data_raw = data_raw.view(2, 2, -1)
timestamps_doubled = torch.cat([timestamps.unsqueeze(dim=0), timestamps.unsqueeze(dim=0)], dim=0).unsqueeze(dim=1)
timestamps_doubled = torch.cat([timestamps.unsqueeze(dim=0), timestamps.unsqueeze(dim=0)], dim=0).unsqueeze(
dim=1
)
data_raw = torch.cat([data_raw, timestamps_doubled], dim=1)
# data_raw = torch.cat([data_raw, timestamps_doubled], dim=1)
# data layout
# [ [E_in_x, E_in_y, timestamps],
# [E_out_x, E_out_y, timestamps] ]
self.data = data_raw.unfold(dimension=-1, size=self.samples_per_slice, step=1)
self.data = self.data.movedim(-2, 0)
if randomise_polarisations:
self.angles = torch.rand(self.data.shape[0]) * np.pi * 2
# self.data[:, 1, :2, :] = self.rotate(self.data[:, 1, :2, :], self.angles)
else:
self.angles = torch.zeros(self.data.shape[0])
# ...
# -> [no_slices, 2, 3, samples_per_slice]
# data layout
@@ -289,22 +316,92 @@ class FiberRegenerationDataset(Dataset):
else:
data_slice = self.data[idx].squeeze()
data_slice = data_slice[:, :, :data_slice.shape[2] // self.output_dim * self.output_dim]
data_slice = data_slice[:, :, : data_slice.shape[2] // self.output_dim * self.output_dim]
data_slice = data_slice.view(data_slice.shape[0], data_slice.shape[1], self.output_dim, -1)
target = data_slice[0, :, self.output_dim//2, 0]
data = data_slice[1, :, :, 0]
# if self.randomise_polarisations:
# angle = torch.rand(1) * torch.pi * 2
# sine = torch.sin(angle)
# cosine = torch.cos(angle)
# data_slice_ = data_slice[1]
# data_slice[1, 0] = data_slice_[0] * cosine - data_slice_[1] * sine
# data_slice[1,1] = data_slice_[0] * sine + data_slice_[1] * cosine
# else:
# angle = torch.zeros(1)
# data = data_slice[1, :2, :, 0]
angle = self.angles[idx]
data_index = 1
data_slice[1, :2, :, :] = self.rotate(data_slice[data_index, :2, :, :], angle)
data = data_slice[1, :2, :, 0]
# data = self.rotate(data, angle)
# for both polarisations (x/y), calculate the mean of the signal around the current symbol (-> corresponding to a lowpass filter)
angle_data = data_slice[1, :2, :, :].reshape(2, -1).mean(dim=1)
angle_data2 = self.complex_max(data_slice[1, :2, :, :].reshape(2, -1))
plot_data = data_slice[1, :2, self.output_dim // 2, 0]
sop = self.polarimeter(plot_data)
# angle_data = data_slice[1, :2, :self.output_dim//2, :].squeeze().reshape(2, -1).mean(dim=-1)
# angle = data_slice[1, 3, self.output_dim // 2, 0].real
target = data_slice[0, :2, self.output_dim // 2, 0]
target_timestamp = data_slice[0, 2, self.output_dim // 2, 0].real
...
# data_timestamps = data[-1,:].real
data = data[:-1, :]
target_timestamp = target[-1].real
target = target[:-1]
# data = data[:-1, :]
# target_timestamp = target[-1].real
# target = target[:-1]
# plot_data = plot_data[:-1]
# transpose to interleave the x and y data in the output tensor
data = data.transpose(0, 1).flatten().squeeze()
angle_data = angle_data.flatten().squeeze()
angle_data2 = angle_data.flatten().squeeze()
angle = angle.flatten().squeeze()
# data_timestamps = data_timestamps.flatten().squeeze()
target = target.flatten().squeeze()
target_timestamp = target_timestamp.flatten().squeeze()
return data, target, target_timestamp
return {"x": data, "y": target, "angle": angle, "sop": sop, "angle_data": angle_data, "angle_data2": angle_data2, "timestamp": target_timestamp, "plot_data": plot_data}
def complex_max(self, data, dim=-1):
# returns element(s) with the maximum absolute value along a given dimension
# ind = torch.argmax(data.abs(), dim=dim, keepdim=True)
# max_values = torch.gather(data, dim, ind).squeeze(dim=dim)
# return max_values
return torch.gather(data, dim, torch.argmax(data.abs(), dim=dim, keepdim=True)).squeeze(dim=dim)
def rotate(self, data, angle):
# rotates a 2d tensor by a given angle
# data: [2, ...]
# angle: [1]
# returns: [2, ...]
# get sine and cosine of the angle
sine = torch.sin(angle)
cosine = torch.cos(angle)
return torch.stack([data[0] * cosine - data[1] * sine, data[0] * sine + data[1] * cosine], dim=0)
def polarimeter(self, data):
# data: [2, ...] -> x, y
# returns [4] -> S0, S1, S2, S3
x = data[0].mean()
y = data[1].mean()
I_X = x.abs().square()
I_Y = y.abs().square()
I_45 = (x+y).abs().square()
I_RHC = (x + 1j*y).abs().square()
S0 = I_X + I_Y
S1 = (2*I_X - S0) / S0
S2 = (2*I_45 - S0) / S0
S3 = (2*I_RHC - S0) / S0
return torch.stack([S1, S2, S3], dim=0)