add training script for polarization estimation, refactor model definitions, randomised polarisation support in data_loader

This commit is contained in:
Joseph Hopfmüller
2024-12-11 09:48:38 +01:00
parent 0e29b87395
commit 39ae13d0af
8 changed files with 1899 additions and 259 deletions

View File

@@ -0,0 +1,443 @@
from typing import Any
import lightning as L
import numpy as np
import torch
import torch.nn as nn
# import torch.nn.functional as F
from util.complexNN import DropoutComplex, Scale, ONNRect, EOActivation, energy_conserving, clamp, complex_mse_loss
from util.datasets import FiberRegenerationDataset
class regeneratorData(L.LightningDataModule):
def __init__(
self,
config_globs,
output_symbols,
output_dim,
dtype,
drop_first,
shuffle=True,
train_split=None,
batch_size=None,
loader_settings=None,
seed=None,
num_symbols=None,
test_globs=None,
):
super().__init__()
self._config_globs = config_globs
self._test_globs = test_globs
self._test_data_available = test_globs is not None
if self._test_data_available:
self.test_dataloader = self._test_dataloader
self._output_symbols = output_symbols
self._output_dim = output_dim
self._dtype = dtype
self._drop_first = drop_first
self._seed = seed
self._shuffle = shuffle
self._num_symbols = num_symbols
self._train_split = train_split if train_split is not None else 0.8
self.batch_size = batch_size if batch_size is not None else 1024
self._loader_settings = loader_settings if loader_settings is not None else {}
def _get_data(self):
self._data_train = FiberRegenerationDataset(
file_path=self._config_globs,
symbols=self._output_symbols,
output_dim=self._output_dim,
dtype=self._dtype,
real=not self._dtype.is_complex,
drop_first=self._drop_first,
num_symbols=self._num_symbols,
)
# self._data_plot = FiberRegenerationDataset(
# file_path=self._config_globs,
# symbols=self._output_symbols,
# output_dim=self._output_dim,
# dtype=self._dtype,
# real=not self._dtype.is_complex,
# drop_first=self._drop_first,
# num_symbols=400,
# )
if self._test_data_available:
self._data_test = FiberRegenerationDataset(
file_path=self._test_globs,
symbols=self._output_symbols,
output_dim=self._output_dim,
dtype=self._dtype,
real=not self._dtype.is_complex,
drop_first=self._drop_first,
num_symbols=self._num_symbols,
)
return self._data_train, self._data_test
return self._data_train
def _split_data(self, stage="fit", split=None, shuffle=None):
_split = split if split is not None else self._train_split
_shuffle = shuffle if shuffle is not None else self._shuffle
dataset_size = len(self._data_train)
indices = list(range(dataset_size))
split_index = int(np.floor(_split * dataset_size))
train_indices, valid_indices = indices[:split_index], indices[split_index:]
if _shuffle:
np.random.seed(self._seed)
np.random.shuffle(train_indices)
if _shuffle:
if stage == "fit" or stage == "predict":
self._train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
# if stage == "fit" or stage == "validate":
# self._valid_sampler = torch.utils.data.SubsetRandomSampler(valid_indices)
else:
if stage == "fit" or stage == "predict":
self._train_sampler = train_indices
if stage == "fit" or stage == "validate":
self._valid_sampler = valid_indices
if stage == "fit":
return self._train_sampler, self._valid_sampler
elif stage == "validate":
return self._valid_sampler
elif stage == "predict":
return self._train_sampler
def prepare_data(self):
self._get_data()
def setup(self, stage=None):
stage = stage or "fit"
self._split_data(stage=stage)
def train_dataloader(self):
return torch.utils.data.DataLoader(
self._data_train,
batch_size=self.batch_size,
sampler=self._train_sampler,
**self._loader_settings
)
def val_dataloader(self):
return torch.utils.data.DataLoader(
self._data_train,
batch_size=self.batch_size,
sampler=self._valid_sampler,
**self._loader_settings
)
def _test_dataloader(self):
return torch.utils.data.DataLoader(
self._data_test,
shuffle=self._shuffle,
batch_size=self.batch_size,
**self._loader_settings
)
def predict_dataloader(self):
return torch.utils.data.DataLoader(
self._data_plot,
shuffle=False,
batch_size=40,
pin_memory=True,
drop_last=True,
num_workers=4,
prefetch_factor=2,
)
# def plot_dataloader(self):
class regenerator(L.LightningModule):
def __init__(
self,
*dims,
layer_function=ONNRect,
layer_func_kwargs: dict | None = {"square": True},
act_function=EOActivation,
act_func_kwargs: dict | None = None,
parametrizations: list[dict] | None = [
{
"tensor_name": "weight",
"parametrization": energy_conserving,
},
{
"tensor_name": "alpha",
"parametrization": clamp,
},
{
"tensor_name": "alpha",
"parametrization": clamp,
},
],
dtype=torch.complex64,
dropout_prob=0.01,
scale_layers=False,
optimizer=torch.optim.AdamW,
optimizer_kwargs: dict | None = {
"lr": 0.01,
"amsgrad": True,
},
lr_scheduler=None,
lr_scheduler_kwargs: dict | None = {
"patience": 20,
"factor": 0.5,
"min_lr": 1e-6,
"cooldown": 10,
},
sps = 128,
# **kwargs,
):
torch.set_float32_matmul_precision('high')
layer_func_kwargs = layer_func_kwargs if layer_func_kwargs is not None else {}
act_func_kwargs = act_func_kwargs if act_func_kwargs is not None else {}
optimizer_kwargs = optimizer_kwargs if optimizer_kwargs is not None else {}
lr_scheduler_kwargs = lr_scheduler_kwargs if lr_scheduler_kwargs is not None else {}
super().__init__()
self.example_input_array = torch.randn(1, dims[0], dtype=dtype)
self._sps = sps
self.optimizer_settings = {
"optimizer": optimizer,
"optimizer_kwargs": optimizer_kwargs,
"lr_scheduler": lr_scheduler,
"lr_scheduler_kwargs": lr_scheduler_kwargs,
}
# if len(dims) == 0:
# try:
# dims = kwargs["dims"]
# except KeyError:
# raise ValueError("dims must be provided")
self._n_hidden_layers = len(dims) - 2
self.build_model(dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob, scale_layers)
def build_model(self, dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob, scale_layers):
input_layer = nn.Sequential(
layer_function(dims[0], dims[1], dtype=dtype, **layer_func_kwargs),
act_function(size=dims[1], **act_func_kwargs),
DropoutComplex(p=dropout_prob),
)
if scale_layers:
input_layer = nn.Sequential(Scale(dims[0]), input_layer)
self.layer_0 = input_layer
for i in range(1, self._n_hidden_layers):
layer = nn.Sequential(
layer_function(dims[i], dims[i + 1], dtype=dtype, **layer_func_kwargs),
act_function(size=dims[i + 1], **act_func_kwargs),
DropoutComplex(p=dropout_prob),
)
if scale_layers:
layer = nn.Sequential(Scale(dims[i]), layer)
setattr(self, f"layer_{i}", layer)
output_layer = nn.Sequential(
layer_function(dims[-2], dims[-1], dtype=dtype, **layer_func_kwargs),
act_function(size=dims[-1], **act_func_kwargs),
Scale(dims[-1]),
)
setattr(self, f"layer_{self._n_hidden_layers}", output_layer)
if parametrizations is not None:
self._apply_parametrizations(self, parametrizations)
def _apply_parametrizations(self, layer, parametrizations):
for sub_layer in layer.children():
if len(sub_layer._modules) > 0:
self._apply_parametrizations(sub_layer, parametrizations)
else:
for parametrization in parametrizations:
tensor_name = parametrization.get("tensor_name", None)
if tensor_name is None:
continue
parametrization_func = parametrization.get("parametrization", None)
if parametrization_func is None:
continue
param_kwargs = parametrization.get("kwargs", {})
if tensor_name in sub_layer._parameters:
parametrization_func(sub_layer, tensor_name, **param_kwargs)
def _trace_powers(self, enable, x, powers=None):
if not enable:
return
if powers is None:
powers = []
powers.append(x.abs().square().sum())
return powers
# def plot(self, mode):
# self.predict_step()
# def validation_epoch_end(self, outputs):
# x = torch.vstack([output['x'].view(output['x'].shape[0], -1, 2)[:, output['x'].shape[1]//2, :].squeeze() for output in outputs])
# y = torch.vstack([output['y'].view(output['y'].shape[0], -1, 2).squeeze() for output in outputs])
# y_hat = torch.vstack([output['y_hat'].view(output['y_hat'].shape[0], -1, 2).squeeze() for output in outputs])
# timesteps = torch.vstack([output['timesteps'].squeeze() for output in outputs])
# powers = torch.vstack([output['powers'] for output in outputs])
# return {'x': x, 'y': y, 'y_hat': y_hat, 'timesteps': timesteps, 'powers': powers}
def on_validation_epoch_end(self):
if self.current_epoch % 10 == 0 or self.current_epoch == self.trainer.max_epochs - 1 or self.current_epoch < 10:
x = self.val_outputs['x']
# x = x.view(x.shape[0], -1, 2)
# x = x[:, x.shape[1]//2, :].squeeze()
y = self.val_outputs['y']
# y = y.view(y.shape[0], -1, 2).squeeze()
y_hat = self.val_outputs['y_hat']
# y_hat = y_hat.view(y_hat.shape[0], -1, 2).squeeze()
timesteps = self.val_outputs['timesteps']
# timesteps = timesteps.squeeze()
powers = self.val_outputs['powers']
# powers = powers.squeeze()
fiber_in = x.detach().cpu().numpy()
fiber_out = y.detach().cpu().numpy()
regen = y_hat.detach().cpu().numpy()
timesteps = timesteps.detach().cpu().numpy()
# powers = np.array([power.detach().cpu().numpy() for power in powers])
# fiber_in = np.concat(fiber_in, axis=0)
# fiber_out = np.concat(fiber_out, axis=0)
# regen = np.concat(regen, axis=0)
# timesteps = np.concat(timesteps, axis=0)
# powers = powers.detach().cpu().numpy()
import gc
fig = self.plot_model_head(fiber_in, fiber_out, regen, timesteps, sps=self._sps)
self.logger.experiment.add_figure("model response", fig, self.current_epoch)
# fig = self.plot_model_eye(fiber_in, fiber_out, regen, timesteps, sps=self._sps)
# self.logger.experiment.add_figure("model eye", fig, self.current_epoch)
# fig = self.plot_model_powers(powers)
# self.logger.experiment.add_figure("powers", fig, self.current_epoch)
gc.collect()
# x, y, y_hat, timesteps, powers = self.validation_epoch_end(self.outputs)
# self.plot(x, y, y_hat, timesteps, powers)
def plot_model_head(self, fiber_in, fiber_out, regen, timesteps, sps):
import matplotlib
matplotlib.use("TkCairo")
import matplotlib.pyplot as plt
ordering = np.argsort(timesteps)
signals = [signal[ordering] for signal in [fiber_in, fiber_out, regen]]
timesteps = timesteps[ordering]
signals = [signal[:sps*40] for signal in signals]
timesteps = timesteps[:sps*40]
fig, axs = plt.subplots(1, 2, sharex=True, sharey=True)
fig.set_figwidth(16)
fig.set_figheight(4)
for i, ax in enumerate(axs):
for j, signal in enumerate(signals):
ax.plot(timesteps / sps, np.square(np.abs(signal[:,i])), label=["fiber in", "fiber out", "regen"][j] + [" x", " y"][i])
ax.set_xlabel("symbol")
ax.set_ylabel("amplitude")
ax.minorticks_on()
ax.tick_params(axis="y", which="minor", left=False, right=False)
ax.grid(which="major", axis="x")
ax.grid(which="minor", axis="x", linestyle=":")
ax.grid(which="major", axis="y")
ax.legend(loc="upper right")
fig.tight_layout()
return fig
def plot_model_eye(self, fiber_in, fiber_out, regen, timesteps, sps):
...
def plot_model_powers(self, powers):
...
def forward(self, x, trace_powers=False):
powers = self._trace_powers(trace_powers, x)
x = self.layer_0(x)
powers = self._trace_powers(trace_powers, x, powers)
for i in range(1, self._n_hidden_layers):
x = getattr(self, f"layer_{i}")(x)
powers = self._trace_powers(trace_powers, x, powers)
x = getattr(self, f"layer_{self._n_hidden_layers}")(x)
powers = self._trace_powers(trace_powers, x, powers)
if trace_powers:
return x, powers
return x
def configure_optimizers(self):
optimizer = self.optimizer_settings["optimizer"](
self.parameters(), **self.optimizer_settings["optimizer_kwargs"]
)
if self.optimizer_settings["lr_scheduler"] is not None:
lr_scheduler = self.optimizer_settings["lr_scheduler"](
optimizer, **self.optimizer_settings["lr_scheduler_kwargs"]
)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": lr_scheduler,
"monitor": "val_loss",
}
}
return {"optimizer": optimizer}
def training_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0):
x, y, timesteps = batch
y_hat = self(x)
loss = complex_mse_loss(y_hat, y, power=True)
self.log("train_loss", loss, on_epoch=True, on_step=True)
return loss
def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0):
x, y, timesteps = batch
if batch_idx == 0:
y_hat, powers = self.forward(x, trace_powers=True)
else:
y_hat = self.forward(x)
loss = complex_mse_loss(y_hat, y, power=True)
self.log("val_loss", loss, on_epoch=True)
y = y.view(y.shape[0], -1, 2).squeeze()
x = x.view(x.shape[0], -1, 2)
x = x[:, x.shape[1]//2, :].squeeze()
y_hat = y_hat.view(y_hat.shape[0], -1, 2).squeeze()
timesteps = timesteps.squeeze()
if batch_idx == 0:
powers = np.array([power.detach().cpu() for power in powers])
self.val_outputs = {"y": y, "x": x, "y_hat": y_hat, "timesteps": timesteps, "powers": powers}
else:
self.val_outputs["y"] = torch.vstack([self.val_outputs["y"], y])
self.val_outputs["x"] = torch.vstack([self.val_outputs["x"], x])
self.val_outputs["y_hat"] = torch.vstack([self.val_outputs["y_hat"], y_hat])
self.val_outputs["timesteps"] = torch.concat([self.val_outputs["timesteps"], timesteps], dim=0)
return loss
def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0):
x, y, timesteps = batch
y_hat = self(x)
loss = complex_mse_loss(y_hat, y, power=True)
self.log("test_loss", loss, on_epoch=True)
return loss
# def predict_step(self, batch, batch_idx):
# x, y, timesteps = batch
# y_hat = self(x)
# return y, x, y_hat, timesteps

View File

@@ -0,0 +1,204 @@
import torch
from torch.nn import Module, Sequential
from util.complexNN import (
DropoutComplex,
Scale,
ONNRect,
photodiode,
EOActivation,
polarimeter,
normalize_by_first
)
class polarisation_estimator2(Module):
def __init__(self):
super(polarisation_estimator2, self).__init__()
self.layers = Sequential(
polarimeter(),
torch.nn.Linear(4, 4),
torch.nn.ReLU(),
torch.nn.Dropout(p=0.01),
torch.nn.Linear(4, 4),
torch.nn.ReLU(),
torch.nn.Dropout(p=0.01),
torch.nn.Linear(4, 4),
)
def forward(self, x):
# x = self.polarimeter(x)
for layer in self.layers:
x = layer(x)
return x
class polarisation_estimator(Module):
def __init__(
self,
*dims,
layer_function=ONNRect,
layer_func_kwargs: dict | None = None,
output_layer_function=photodiode,
# output_layer_func_kwargs: dict | None = None,
act_function=EOActivation,
act_func_kwargs: dict | None = None,
parametrizations: list[dict] = None,
dtype=torch.float64,
dropout_prob=0.01,
scale_layers=False,
):
super(polarisation_estimator, self).__init__()
self._n_hidden_layers = len(dims) - 2
layer_func_kwargs = layer_func_kwargs or {}
act_func_kwargs = act_func_kwargs or {}
self.build_model(dims, layer_function, layer_func_kwargs, output_layer_function, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob, scale_layers)
def forward(self, x):
x = self.layer_0(x)
for i in range(1, self._n_hidden_layers):
x = getattr(self, f"layer_{i}")(x)
x = getattr(self, f"layer_{self._n_hidden_layers}")(x)
x = torch.remainder(x, torch.ones_like(x) * 2 * torch.pi)
return x.squeeze()
def build_model(self, dims, layer_function, layer_func_kwargs, output_layer_function, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob, scale_layers):
for i in range(0, self._n_hidden_layers):
self.add_module(f"layer_{i}", Sequential())
if scale_layers:
self.get_submodule(f"layer_{i}").add_module("scale", Scale(dims[i]))
module = layer_function(dims[i], dims[i + 1], dtype=dtype, **layer_func_kwargs)
self.get_submodule(f"layer_{i}").add_module("ONN", module)
module = act_function(size=dims[i + 1], **act_func_kwargs)
self.get_submodule(f"layer_{i}").add_module("activation", module)
module = DropoutComplex(p=dropout_prob)
self.get_submodule(f"layer_{i}").add_module("dropout", module)
self.add_module(f"layer_{self._n_hidden_layers}", Sequential())
if scale_layers:
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("scale", Scale(dims[-2]))
module = layer_function(dims[-2], dims[-1], dtype=dtype, **layer_func_kwargs)
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("ONN", module)
module = output_layer_function(size=dims[-1])
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("photodiode", module)
# module = normalize_by_first()
# self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("normalize", module)
if parametrizations is not None:
self._apply_parametrizations(self, parametrizations)
def _apply_parametrizations(self, layer, parametrizations):
for sub_layer in layer.children():
if len(sub_layer._modules) > 0:
self._apply_parametrizations(sub_layer, parametrizations)
else:
for parametrization in parametrizations:
tensor_name = parametrization.get("tensor_name", None)
if tensor_name is None:
continue
parametrization_func = parametrization.get("parametrization", None)
if parametrization_func is None:
continue
param_kwargs = parametrization.get("kwargs", {})
if tensor_name in sub_layer._parameters:
parametrization_func(sub_layer, tensor_name, **param_kwargs)
class regenerator(Module):
def __init__(
self,
*dims,
layer_function=ONNRect,
layer_func_kwargs: dict | None = None,
act_function=EOActivation,
act_func_kwargs: dict | None = None,
parametrizations: list[dict] = None,
dtype=torch.float64,
dropout_prob=0.01,
scale_layers=False,
):
super(regenerator, self).__init__()
self._n_hidden_layers = len(dims) - 2
layer_func_kwargs = layer_func_kwargs or {}
act_func_kwargs = act_func_kwargs or {}
self.build_model(dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob, scale_layers)
def build_model(self, dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob, scale_layers):
for i in range(0, self._n_hidden_layers):
self.add_module(f"layer_{i}", Sequential())
if scale_layers:
self.get_submodule(f"layer_{i}").add_module("scale", Scale(dims[i]))
module = layer_function(dims[i], dims[i + 1], dtype=dtype, **layer_func_kwargs)
self.get_submodule(f"layer_{i}").add_module("ONN", module)
module = act_function(size=dims[i + 1], **act_func_kwargs)
self.get_submodule(f"layer_{i}").add_module("activation", module)
module = DropoutComplex(p=dropout_prob)
self.get_submodule(f"layer_{i}").add_module("dropout", module)
self.add_module(f"layer_{self._n_hidden_layers}", Sequential())
if scale_layers:
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("scale", Scale(dims[-2]))
module = layer_function(dims[-2], dims[-1], dtype=dtype, **layer_func_kwargs)
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("ONN", module)
module = act_function(size=dims[-1], **act_func_kwargs)
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("activation", module)
# module = Scale(size=dims[-1])
# self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("out_scale", module)
if parametrizations is not None:
self._apply_parametrizations(self, parametrizations)
def _apply_parametrizations(self, layer, parametrizations):
for sub_layer in layer.children():
if len(sub_layer._modules) > 0:
self._apply_parametrizations(sub_layer, parametrizations)
else:
for parametrization in parametrizations:
tensor_name = parametrization.get("tensor_name", None)
if tensor_name is None:
continue
parametrization_func = parametrization.get("parametrization", None)
if parametrization_func is None:
continue
param_kwargs = parametrization.get("kwargs", {})
if tensor_name in sub_layer._parameters:
parametrization_func(sub_layer, tensor_name, **param_kwargs)
def _trace_powers(self, enable, x, powers=None):
if not enable:
return
if powers is None:
powers = []
powers.append(x.abs().square().sum())
return powers
def forward(self, x, trace_powers=False):
powers = self._trace_powers(trace_powers, x)
x = self.layer_0(x)
powers = self._trace_powers(trace_powers, x, powers)
for i in range(1, self._n_hidden_layers):
x = getattr(self, f"layer_{i}")(x)
powers = self._trace_powers(trace_powers, x, powers)
x = getattr(self, f"layer_{self._n_hidden_layers}")(x)
powers = self._trace_powers(trace_powers, x, powers)
if trace_powers:
return x, powers
return x

View File

@@ -20,6 +20,22 @@ class DataSettings:
xy_delay: tuple | float | int = 0 xy_delay: tuple | float | int = 0
drop_first: int = 1000 drop_first: int = 1000
train_split: float = 0.8 train_split: float = 0.8
polarisations: tuple | list = (0,)
randomise_polarisations: bool = False
"""
change to:
config_path: tuple | list | None = None
dtype: torch.dtype | None = None
symbols: int | float = 1
output_dim: int = 2
shuffle: bool = True
drop_first: float | int = 0
train_split: float = 0.8
randomise_polarisations: bool = False
"""
# pytorch settings # pytorch settings
@@ -30,8 +46,8 @@ class PytorchSettings:
device: str = "cuda" device: str = "cuda"
dataloader_workers: int = 2 dataloader_workers: int = 1
dataloader_prefetch: int = 2 dataloader_prefetch: int = 1
save_models: bool = True save_models: bool = True
model_dir: str = ".models" model_dir: str = ".models"
@@ -56,6 +72,24 @@ class ModelSettings:
model_layer_kwargs: dict | None = None model_layer_kwargs: dict | None = None
model_layer_parametrizations: list= field(default_factory=list) model_layer_parametrizations: list= field(default_factory=list)
"""
change to:
dims: tuple | list | None = None
layer_function: nn.Module | None = None
layer_func_kwargs: dict | None = None
activation_function: nn.Module | None = None
activation_func_kwargs: dict | None = None
output_function: nn.Module | None = None
output_func_kwargs: dict | None = None
dropout_function: nn.Module | None = None
dropout_func_kwargs: dict | None = None
scale_function: nn.Module | None = None
scale_func_kwargs: dict | None = None
parametrizations: list | None = None
"""
@dataclass @dataclass
class OptimizerSettings: class OptimizerSettings:
@@ -65,6 +99,17 @@ class OptimizerSettings:
scheduler: str | None = None scheduler: str | None = None
scheduler_kwargs: dict | None = None scheduler_kwargs: dict | None = None
"""
change to:
optimizer: torch.optim.Optimizer | None = None
optimizer_kwargs: dict | None = None
learning_rate: float | None = None
scheduler: torch.optim.lr_scheduler | None = None
scheduler_kwargs: dict | None = None
"""
def _pruner_default_kwargs(): def _pruner_default_kwargs():
# MedianPruner # MedianPruner

View File

@@ -37,6 +37,7 @@ from rich.console import Console
from util.datasets import FiberRegenerationDataset from util.datasets import FiberRegenerationDataset
import util import util
import hypertraining.models as models
from .settings import ( from .settings import (
GlobalSettings, GlobalSettings,
@@ -59,8 +60,527 @@ def traverse_dict_update(target, source):
except TypeError: except TypeError:
target.__dict__[k] = v target.__dict__[k] = v
def get_parameter_names_and_values(model):
def is_parametrized(module):
if hasattr(module, "parametrizations"):
return True
return False
class Trainer: def _get_param_info(module, prefix='', parametrization=False):
param_list = []
for name, param in module.named_parameters(recurse = parametrization):
if parametrization and name.startswith("parametrizations"):
name_parts = name.split('.')
name = name_parts[1]
param = getattr(module, name)
full_name = prefix + ('.' if prefix else '') + name
param_value = param.data
param_list.append((full_name, param_value))
for child_name, child_module in module.named_children():
child_prefix = prefix + ('.' if prefix else '') + child_name
if child_name == "parametrizations":
continue
param_list.extend(_get_param_info(child_module, child_prefix, is_parametrized(child_module)))
return param_list
return _get_param_info(model)
class PolarizationTrainer:
def __init__(
self,
*,
global_settings=None,
data_settings=None,
pytorch_settings=None,
model_settings=None,
optimizer_settings=None,
console=None,
checkpoint_path=None,
settings_override=None,
reset_epoch=False,
):
self.mod = torch.pi/2
self.resume = checkpoint_path is not None
torch.serialization.add_safe_globals([
*util.complexNN.__all__,
GlobalSettings,
DataSettings,
ModelSettings,
OptimizerSettings,
PytorchSettings,
models.regenerator,
torch.nn.utils.parametrizations.orthogonal,
])
if self.resume:
print(f"loading checkpoint from {checkpoint_path}")
self.checkpoint_dict = torch.load(checkpoint_path, weights_only=True)
if settings_override is not None:
traverse_dict_update(self.checkpoint_dict["settings"], settings_override)
if reset_epoch:
self.checkpoint_dict["epoch"] = -1
self.global_settings: GlobalSettings = self.checkpoint_dict["settings"]["global_settings"]
self.data_settings: DataSettings = self.checkpoint_dict["settings"]["data_settings"]
self.pytorch_settings: PytorchSettings = self.checkpoint_dict["settings"]["pytorch_settings"]
self.model_settings: ModelSettings = self.checkpoint_dict["settings"]["model_settings"]
self.optimizer_settings: OptimizerSettings = self.checkpoint_dict["settings"]["optimizer_settings"]
else:
if global_settings is None:
global_settings = GlobalSettings()
raise UserWarning("Global settings not provided, using default settings")
if data_settings is None:
data_settings = DataSettings()
raise UserWarning("Data settings not provided, using default settings")
if pytorch_settings is None:
pytorch_settings = PytorchSettings()
raise UserWarning("Pytorch settings not provided, using default settings")
if model_settings is None:
model_settings = ModelSettings()
raise UserWarning("Model settings not provided, using default settings")
if optimizer_settings is None:
optimizer_settings = OptimizerSettings()
raise UserWarning("Optimizer settings not provided, using default settings")
self.global_settings: GlobalSettings = global_settings
self.data_settings: DataSettings = data_settings
self.pytorch_settings: PytorchSettings = pytorch_settings
self.model_settings: ModelSettings = model_settings
self.optimizer_settings: OptimizerSettings = optimizer_settings
self.console = console or Console()
self.writer = None
def setup_tb_writer(self, append=None):
log_dir = self.pytorch_settings.summary_dir + "/pol_" + (datetime.now().strftime("%Y%m%d_%H%M%S"))
if append is not None:
log_dir += "_" + str(append)
print(f"Logging to {log_dir}")
self.writer = SummaryWriter(log_dir=log_dir)
def save_checkpoint(self, save_dict, filename):
torch.save(save_dict, filename)
def build_checkpoint_dict(self, loss=None, epoch=None):
return {
"epoch": -1 if epoch is None else epoch,
"loss": float("inf") if loss is None else loss,
"model_state_dict": copy.deepcopy(self.model.state_dict()),
"optimizer_state_dict": copy.deepcopy(self.optimizer.state_dict()),
"scheduler_state_dict": copy.deepcopy(self.scheduler.state_dict()) if hasattr(self, "scheduler") else None,
"model_kwargs": copy.deepcopy(self.model_kwargs),
"settings": {
"global_settings": copy.deepcopy(self.global_settings),
"data_settings": copy.deepcopy(self.data_settings),
"pytorch_settings": copy.deepcopy(self.pytorch_settings),
"model_settings": copy.deepcopy(self.model_settings),
"optimizer_settings": copy.deepcopy(self.optimizer_settings),
},
}
def define_model(self, model_kwargs=None):
if self.resume:
model_kwargs = self.checkpoint_dict["model_kwargs"]
else:
model_kwargs = model_kwargs
if model_kwargs is None:
n_hidden_layers = self.model_settings.n_hidden_layers
input_dim = 2 * self.data_settings.output_size
dtype = getattr(torch, self.data_settings.dtype)
afunc = getattr(util.complexNN, self.model_settings.model_activation_func)
layer_func = getattr(util.complexNN, self.model_settings.model_layer_function)
layer_parametrizations = self.model_settings.model_layer_parametrizations
hidden_dims = [self.model_settings.overrides.get(f"n_hidden_nodes_{i}") for i in range(n_hidden_layers)]
self.model_kwargs = {
"dims": (input_dim, *hidden_dims, self.model_settings.output_dim),
"layer_function": layer_func,
"layer_func_kwargs": self.model_settings.model_layer_kwargs,
"act_function": afunc,
"act_func_kwargs": None,
"parametrizations": layer_parametrizations,
"dtype": dtype,
"dropout_prob": self.model_settings.dropout_prob,
"scale_layers": self.model_settings.scale,
}
else:
self.model_kwargs = model_kwargs
input_dim = self.model_kwargs["dims"][0]
dtype = self.model_kwargs["dtype"]
# dims = self.model_kwargs.pop("dims")
model_kwargs = copy.deepcopy(self.model_kwargs)
self.model = models.polarisation_estimator(*model_kwargs.pop('dims'),**model_kwargs)
# self.model = models.polarisation_estimator2()
if self.writer is not None:
try:
self.writer.add_graph(self.model, torch.rand(1, input_dim, dtype=dtype), use_strict_trace=False)
except RuntimeError:
self.writer.add_graph(self.model, torch.rand(1, 2, dtype=dtype), use_strict_trace=False)
self.model = self.model.to(self.pytorch_settings.device)
if self.resume:
self.model.load_state_dict(self.checkpoint_dict["model_state_dict"], strict=False)
def get_sliced_data(self, override=None):
symbols = self.data_settings.symbols
in_out_delay = self.data_settings.in_out_delay
xy_delay = self.data_settings.xy_delay
data_size = self.data_settings.output_size
dtype = getattr(torch, self.data_settings.dtype)
num_symbols = None
config_path = self.data_settings.config_path
polarisations = self.data_settings.polarisations
randomise_polarisations = self.data_settings.randomise_polarisations
if override is not None:
num_symbols = override.get("num_symbols", None)
config_path = override.get("config_path", config_path)
polarisations = override.get("polarisations", polarisations)
randomise_polarisations = override.get("randomise_polarisation", randomise_polarisations)
# get dataset
dataset = FiberRegenerationDataset(
file_path=config_path,
symbols=symbols,
output_dim=data_size,
target_delay=in_out_delay,
xy_delay=xy_delay,
drop_first=self.data_settings.drop_first,
dtype=dtype,
real=not dtype.is_complex,
num_symbols=num_symbols,
polarisations=polarisations,
randomise_polarisations=randomise_polarisations,
)
dataset_size = len(dataset)
indices = list(range(dataset_size))
split = int(np.floor(self.data_settings.train_split * dataset_size))
if self.data_settings.shuffle:
np.random.seed(self.global_settings.seed)
np.random.shuffle(indices)
train_indices, valid_indices = indices[:split], indices[split:]
if self.data_settings.shuffle:
train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
valid_sampler = torch.utils.data.SubsetRandomSampler(valid_indices)
else:
train_sampler = train_indices
valid_sampler = valid_indices
train_loader = torch.utils.data.DataLoader(
dataset,
batch_size=self.pytorch_settings.batchsize,
sampler=train_sampler,
drop_last=True,
pin_memory=True,
num_workers=self.pytorch_settings.dataloader_workers,
prefetch_factor=self.pytorch_settings.dataloader_prefetch,
)
valid_loader = torch.utils.data.DataLoader(
dataset,
batch_size=self.pytorch_settings.batchsize,
sampler=valid_sampler,
drop_last=True,
pin_memory=True,
num_workers=self.pytorch_settings.dataloader_workers,
prefetch_factor=self.pytorch_settings.dataloader_prefetch,
)
return train_loader, valid_loader
def train_model(
self,
optimizer,
train_loader,
epoch,
enable_progress=False,
):
if enable_progress:
progress = Progress(
TextColumn("[yellow] Training..."),
TextColumn("Error: {task.description}"),
BarColumn(),
TaskProgressColumn(),
TextColumn("[green]Batch"),
MofNCompleteColumn(),
TimeRemainingColumn(),
TimeElapsedColumn(),
transient=False,
console=self.console,
refresh_per_second=10,
)
task = progress.add_task("-.---e--", total=len(train_loader))
progress.start()
running_loss2 = 0.0
running_loss = 0.0
self.model.train()
loader_len = len(train_loader)
write_div = 0
loss_div = 0
for batch_idx, batch in enumerate(train_loader):
x = batch["x"]
y = batch["sop"]
self.model.zero_grad(set_to_none=True)
x, y = (
x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device),
)
y_pred = self.model(x)
# loss = torch.nn.functional.smooth_l1_loss(torch.cos(torch.fmod(y_pred, torch.pi/2)).squeeze(), torch.cos(torch.fmod(y, torch.pi/2)).squeeze(), beta=0.5)
loss = torch.nn.functional.mse_loss(y_pred, y)
# loss = util.complexNN.naive_angle_loss(y_pred, y, mod=torch.pi/2)
loss_value = loss.item()
loss.backward()
optimizer.step()
running_loss += loss_value
running_loss2 += loss_value
write_div += 1
loss_div += 1
if enable_progress:
progress.update(task, advance=1, description=f"{loss_value:.3e}")
if batch_idx % self.pytorch_settings.write_every == 0:
self.writer.add_scalar(
"training loss",
running_loss2 / write_div,
epoch * loader_len + batch_idx,
)
running_loss2 = 0.0
write_div = 0
if enable_progress:
progress.stop()
return running_loss / loss_div
def eval_model(self, valid_loader, epoch, enable_progress=True):
if enable_progress:
progress = Progress(
TextColumn("[green]Evaluating..."),
TextColumn("Error: {task.description}"),
BarColumn(),
TaskProgressColumn(),
TextColumn("[green]Batch"),
MofNCompleteColumn(),
TimeRemainingColumn(),
TimeElapsedColumn(),
transient=False,
console=self.console,
refresh_per_second=10,
)
progress.start()
task = progress.add_task("-.---e--", total=len(valid_loader))
self.model.eval()
running_loss = 0
loss_div = 0
with torch.no_grad():
for _, batch in enumerate(valid_loader):
x = batch["x"]
y = batch["sop"]
x, y = (
x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device),
)
y_pred = self.model(x)
# loss = torch.nn.functional.smooth_l1_loss(torch.cos(torch.fmod(y_pred, torch.pi/2)).squeeze(), torch.cos(torch.fmod(y, torch.pi/2)).squeeze(), beta=0.5)
loss = torch.nn.functional.mse_loss(y_pred, y)
# loss = util.complexNN.naive_angle_loss(y_pred, y, mod=torch.pi/2)
loss_value = loss.item()
running_loss += loss_value
loss_div += 1
if enable_progress:
progress.update(task, advance=1, description=f"{loss_value:.3e}")
running_loss = running_loss/loss_div
self.writer.add_scalar(
"eval loss",
running_loss,
epoch,
)
# self.write_parameters(epoch + 1)
self.writer.flush()
if enable_progress:
progress.stop()
return running_loss
# def run_model(self, model, loader, trace_powers=False):
# model.eval()
# fiber_out = []
# fiber_in = []
# regen = []
# timestamps = []
# with torch.no_grad():
# model = model.to(self.pytorch_settings.device)
# for batch in loader:
# x = batch["x"]
# y = batch["angle"]
# timestamp = batch["timestamp"]
# plot_data = batch["plot_data"]
# x, y = (
# x.to(self.pytorch_settings.device),
# y.to(self.pytorch_settings.device),
# )
# if trace_powers:
# y_pred, powers = model(x, trace_powers).cpu()
# else:
# y_pred = model(x, trace_powers).cpu()
# # x = x.cpu()
# # y = y.cpu()
# y_pred = y_pred.view(y_pred.shape[0], -1, 2)
# y = y.view(y.shape[0], -1, 2)
# plot_data = plot_data.view(plot_data.shape[0], -1, 2)
# # x = x.view(x.shape[0], -1, 2)
# # timestamp = timestamp.view(-1, 1)
# fiber_out.append(plot_data.squeeze())
# fiber_in.append(y.squeeze())
# regen.append(y_pred.squeeze())
# timestamps.append(timestamp.squeeze())
# fiber_out = torch.vstack(fiber_out).cpu()
# fiber_in = torch.vstack(fiber_in).cpu()
# regen = torch.vstack(regen).cpu()
# timestamps = torch.concat(timestamps).cpu()
# if trace_powers:
# return fiber_in, fiber_out, regen, timestamps, powers
# return fiber_in, fiber_out, regen, timestamps
def write_parameters(self, epoch, attributes: list[str] | tuple[str] = None):
parameter_list = get_parameter_names_and_values(self.model)
for name, value in parameter_list:
plot = (attributes is None) or (name in attributes)
if plot:
vals: np.ndarray = value.detach().cpu().numpy().flatten()
if vals.ndim <= 1 and len(vals) == 1:
if np.iscomplexobj(vals):
self.writer.add_scalar(f"{name} (Mag)", np.abs(vals), epoch)
self.writer.add_scalar(f"{name} (Phase)", np.angle(vals), epoch)
else:
self.writer.add_scalar(f"{name}", vals, epoch)
else:
if np.iscomplexobj(vals):
self.writer.add_histogram(f"{name} (Mag)", np.abs(vals), epoch, bins="fd")
self.writer.add_histogram(f"{name} (Phase)", np.angle(vals), epoch, bins="fd")
else:
self.writer.add_histogram(f"{name}", vals, epoch, bins="fd")
def train(self):
if self.writer is None:
self.setup_tb_writer()
self.define_model()
print(
f"number of parameters (trainable): {sum(p.numel() for p in self.model.parameters())} ({sum(p.numel() for p in self.model.parameters() if p.requires_grad)})"
)
# self.write_parameters(0)
if isinstance(self.data_settings.config_path, (list, tuple)):
for i, config_path in enumerate(self.data_settings.config_path):
paths = Path.cwd().glob(config_path)
for j, path in enumerate(paths):
text = str(path) + '\n'
with open(path, 'r') as f:
text += f.read()
text += '\n'
self.writer.add_text(f"config_{i*len(self.data_settings.config_path)+j}", text)
elif isinstance(self.data_settings.config_path, str):
paths = Path.cwd().glob(self.data_settings.config_path)
for j, path in enumerate(paths):
text = str(path) + '\n'
with open(path, 'r') as f:
text += f.read()
text += '\n'
self.writer.add_text(f"config_{j}", text)
self.writer.flush()
train_loader, valid_loader = self.get_sliced_data()
optimizer_name = self.optimizer_settings.optimizer
self.optimizer: optim.Optimizer = getattr(optim, optimizer_name)(
self.model.parameters(), **self.optimizer_settings.optimizer_kwargs
)
if self.optimizer_settings.scheduler is not None:
self.scheduler = getattr(optim.lr_scheduler, self.optimizer_settings.scheduler)(
self.optimizer, **self.optimizer_settings.scheduler_kwargs
)
self.writer.add_scalar("learning rate", self.optimizer.param_groups[0]["lr"], -1)
if not self.resume:
self.best = self.build_checkpoint_dict()
else:
self.best = self.checkpoint_dict
self.best["loss"] = float("inf")
for epoch in range(self.best["epoch"] + 1, self.pytorch_settings.epochs):
enable_progress = True
if enable_progress:
self.console.rule(f"Epoch {epoch + 1}/{self.pytorch_settings.epochs}")
self.train_model(
self.optimizer,
train_loader,
epoch,
enable_progress=enable_progress,
)
loss = self.eval_model(
valid_loader,
epoch,
enable_progress=enable_progress,
)
if self.optimizer_settings.scheduler is not None:
self.scheduler.step(loss)
self.writer.add_scalar("learning rate", self.optimizer.param_groups[0]["lr"], epoch)
if self.pytorch_settings.save_models and self.model is not None:
save_path = (
Path(self.pytorch_settings.model_dir) / f"pol_{self.writer.get_logdir().split('/')[-1]}_{epoch}.tar"
)
save_path.parent.mkdir(parents=True, exist_ok=True)
checkpoint = self.build_checkpoint_dict(loss, epoch)
self.save_checkpoint(checkpoint, save_path)
if loss < self.best["loss"]:
self.best = checkpoint
save_path = (
Path(self.pytorch_settings.model_dir) / f"best_pol_{self.writer.get_logdir().split('/')[-1]}.tar"
)
save_path.parent.mkdir(parents=True, exist_ok=True)
self.save_checkpoint(self.best, save_path)
self.writer.flush()
self.writer.close()
return self.best
class RegenerationTrainer:
def __init__( def __init__(
self, self,
*, *,
@@ -82,10 +602,11 @@ class Trainer:
ModelSettings, ModelSettings,
OptimizerSettings, OptimizerSettings,
PytorchSettings, PytorchSettings,
util.complexNN.regenerator, models.regenerator,
torch.nn.utils.parametrizations.orthogonal, torch.nn.utils.parametrizations.orthogonal,
]) ])
if self.resume: if self.resume:
print(f"loading checkpoint from {checkpoint_path}")
self.checkpoint_dict = torch.load(checkpoint_path, weights_only=True) self.checkpoint_dict = torch.load(checkpoint_path, weights_only=True)
if settings_override is not None: if settings_override is not None:
traverse_dict_update(self.checkpoint_dict["settings"], settings_override) traverse_dict_update(self.checkpoint_dict["settings"], settings_override)
@@ -170,11 +691,13 @@ class Trainer:
self.model_kwargs = { self.model_kwargs = {
"dims": (input_dim, *hidden_dims, self.model_settings.output_dim), "dims": (input_dim, *hidden_dims, self.model_settings.output_dim),
"layer_function": layer_func, "layer_function": layer_func,
"layer_parametrizations": layer_parametrizations, "layer_func_kwargs": self.model_settings.model_layer_kwargs,
"activation_function": afunc, "act_function": afunc,
"act_func_kwargs": None,
"parametrizations": layer_parametrizations,
"dtype": dtype, "dtype": dtype,
"dropout_prob": self.model_settings.dropout_prob, "dropout_prob": self.model_settings.dropout_prob,
"scale": self.model_settings.scale, "scale_layers": self.model_settings.scale,
} }
else: else:
self.model_kwargs = model_kwargs self.model_kwargs = model_kwargs
@@ -182,7 +705,8 @@ class Trainer:
dtype = self.model_kwargs["dtype"] dtype = self.model_kwargs["dtype"]
# dims = self.model_kwargs.pop("dims") # dims = self.model_kwargs.pop("dims")
self.model = util.complexNN.regenerator(**self.model_kwargs) model_kwargs = copy.deepcopy(self.model_kwargs)
self.model = models.regenerator(*model_kwargs.pop('dims'),**model_kwargs)
if self.writer is not None: if self.writer is not None:
self.writer.add_graph(self.model, torch.zeros(1, input_dim, dtype=dtype)) self.writer.add_graph(self.model, torch.zeros(1, input_dim, dtype=dtype))
@@ -204,9 +728,13 @@ class Trainer:
num_symbols = None num_symbols = None
config_path = self.data_settings.config_path config_path = self.data_settings.config_path
polarisations = self.data_settings.polarisations
randomise_polarisations = self.data_settings.randomise_polarisations
if override is not None: if override is not None:
num_symbols = override.get("num_symbols", None) num_symbols = override.get("num_symbols", None)
config_path = override.get("config_path", config_path) config_path = override.get("config_path", config_path)
polarisations = override.get("polarisations", polarisations)
randomise_polarisations = override.get("randomise_polarisation", randomise_polarisations)
# get dataset # get dataset
dataset = FiberRegenerationDataset( dataset = FiberRegenerationDataset(
file_path=config_path, file_path=config_path,
@@ -218,6 +746,8 @@ class Trainer:
dtype=dtype, dtype=dtype,
real=not dtype.is_complex, real=not dtype.is_complex,
num_symbols=num_symbols, num_symbols=num_symbols,
polarisations=polarisations,
randomise_polarisations=randomise_polarisations,
) )
dataset_size = len(dataset) dataset_size = len(dataset)
@@ -286,7 +816,9 @@ class Trainer:
running_loss = 0.0 running_loss = 0.0
self.model.train() self.model.train()
loader_len = len(train_loader) loader_len = len(train_loader)
for batch_idx, (x, y, _) in enumerate(train_loader): for batch_idx, batch in enumerate(train_loader):
x = batch["x"]
y = batch["y"]
self.model.zero_grad(set_to_none=True) self.model.zero_grad(set_to_none=True)
x, y = ( x, y = (
x.to(self.pytorch_settings.device), x.to(self.pytorch_settings.device),
@@ -307,7 +839,7 @@ class Trainer:
self.writer.add_scalar( self.writer.add_scalar(
"training loss", "training loss",
running_loss2 / (self.pytorch_settings.write_every if batch_idx > 0 else 1), running_loss2 / (self.pytorch_settings.write_every if batch_idx > 0 else 1),
epoch + batch_idx/loader_len, epoch * loader_len + batch_idx,
) )
running_loss2 = 0.0 running_loss2 = 0.0
@@ -337,7 +869,9 @@ class Trainer:
self.model.eval() self.model.eval()
running_error = 0 running_error = 0
with torch.no_grad(): with torch.no_grad():
for _, (x, y, _) in enumerate(valid_loader): for _, batch in enumerate(valid_loader):
x = batch["x"]
y = batch["y"]
x, y = ( x, y = (
x.to(self.pytorch_settings.device), x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device), y.to(self.pytorch_settings.device),
@@ -360,37 +894,26 @@ class Trainer:
if (epoch + 1) % 10 == 0 or epoch < 10: if (epoch + 1) % 10 == 0 or epoch < 10:
# plotting is slow, so only do it every 10 epochs # plotting is slow, so only do it every 10 epochs
title_append, subtitle = self.build_title(epoch + 1) title_append, subtitle = self.build_title(epoch + 1)
self.writer.add_figure( head_fig, eye_fig, powers_fig = self.plot_model_response(
"fiber response",
self.plot_model_response(
model=self.model, model=self.model,
title_append=title_append, title_append=title_append,
subtitle=subtitle, subtitle=subtitle,
show=False, show=False,
), )
self.writer.add_figure(
"fiber response",
head_fig,
epoch + 1, epoch + 1,
) )
self.writer.add_figure( self.writer.add_figure(
"eye diagram", "eye diagram",
self.plot_model_response( eye_fig,
model=self.model,
title_append=title_append,
subtitle=subtitle,
show=False,
mode="eye",
),
epoch + 1, epoch + 1,
) )
self.writer.add_figure( self.writer.add_figure(
"powers", "powers",
self.plot_model_response( powers_fig,
model=self.model,
title_append=title_append,
subtitle=subtitle,
mode="powers",
show=False,
),
epoch + 1, epoch + 1,
) )
@@ -411,7 +934,11 @@ class Trainer:
with torch.no_grad(): with torch.no_grad():
model = model.to(self.pytorch_settings.device) model = model.to(self.pytorch_settings.device)
for x, y, timestamp in loader: for batch in loader:
x = batch["x"]
y = batch["y"]
timestamp = batch["timestamp"]
plot_data = batch["plot_data"]
x, y = ( x, y = (
x.to(self.pytorch_settings.device), x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device), y.to(self.pytorch_settings.device),
@@ -424,9 +951,11 @@ class Trainer:
# y = y.cpu() # y = y.cpu()
y_pred = y_pred.view(y_pred.shape[0], -1, 2) y_pred = y_pred.view(y_pred.shape[0], -1, 2)
y = y.view(y.shape[0], -1, 2) y = y.view(y.shape[0], -1, 2)
x = x.view(x.shape[0], -1, 2) plot_data = plot_data.view(plot_data.shape[0], -1, 2)
# x = x.view(x.shape[0], -1, 2)
# timestamp = timestamp.view(-1, 1) # timestamp = timestamp.view(-1, 1)
fiber_out.append(x[:, x.shape[1] // 2, :].squeeze()) fiber_out.append(plot_data.squeeze())
fiber_in.append(y.squeeze()) fiber_in.append(y.squeeze())
regen.append(y_pred.squeeze()) regen.append(y_pred.squeeze())
timestamps.append(timestamp.squeeze()) timestamps.append(timestamp.squeeze())
@@ -440,28 +969,23 @@ class Trainer:
return fiber_in, fiber_out, regen, timestamps return fiber_in, fiber_out, regen, timestamps
def write_parameters(self, epoch, attributes: list[str] | tuple[str] = None): def write_parameters(self, epoch, attributes: list[str] | tuple[str] = None):
for i, layer in enumerate(self.model._layers): parameter_list = get_parameter_names_and_values(self.model)
tag = f"layer {i}" for name, value in parameter_list:
if hasattr(layer, "parametrizations"): plot = (attributes is None) or (name in attributes)
attribute_pool = set(layer.parametrizations._modules) | set(layer._parameters)
else:
attribute_pool = set(layer._parameters)
for attribute in attribute_pool:
plot = (attributes is None) or (attribute in attributes)
if plot: if plot:
vals: np.ndarray = getattr(layer, attribute).detach().cpu().numpy().flatten() vals: np.ndarray = value.detach().cpu().numpy().flatten()
if vals.ndim <= 1 and len(vals) == 1: if vals.ndim <= 1 and len(vals) == 1:
if np.iscomplexobj(vals): if np.iscomplexobj(vals):
self.writer.add_scalar(f"{tag} {attribute} (Mag)", np.abs(vals), epoch) self.writer.add_scalar(f"{name} (Mag)", np.abs(vals), epoch)
self.writer.add_scalar(f"{tag} {attribute} (Phase)", np.angle(vals), epoch) self.writer.add_scalar(f"{name} (Phase)", np.angle(vals), epoch)
else: else:
self.writer.add_scalar(f"{tag} {attribute}", vals, epoch) self.writer.add_scalar(f"{name}", vals, epoch)
else: else:
if np.iscomplexobj(vals): if np.iscomplexobj(vals):
self.writer.add_histogram(f"{tag} {attribute} (Mag)", np.abs(vals), epoch, bins="fd") self.writer.add_histogram(f"{name} (Mag)", np.abs(vals), epoch, bins="fd")
self.writer.add_histogram(f"{tag} {attribute} (Phase)", np.angle(vals), epoch, bins="fd") self.writer.add_histogram(f"{name} (Phase)", np.angle(vals), epoch, bins="fd")
else: else:
self.writer.add_histogram(f"{tag} {attribute}", vals, epoch, bins="fd") self.writer.add_histogram(f"{name}", vals, epoch, bins="fd")
def train(self): def train(self):
if self.writer is None: if self.writer is None:
@@ -474,44 +998,48 @@ class Trainer:
) )
title_append, subtitle = self.build_title(0) title_append, subtitle = self.build_title(0)
head_fig, eye_fig, powers_fig = self.plot_model_response(
self.writer.add_figure(
"fiber response",
self.plot_model_response(
model=self.model, model=self.model,
title_append=title_append, title_append=title_append,
subtitle=subtitle, subtitle=subtitle,
show=False, show=False,
), )
self.writer.add_figure(
"fiber response",
head_fig,
0, 0,
) )
self.writer.add_figure( self.writer.add_figure(
"eye diagram", "eye diagram",
self.plot_model_response( eye_fig,
model=self.model,
title_append=title_append,
subtitle=subtitle,
mode="eye",
show=False,
),
0, 0,
) )
self.writer.add_figure( self.writer.add_figure(
"powers", "powers",
self.plot_model_response( powers_fig,
model=self.model,
title_append=title_append,
subtitle=subtitle,
mode="powers",
show=False,
),
0, 0,
) )
self.write_parameters(0) self.write_parameters(0)
self.writer.add_text("datasets", '\n'.join(self.data_settings.config_path)) if isinstance(self.data_settings.config_path, (list, tuple)):
for i, config_path in enumerate(self.data_settings.config_path):
paths = Path.cwd().glob(config_path)
for j, path in enumerate(paths):
text = str(path) + '\n'
with open(path, 'r') as f:
text += f.read()
text += '\n'
self.writer.add_text(f"config_{i*len(self.data_settings.config_path)+j}", text)
elif isinstance(self.data_settings.config_path, str):
paths = Path.cwd().glob(self.data_settings.config_path)
for j, path in enumerate(paths):
text = str(path) + '\n'
with open(path, 'r') as f:
text += f.read()
text += '\n'
self.writer.add_text(f"config_{j}", text)
self.writer.flush() self.writer.flush()
@@ -741,13 +1269,12 @@ class Trainer:
def plot_model_response( def plot_model_response(
self, self,
model=None, model:torch.nn.Module=None,
title_append="", title_append="",
subtitle="", subtitle="",
mode: Literal["eye", "head", "powers"] = "head", # mode: Literal["eye", "head", "powers"] = "head",
show=False, show=False,
): ):
if mode == "powers":
input_data = torch.ones( input_data = torch.ones(
1, 2 * self.data_settings.output_size, dtype=getattr(torch, self.data_settings.dtype) 1, 2 * self.data_settings.output_size, dtype=getattr(torch, self.data_settings.dtype)
).to(self.pytorch_settings.device) ).to(self.pytorch_settings.device)
@@ -757,38 +1284,35 @@ class Trainer:
_, powers = model(input_data, trace_powers=True) _, powers = model(input_data, trace_powers=True)
powers = [power.item() for power in powers] powers = [power.item() for power in powers]
layer_names = ["input", *[str(x).split("(")[0] for x in model._layers._modules.values()]] layer_names = [name for (name, _) in model.named_children()]
# remove dropout layers power_fig = self._plot_model_response_powers(
mask = [1 if "Dropout" not in layer_name else 0 for layer_name in layer_names]
layer_names = [layer_name for layer_name, m in zip(layer_names, mask) if m]
powers = [power for power, m in zip(powers, mask) if m]
fig = self._plot_model_response_powers(
powers, layer_names, title_append=title_append, subtitle=subtitle, show=show powers, layer_names, title_append=title_append, subtitle=subtitle, show=show
) )
return fig
data_settings_backup = copy.deepcopy(self.data_settings) data_settings_backup = copy.deepcopy(self.data_settings)
pytorch_settings_backup = copy.deepcopy(self.pytorch_settings) pytorch_settings_backup = copy.deepcopy(self.pytorch_settings)
self.data_settings.drop_first = 99.5 + random.randint(0, 1000) self.data_settings.drop_first = 99.5 + random.randint(0, 1000)
self.data_settings.shuffle = False self.data_settings.shuffle = False
self.data_settings.train_split = 1.0 self.data_settings.train_split = 1.0
self.pytorch_settings.batchsize = ( self.pytorch_settings.batchsize = max(self.pytorch_settings.head_symbols, self.pytorch_settings.eye_symbols)
self.pytorch_settings.eye_symbols if mode == "eye" else self.pytorch_settings.head_symbols
)
config_path = random.choice(self.data_settings.config_path) if isinstance(self.data_settings.config_path, (list, tuple)) else self.data_settings.config_path config_path = random.choice(self.data_settings.config_path) if isinstance(self.data_settings.config_path, (list, tuple)) else self.data_settings.config_path
fiber_length = int(float(str(config_path).split('-')[-7])/1000) fiber_length = int(float(str(config_path).split('-')[4])/1000)
plot_loader, _ = self.get_sliced_data( if not hasattr(self, "_plot_loader"):
self._plot_loader, _ = self.get_sliced_data(
override={ override={
"num_symbols": self.pytorch_settings.batchsize, "num_symbols": self.pytorch_settings.batchsize,
"config_path": config_path, "config_path": config_path,
"shuffle": False,
"polarisations": (np.random.rand(1)*np.pi*2,),
"randomise_polarisation": False,
} }
) )
self._sps = self._plot_loader.dataset.samples_per_symbol
self.data_settings = data_settings_backup self.data_settings = data_settings_backup
self.pytorch_settings = pytorch_settings_backup self.pytorch_settings = pytorch_settings_backup
fiber_in, fiber_out, regen, timestamps = self.run_model(model, plot_loader) fiber_in, fiber_out, regen, timestamps = self.run_model(model, self._plot_loader)
fiber_in = fiber_in.view(-1, 2) fiber_in = fiber_in.view(-1, 2)
fiber_out = fiber_out.view(-1, 2) fiber_out = fiber_out.view(-1, 2)
regen = regen.view(-1, 2) regen = regen.view(-1, 2)
@@ -802,36 +1326,32 @@ class Trainer:
# https://github.com/matplotlib/matplotlib/issues/27713#issuecomment-1915497463 # https://github.com/matplotlib/matplotlib/issues/27713#issuecomment-1915497463
import gc import gc
if mode == "head": head_fig = self._plot_model_response_head(
fig = self._plot_model_response_head( fiber_in[:self.pytorch_settings.head_symbols*self._sps],
fiber_in, fiber_out[:self.pytorch_settings.head_symbols*self._sps],
fiber_out, regen[:self.pytorch_settings.head_symbols*self._sps],
regen, timestamps=timestamps[:self.pytorch_settings.head_symbols*self._sps],
timestamps=timestamps,
labels=("fiber in", "fiber out", "regen"), labels=("fiber in", "fiber out", "regen"),
sps=plot_loader.dataset.samples_per_symbol, sps=self._sps,
title_append=title_append + f" ({fiber_length} km)", title_append=title_append + f" ({fiber_length} km)",
subtitle=subtitle, subtitle=subtitle,
show=show, show=show,
) )
elif mode == "eye":
# raise NotImplementedError("Eye diagram not implemented") # raise NotImplementedError("Eye diagram not implemented")
fig = self._plot_model_response_eye( eye_fig = self._plot_model_response_eye(
fiber_in, fiber_in[:self.pytorch_settings.eye_symbols*self._sps],
fiber_out, fiber_out[:self.pytorch_settings.eye_symbols*self._sps],
regen, regen[:self.pytorch_settings.eye_symbols*self._sps],
timestamps=timestamps, timestamps=timestamps[:self.pytorch_settings.eye_symbols*self._sps],
labels=("fiber in", "fiber out", "regen"), labels=("fiber in", "fiber out", "regen"),
sps=plot_loader.dataset.samples_per_symbol, sps=self._sps,
title_append=title_append + f" ({fiber_length} km)", title_append=title_append + f" ({fiber_length} km)",
subtitle=subtitle, subtitle=subtitle,
show=show, show=show,
) )
else:
raise ValueError(f"Unknown mode: {mode}")
gc.collect() gc.collect()
return fig return head_fig, eye_fig, power_fig
def build_title(self, number: int): def build_title(self, number: int):
title_append = f"epoch {number}" title_append = f"epoch {number}"

View File

@@ -1,7 +1,10 @@
from datetime import datetime
from pathlib import Path from pathlib import Path
import matplotlib import matplotlib
import numpy as np import numpy as np
import torch import torch
import torch.utils.tensorboard
import torch.utils.tensorboard.summary
from hypertraining.settings import ( from hypertraining.settings import (
GlobalSettings, GlobalSettings,
DataSettings, DataSettings,
@@ -10,7 +13,7 @@ from hypertraining.settings import (
OptimizerSettings, OptimizerSettings,
) )
from hypertraining.training import Trainer from hypertraining.training import RegenerationTrainer, PolarizationTrainer
# import torch # import torch
import json import json
@@ -23,7 +26,7 @@ global_settings = GlobalSettings(
) )
data_settings = DataSettings( data_settings = DataSettings(
config_path="data/20241204-13*-128-16384-100000-0-0-17-0-PAM4-0.ini", config_path="data/*-128-16384-10000-0-0-0-0-PAM4-0-0.4.ini",
# config_path=[f"data/20241202-*-128-16384-{length}-0-0-17-0-PAM4-0.ini" for length in range(48000, 53000, 1000)], # config_path=[f"data/20241202-*-128-16384-{length}-0-0-17-0-PAM4-0.ini" for length in range(48000, 53000, 1000)],
dtype="complex64", dtype="complex64",
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber # symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
@@ -31,17 +34,16 @@ data_settings = DataSettings(
# output_size = (11, 32), # 26 taps -> 2 taps per input symbol -> 1 tap every 1cm (model has 52 inputs (x/y)) # output_size = (11, 32), # 26 taps -> 2 taps per input symbol -> 1 tap every 1cm (model has 52 inputs (x/y))
output_size=26, # study: single_core_regen_20241123_011232 (model_input_dim/2) output_size=26, # study: single_core_regen_20241123_011232 (model_input_dim/2)
shuffle=True, shuffle=True,
in_out_delay=0, drop_first=64,
xy_delay=0,
drop_first=128 * 64,
train_split=0.8, train_split=0.8,
randomise_polarisations=True,
) )
pytorch_settings = PytorchSettings( pytorch_settings = PytorchSettings(
epochs=10000, epochs=10000,
batchsize=2**12, batchsize=2**14,
device="cuda", device="cuda",
dataloader_workers=12, dataloader_workers=16,
dataloader_prefetch=8, dataloader_prefetch=8,
summary_dir=".runs", summary_dir=".runs",
write_every=2**5, write_every=2**5,
@@ -51,12 +53,14 @@ pytorch_settings = PytorchSettings(
model_settings = ModelSettings( model_settings = ModelSettings(
output_dim=2, output_dim=2,
n_hidden_layers=4, n_hidden_layers=5,
overrides={ overrides={
# "hidden_layer_dims": (8, 8, 4, 4),
"n_hidden_nodes_0": 8, "n_hidden_nodes_0": 8,
"n_hidden_nodes_1": 8, "n_hidden_nodes_1": 8,
"n_hidden_nodes_2": 4, "n_hidden_nodes_2": 4,
"n_hidden_nodes_3": 4, "n_hidden_nodes_3": 4,
"n_hidden_nodes_4": 2,
}, },
model_activation_func="EOActivation", model_activation_func="EOActivation",
dropout_prob=0.01, dropout_prob=0.01,
@@ -92,6 +96,14 @@ model_settings = ModelSettings(
"tensor_name": "scales", "tensor_name": "scales",
"parametrization": util.complexNN.clamp, "parametrization": util.complexNN.clamp,
}, },
{
"tensor_name": "angle",
"parametrization": util.complexNN.clamp,
"kwargs": {
"min": -torch.pi,
"max": torch.pi,
},
},
# { # {
# "tensor_name": "scale", # "tensor_name": "scale",
# "parametrization": util.complexNN.clamp, # "parametrization": util.complexNN.clamp,
@@ -143,7 +155,7 @@ def save_dict_to_file(dictionary, filename):
json.dump(dictionary, f, indent=4) json.dump(dictionary, f, indent=4)
def sweep_lengths(*lengths, model=None, data_glob:str=None, strategy="newest"): def sweep_lengths(*lengths, model=None, data_glob: str = None, strategy="newest"):
assert model is not None, "Model must be provided." assert model is not None, "Model must be provided."
assert data_glob is not None, "Data glob must be provided." assert data_glob is not None, "Data glob must be provided."
model = model model = model
@@ -153,7 +165,7 @@ def sweep_lengths(*lengths, model=None, data_glob:str=None, strategy="newest"):
regens = {} regens = {}
timestampss = {} timestampss = {}
trainer = Trainer( trainer = RegenerationTrainer(
checkpoint_path=model, checkpoint_path=model,
) )
trainer.define_model() trainer.define_model()
@@ -165,13 +177,13 @@ def sweep_lengths(*lengths, model=None, data_glob:str=None, strategy="newest"):
continue continue
if strategy == "newest": if strategy == "newest":
sorted_kwargs = { sorted_kwargs = {
'key': lambda x: x.stat().st_mtime, "key": lambda x: x.stat().st_mtime,
'reverse': True, "reverse": True,
} }
elif strategy == "oldest": elif strategy == "oldest":
sorted_kwargs = { sorted_kwargs = {
'key': lambda x: x.stat().st_mtime, "key": lambda x: x.stat().st_mtime,
'reverse': False, "reverse": False,
} }
else: else:
raise ValueError(f"Unknown strategy {strategy}.") raise ValueError(f"Unknown strategy {strategy}.")
@@ -186,22 +198,21 @@ def sweep_lengths(*lengths, model=None, data_glob:str=None, strategy="newest"):
timestampss[length] = timestamps timestampss[length] = timestamps
data = torch.zeros(2 * len(timestampss.keys()) + 2, 2, tuple(fiber_outs.values())[-1].shape[0]) data = torch.zeros(2 * len(timestampss.keys()) + 2, 2, tuple(fiber_outs.values())[-1].shape[0])
channel_names = ["" for _ in range(2 * len(timestampss.keys())+2)] channel_names = ["" for _ in range(2 * len(timestampss.keys()) + 2)]
data[1, 0, :] = timestampss[tuple(timestampss.keys())[-1]] / 128 data[1, 0, :] = timestampss[tuple(timestampss.keys())[-1]] / 128
data[1, 1, :] = fiber_ins[tuple(timestampss.keys())[-1]][:, 0].abs().square() data[1, 1, :] = fiber_ins[tuple(timestampss.keys())[-1]][:, 0].abs().square()
channel_names[1] = "fiber in x" channel_names[1] = "fiber in x"
for li, length in enumerate(timestampss.keys()): for li, length in enumerate(timestampss.keys()):
data[2+2 * li, 0, :] = timestampss[length] / 128 data[2 + 2 * li, 0, :] = timestampss[length] / 128
data[2+2 * li, 1, :] = fiber_outs[length][:, 0].abs().square() data[2 + 2 * li, 1, :] = fiber_outs[length][:, 0].abs().square()
data[2+2 * li + 1, 0, :] = timestampss[length] / 128 data[2 + 2 * li + 1, 0, :] = timestampss[length] / 128
data[2+2 * li + 1, 1, :] = regens[length][:, 0].abs().square() data[2 + 2 * li + 1, 1, :] = regens[length][:, 0].abs().square()
channel_names[2+2 * li+1] = f"regen x {length}" channel_names[2 + 2 * li + 1] = f"regen x {length}"
channel_names[2+2 * li] = f"fiber out x {length}" channel_names[2 + 2 * li] = f"fiber out x {length}"
# get current backend # get current backend
backend = matplotlib.get_backend() backend = matplotlib.get_backend()
@@ -210,7 +221,7 @@ def sweep_lengths(*lengths, model=None, data_glob:str=None, strategy="newest"):
eye = util.eye_diagram.eye_diagram(data.to(dtype=torch.float32).detach().cpu().numpy(), channel_names=channel_names) eye = util.eye_diagram.eye_diagram(data.to(dtype=torch.float32).detach().cpu().numpy(), channel_names=channel_names)
print_attrs = ("channel_name", "success", "min_area") print_attrs = ("channel_name", "success", "min_area")
with np.printoptions(precision=3, suppress=True, formatter={'float': '{:0.3e}'.format}): with np.printoptions(precision=3, suppress=True, formatter={"float": "{:0.3e}".format}):
for result in eye.eye_stats: for result in eye.eye_stats:
print_dict = {attr: result[attr] for attr in print_attrs} print_dict = {attr: result[attr] for attr in print_attrs}
rprint(print_dict) rprint(print_dict)
@@ -221,18 +232,77 @@ def sweep_lengths(*lengths, model=None, data_glob:str=None, strategy="newest"):
if __name__ == "__main__": if __name__ == "__main__":
# lengths = range(90000, 100000+10000, 10000)
lengths = range(90000, 100000+10000, 10000)
# lengths = [100000] # lengths = [100000]
sweep_lengths(*lengths, model=".models/best_20241204_132605.tar", data_glob="data/202412*-{length}-*-0.ini", strategy="newest") # sweep_lengths(*lengths, model=".models/best_20241204_132605.tar", data_glob="data/202412*-{length}-*-0.ini", strategy="newest")
# trainer = Trainer( trainer = RegenerationTrainer(
# global_settings=global_settings, global_settings=global_settings,
# data_settings=data_settings, data_settings=data_settings,
# pytorch_settings=pytorch_settings, pytorch_settings=pytorch_settings,
# model_settings=model_settings, model_settings=model_settings,
# optimizer_settings=optimizer_settings, optimizer_settings=optimizer_settings,
# # checkpoint_path=".models/best_20241202_143149.tar", # checkpoint_path=".models/best_20241205_235929.tar",
# # 20241202_143149 # 20241202_143149
)
trainer.train()
# from hypertraining.lighning_models import regenerator, regeneratorData
# import lightning as L
# model = regenerator(
# 2 * data_settings.output_size,
# *model_settings.overrides["hidden_layer_dims"],
# model_settings.output_dim,
# layer_function=getattr(util.complexNN, model_settings.model_layer_function),
# layer_func_kwargs=model_settings.model_layer_kwargs,
# act_function=getattr(util.complexNN, model_settings.model_activation_func),
# act_func_kwargs=None,
# parametrizations=model_settings.model_layer_parametrizations,
# dtype=getattr(torch, data_settings.dtype),
# dropout_prob=model_settings.dropout_prob,
# scale_layers=model_settings.scale,
# optimizer=getattr(torch.optim, optimizer_settings.optimizer),
# optimizer_kwargs=optimizer_settings.optimizer_kwargs,
# lr_scheduler=getattr(torch.optim.lr_scheduler, optimizer_settings.scheduler),
# lr_scheduler_kwargs=optimizer_settings.scheduler_kwargs,
# ) # )
# trainer.train()
# dm = regeneratorData(
# config_globs=data_settings.config_path,
# output_symbols=data_settings.symbols,
# output_dim=data_settings.output_size,
# dtype=getattr(torch, data_settings.dtype),
# drop_first=data_settings.drop_first,
# shuffle=data_settings.shuffle,
# train_split=data_settings.train_split,
# batch_size=pytorch_settings.batchsize,
# loader_settings={
# "num_workers": pytorch_settings.dataloader_workers,
# "prefetch_factor": pytorch_settings.dataloader_prefetch,
# "pin_memory": True,
# "drop_last": True,
# },
# seed=global_settings.seed,
# )
# # writer = L.SummaryWriter(pytorch_settings.summary_dir + f"/{datetime.now().strftime('%Y%m%d_%H%M%S')}")
# # from torch.utils.tensorboard import SummaryWriter
# subdir = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}"
# # writer = SummaryWriter(pytorch_settings.summary_dir + f"/{subdir}")
# logger = L.pytorch.loggers.TensorBoardLogger(pytorch_settings.summary_dir, name=subdir, log_graph=True)
# trainer = L.Trainer(
# fast_dev_run=False,
# # max_epochs=pytorch_settings.epochs,
# max_epochs=2,
# enable_checkpointing=True,
# default_root_dir=f".models/{subdir}/",
# logger=logger,
# )
# trainer.fit(model, dm)

View File

@@ -0,0 +1,230 @@
from datetime import datetime
from pathlib import Path
import matplotlib
import numpy as np
import torch
import torch.utils.tensorboard
import torch.utils.tensorboard.summary
from hypertraining.settings import (
GlobalSettings,
DataSettings,
PytorchSettings,
ModelSettings,
OptimizerSettings,
)
from hypertraining.training import RegenerationTrainer, PolarizationTrainer
# import torch
import json
import util
from rich import print as rprint
global_settings = GlobalSettings(
seed=0xC0FFEE,
)
data_settings = DataSettings(
config_path="data/*-128-16384-10000-0-0-0-0-PAM4-0-0.4.ini",
# config_path=[f"data/20241202-*-128-16384-{length}-0-0-17-0-PAM4-0.ini" for length in range(48000, 53000, 1000)],
dtype="complex64",
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
symbols=13, # study: single_core_regen_20241123_011232
# output_size = (11, 32), # 26 taps -> 2 taps per input symbol -> 1 tap every 1cm (model has 52 inputs (x/y))
output_size=26, # study: single_core_regen_20241123_011232 (model_input_dim/2)
shuffle=True,
drop_first=64,
train_split=0.8,
# polarisations=tuple(np.random.rand(2)*2*np.pi),
randomise_polarisations=True,
)
pytorch_settings = PytorchSettings(
epochs=10000,
batchsize=2**12,
device="cuda",
dataloader_workers=16,
dataloader_prefetch=8,
summary_dir=".runs",
write_every=2**5,
save_models=True,
model_dir=".models",
)
model_settings = ModelSettings(
output_dim=3,
n_hidden_layers=3,
overrides={
"n_hidden_nodes_0": 2,
"n_hidden_nodes_1": 2,
"n_hidden_nodes_2": 2,
},
dropout_prob=0.01,
model_layer_function="ONNRect",
model_activation_func="EOActivation",
model_layer_kwargs={"square": True},
scale=False,
model_layer_parametrizations=[
{
"tensor_name": "weight",
"parametrization": util.complexNN.energy_conserving,
},
{
"tensor_name": "alpha",
"parametrization": util.complexNN.clamp,
},
{
"tensor_name": "gain",
"parametrization": util.complexNN.clamp,
"kwargs": {
"min": 0,
"max": float("inf"),
},
},
{
"tensor_name": "phase_bias",
"parametrization": util.complexNN.clamp,
"kwargs": {
"min": 0,
"max": 2 * torch.pi,
},
},
{
"tensor_name": "scales",
"parametrization": util.complexNN.clamp,
},
{
"tensor_name": "angle",
"parametrization": util.complexNN.clamp,
"kwargs": {
"min": 0,
"max": 2*torch.pi,
},
},
{
"tensor_name": "loss",
"parametrization": util.complexNN.clamp,
},
],
)
optimizer_settings = OptimizerSettings(
optimizer="AdamW",
optimizer_kwargs={
"lr": 0.005,
"amsgrad": True,
# "weight_decay": 1e-7,
},
# learning_rate=0.05,
scheduler="ReduceLROnPlateau",
scheduler_kwargs={
"patience": 2**6,
"factor": 0.75,
# "threshold": 1e-3,
"min_lr": 1e-6,
"cooldown": 10,
},
)
def save_dict_to_file(dictionary, filename):
"""
Save the best dictionary to a JSON file.
:param best: Dictionary containing the best training results.
:type best: dict
:param filename: Path to the JSON file where the dictionary will be saved.
:type filename: str
"""
with open(filename, "w") as f:
json.dump(dictionary, f, indent=4)
def sweep_lengths(*lengths, model=None, data_glob: str = None, strategy="newest"):
assert model is not None, "Model must be provided."
assert data_glob is not None, "Data glob must be provided."
model = model
fiber_ins = {}
fiber_outs = {}
regens = {}
timestampss = {}
trainer = RegenerationTrainer(
checkpoint_path=model,
)
trainer.define_model()
for length in lengths:
data_glob_length = data_glob.replace("{length}", str(length))
files = list(Path.cwd().glob(data_glob_length))
if len(files) == 0:
continue
if strategy == "newest":
sorted_kwargs = {
"key": lambda x: x.stat().st_mtime,
"reverse": True,
}
elif strategy == "oldest":
sorted_kwargs = {
"key": lambda x: x.stat().st_mtime,
"reverse": False,
}
else:
raise ValueError(f"Unknown strategy {strategy}.")
file = sorted(files, **sorted_kwargs)[0]
loader, _ = trainer.get_sliced_data(override={"config_path": file})
fiber_in, fiber_out, regen, timestamps = trainer.run_model(trainer.model, loader=loader)
fiber_ins[length] = fiber_in
fiber_outs[length] = fiber_out
regens[length] = regen
timestampss[length] = timestamps
data = torch.zeros(2 * len(timestampss.keys()) + 2, 2, tuple(fiber_outs.values())[-1].shape[0])
channel_names = ["" for _ in range(2 * len(timestampss.keys()) + 2)]
data[1, 0, :] = timestampss[tuple(timestampss.keys())[-1]] / 128
data[1, 1, :] = fiber_ins[tuple(timestampss.keys())[-1]][:, 0].abs().square()
channel_names[1] = "fiber in x"
for li, length in enumerate(timestampss.keys()):
data[2 + 2 * li, 0, :] = timestampss[length] / 128
data[2 + 2 * li, 1, :] = fiber_outs[length][:, 0].abs().square()
data[2 + 2 * li + 1, 0, :] = timestampss[length] / 128
data[2 + 2 * li + 1, 1, :] = regens[length][:, 0].abs().square()
channel_names[2 + 2 * li + 1] = f"regen x {length}"
channel_names[2 + 2 * li] = f"fiber out x {length}"
# get current backend
backend = matplotlib.get_backend()
matplotlib.use("TkCairo")
eye = util.eye_diagram.eye_diagram(data.to(dtype=torch.float32).detach().cpu().numpy(), channel_names=channel_names)
print_attrs = ("channel_name", "success", "min_area")
with np.printoptions(precision=3, suppress=True, formatter={"float": "{:0.3e}".format}):
for result in eye.eye_stats:
print_dict = {attr: result[attr] for attr in print_attrs}
rprint(print_dict)
rprint()
eye.plot(all_stats=False)
matplotlib.use(backend)
if __name__ == "__main__":
trainer = PolarizationTrainer(
global_settings=global_settings,
data_settings=data_settings,
pytorch_settings=pytorch_settings,
model_settings=model_settings,
optimizer_settings=optimizer_settings,
# checkpoint_path='.models/pol_pol_20241208_122418_1116.tar',
# reset_epoch=True
)
trainer.train()

View File

@@ -260,12 +260,94 @@ class ONNRect(nn.Module):
self.crop = lambda x: x self.crop = lambda x: x
self.crop.__doc__ = "No cropping" self.crop.__doc__ = "No cropping"
def forward(self, x): def forward(self, x):
x = self.pad(x) x = self.pad(x).to(dtype=self.weight.dtype)
out = self.crop((self.weight @ x.mT).mT) out = self.crop((self.weight @ x.mT).mT)
return out return out
class polarimeter(nn.Module):
def __init__(self):
super(polarimeter, self).__init__()
# self.input_length = input_length
def forward(self, data):
# S0 = I
# S1 = (2*I_x - I)/I
# S2 = (2*I_45 - I)/I
# S3 = (2*I_RHC - I)/I
# # data: (batch, input_length*2) -> (batch, input_length, 2)
data = data.view(data.shape[0], -1, 2)
x = data[:, :, 0].mean(dim=1)
y = data[:, :, 1].mean(dim=1)
# x = x.mean(dim=1)
# y = y.mean(dim=1)
# angle = torch.atan2(y.abs().square().real, x.abs().square().real)
# return torch.stack([angle, angle, angle, angle], dim=1)
# horizontal polarisation
I_x = x.abs().square()
# vertical polarisation
I_y = y.abs().square()
# 45 degree polarisation
I_45 = (x + y).abs().square()
# right hand circular polarisation
I_RHC = (x + 1j*y).abs().square()
# S0 = I_x + I_y
# S1 = I_x - I_y
# S2 = I_45 - I_m45
# S3 = I_RHC - I_LHC
S0 = (I_x + I_y)
S1 = ((2*I_x - S0)/S0)
S2 = ((2*I_45 - S0)/S0)
S3 = ((2*I_RHC - S0)/S0)
return torch.stack([S0/S0, S1/S0, S2/S0, S3/S0], dim=1)
class normalize_by_first(nn.Module):
def __init__(self):
super(normalize_by_first, self).__init__()
def forward(self, data):
return data / data[:, 0].unsqueeze(1)
class photodiode(nn.Module):
def __init__(self, size, bias=True):
super(photodiode, self).__init__()
self.input_dim = size
self.scale = nn.Parameter(torch.rand(size))
self.pd_bias = nn.Parameter(torch.rand(size))
def forward(self, x):
return x.abs().square().to(dtype=x.dtype.to_real()).mul(self.scale).add(self.pd_bias)
class input_rotator(nn.Module):
def __init__(self, input_dim):
super(input_rotator, self).__init__()
assert input_dim % 2 == 0, "Input dimension must be even"
self.input_dim = input_dim
# self.angle = nn.Parameter(torch.randn(1, dtype=self.dtype.to_real()))
def forward(self, x, angle=None):
# take channels (0,1), (2,3), ... and rotate them by the angle
angle = angle or self.angle
sine = torch.sin(angle)
cosine = torch.cos(angle)
rot = torch.tensor([[cosine, -sine], [sine, cosine]], dtype=self.dtype)
return torch.matmul(x.view(-1, 2), rot).view(x.shape)
# def __repr__(self): # def __repr__(self):
# return f"ONNRect({self.input_dim}, {self.output_dim})" # return f"ONNRect({self.input_dim}, {self.output_dim})"
@@ -371,7 +453,7 @@ class Identity(nn.Module):
M(z) = z M(z) = z
""" """
def __init__(self): def __init__(self, size=None):
super(Identity, self).__init__() super(Identity, self).__init__()
def forward(self, x): def forward(self, x):
@@ -404,9 +486,28 @@ class MZISingle(nn.Module):
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
return x * torch.exp(1j * self.phi) * torch.sin(self.omega + self.func(x)) return x * torch.exp(1j * self.phi) * torch.sin(self.omega + self.func(x))
def naive_angle_loss(x: torch.Tensor, target: torch.Tensor, mod=2*torch.pi):
return torch.fmod((x - target), mod).square().mean()
def cosine_loss(x: torch.Tensor, target: torch.Tensor):
return (2*(1 - torch.cos(x - target))).mean()
def angle_mse_loss(x: torch.Tensor, target: torch.Tensor):
x = torch.fmod(x, 2*torch.pi)
target = torch.fmod(target, 2*torch.pi)
x_cos = torch.cos(x)
x_sin = torch.sin(x)
target_cos = torch.cos(target)
target_sin = torch.sin(target)
cos_diff = x_cos - target_cos
sin_diff = x_sin - target_sin
squared_diff = cos_diff**2 + sin_diff**2
return squared_diff.mean()
class EOActivation(nn.Module): class EOActivation(nn.Module):
def __init__(self, bias, size=None): def __init__(self, size=None):
# 10.1109/SiPhotonics60897.2024.10543376 # 10.1109/SiPhotonics60897.2024.10543376
super(EOActivation, self).__init__() super(EOActivation, self).__init__()
if size is None: if size is None:
@@ -571,81 +672,10 @@ class ZReLU(nn.Module):
return torch.relu(x) return torch.relu(x)
class regenerator(nn.Module):
def __init__(
self,
*dims,
layer_function=ONN,
layer_kwargs: dict | None = None,
layer_parametrizations: list[dict] = None,
activation_function=Pow,
dtype=torch.float64,
dropout_prob=0.01,
scale=False,
**kwargs,
):
super(regenerator, self).__init__()
if len(dims) == 0:
try:
dims = kwargs["dims"]
except KeyError:
raise ValueError("dims must be provided")
self._n_hidden_layers = len(dims) - 2
self._layers = nn.Sequential()
if layer_kwargs is None:
layer_kwargs = {}
# self.powers = []
for i in range(self._n_hidden_layers + 1):
if scale:
self._layers.append(Scale(dims[i]))
self._layers.append(layer_function(dims[i], dims[i + 1], dtype=dtype, **layer_kwargs))
if i < self._n_hidden_layers:
if dropout_prob is not None:
self._layers.append(DropoutComplex(p=dropout_prob))
self._layers.append(activation_function(bias=True, size=dims[i + 1]))
self._layers.append(Scale(dims[-1]))
# add parametrizations
if layer_parametrizations is not None:
for layer in self._layers:
for layer_parametrization in layer_parametrizations:
tensor_name = layer_parametrization.get("tensor_name", None)
parametrization = layer_parametrization.get("parametrization", None)
param_kwargs = layer_parametrization.get("kwargs", {})
if tensor_name is not None and tensor_name in layer._parameters and parametrization is not None:
parametrization(layer, tensor_name, **param_kwargs)
# def __call__(self, input_x, **kwargs):
# return self.forward(input_x, **kwargs)
def forward(self, input_x, trace_powers=False):
x = input_x
if trace_powers:
powers = [x.abs().square().sum()]
# check if tracing
if torch.jit.is_tracing():
for layer in self._layers:
x = layer(x)
if trace_powers:
powers.append(x.abs().square().sum())
else:
# with torch.nn.utils.parametrize.cached():
for layer in self._layers:
x = layer(x)
if trace_powers:
powers.append(x.abs().square().sum())
if trace_powers:
return x, powers
return x
__all__ = [ __all__ = [
complex_sse_loss, complex_sse_loss,
complex_mse_loss, complex_mse_loss,
angle_mse_loss,
UnitaryLayer, UnitaryLayer,
unitary, unitary,
energy_conserving, energy_conserving,
@@ -662,6 +692,7 @@ __all__ = [
ZReLU, ZReLU,
MZISingle, MZISingle,
EOActivation, EOActivation,
photodiode,
# SaturableAbsorberLambertW, # SaturableAbsorberLambertW,
# SaturableAbsorber, # SaturableAbsorber,
# SpreadLayer, # SpreadLayer,

View File

@@ -54,7 +54,7 @@ def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=Fals
config["glova"]["nos"] = str(symbols) config["glova"]["nos"] = str(symbols)
data = np.concatenate([data, timestamps.reshape(-1,1)], axis=-1) data = np.concatenate([data, timestamps.reshape(-1, 1)], axis=-1)
data = torch.tensor(data, device=device, dtype=dtype) data = torch.tensor(data, device=device, dtype=dtype)
@@ -113,6 +113,8 @@ class FiberRegenerationDataset(Dataset):
dtype: torch.dtype = None, dtype: torch.dtype = None,
real: bool = False, real: bool = False,
device=None, device=None,
polarisations: tuple | list = (0,),
randomise_polarisations: bool = False,
**kwargs, **kwargs,
): ):
""" """
@@ -145,6 +147,8 @@ class FiberRegenerationDataset(Dataset):
assert output_dim is None or output_dim > 0, "output_len must be positive or None" assert output_dim is None or output_dim > 0, "output_len must be positive or None"
assert drop_first >= 0, "drop_first must be non-negative" assert drop_first >= 0, "drop_first must be non-negative"
self.randomise_polarisations = randomise_polarisations
faux = kwargs.pop("faux", False) faux = kwargs.pop("faux", False)
if faux: if faux:
@@ -165,7 +169,7 @@ class FiberRegenerationDataset(Dataset):
data_raw = None data_raw = None
self.config = None self.config = None
files = [] files = []
for file_path in (file_path if isinstance(file_path, (tuple, list)) else [file_path]): for file_path in file_path if isinstance(file_path, (tuple, list)) else [file_path]:
data, config = load_data( data, config = load_data(
file_path, file_path,
skipfirst=drop_first, skipfirst=drop_first,
@@ -186,6 +190,19 @@ class FiberRegenerationDataset(Dataset):
files.append(config["data"]["file"].strip('"')) files.append(config["data"]["file"].strip('"'))
self.config["data"]["file"] = str(files) self.config["data"]["file"] = str(files)
for i, angle in enumerate(torch.tensor(np.array(polarisations))):
data_raw_copy = data_raw.clone()
if angle == 0:
continue
sine = torch.sin(angle)
cosine = torch.cos(angle)
data_raw_copy[:, 2] = data_raw[:, 2] * cosine - data_raw[:, 3] * sine
data_raw_copy[:, 3] = data_raw[:, 2] * sine + data_raw[:, 3] * cosine
if i == 0:
data_raw = data_raw_copy
else:
data_raw = torch.cat([data_raw, data_raw_copy], dim=0)
self.device = data_raw.device self.device = data_raw.device
self.samples_per_symbol = int(self.config["glova"]["sps"]) self.samples_per_symbol = int(self.config["glova"]["sps"])
@@ -258,17 +275,27 @@ class FiberRegenerationDataset(Dataset):
elif self.target_delay_samples < 0: elif self.target_delay_samples < 0:
data_raw = data_raw[:, : self.target_delay_samples] data_raw = data_raw[:, : self.target_delay_samples]
timestamps = data_raw[-1, :] timestamps = data_raw[4, :]
data_raw = data_raw[:-1, :] data_raw = data_raw[:4, :]
data_raw = data_raw.view(2, 2, -1) data_raw = data_raw.view(2, 2, -1)
timestamps_doubled = torch.cat([timestamps.unsqueeze(dim=0), timestamps.unsqueeze(dim=0)], dim=0).unsqueeze(dim=1) timestamps_doubled = torch.cat([timestamps.unsqueeze(dim=0), timestamps.unsqueeze(dim=0)], dim=0).unsqueeze(
dim=1
)
data_raw = torch.cat([data_raw, timestamps_doubled], dim=1) data_raw = torch.cat([data_raw, timestamps_doubled], dim=1)
# data_raw = torch.cat([data_raw, timestamps_doubled], dim=1)
# data layout # data layout
# [ [E_in_x, E_in_y, timestamps], # [ [E_in_x, E_in_y, timestamps],
# [E_out_x, E_out_y, timestamps] ] # [E_out_x, E_out_y, timestamps] ]
self.data = data_raw.unfold(dimension=-1, size=self.samples_per_slice, step=1) self.data = data_raw.unfold(dimension=-1, size=self.samples_per_slice, step=1)
self.data = self.data.movedim(-2, 0) self.data = self.data.movedim(-2, 0)
if randomise_polarisations:
self.angles = torch.rand(self.data.shape[0]) * np.pi * 2
# self.data[:, 1, :2, :] = self.rotate(self.data[:, 1, :2, :], self.angles)
else:
self.angles = torch.zeros(self.data.shape[0])
# ...
# -> [no_slices, 2, 3, samples_per_slice] # -> [no_slices, 2, 3, samples_per_slice]
# data layout # data layout
@@ -289,22 +316,92 @@ class FiberRegenerationDataset(Dataset):
else: else:
data_slice = self.data[idx].squeeze() data_slice = self.data[idx].squeeze()
data_slice = data_slice[:, :, :data_slice.shape[2] // self.output_dim * self.output_dim] data_slice = data_slice[:, :, : data_slice.shape[2] // self.output_dim * self.output_dim]
data_slice = data_slice.view(data_slice.shape[0], data_slice.shape[1], self.output_dim, -1) data_slice = data_slice.view(data_slice.shape[0], data_slice.shape[1], self.output_dim, -1)
target = data_slice[0, :, self.output_dim//2, 0] # if self.randomise_polarisations:
data = data_slice[1, :, :, 0] # angle = torch.rand(1) * torch.pi * 2
# sine = torch.sin(angle)
# cosine = torch.cos(angle)
# data_slice_ = data_slice[1]
# data_slice[1, 0] = data_slice_[0] * cosine - data_slice_[1] * sine
# data_slice[1,1] = data_slice_[0] * sine + data_slice_[1] * cosine
# else:
# angle = torch.zeros(1)
# data = data_slice[1, :2, :, 0]
angle = self.angles[idx]
data_index = 1
data_slice[1, :2, :, :] = self.rotate(data_slice[data_index, :2, :, :], angle)
data = data_slice[1, :2, :, 0]
# data = self.rotate(data, angle)
# for both polarisations (x/y), calculate the mean of the signal around the current symbol (-> corresponding to a lowpass filter)
angle_data = data_slice[1, :2, :, :].reshape(2, -1).mean(dim=1)
angle_data2 = self.complex_max(data_slice[1, :2, :, :].reshape(2, -1))
plot_data = data_slice[1, :2, self.output_dim // 2, 0]
sop = self.polarimeter(plot_data)
# angle_data = data_slice[1, :2, :self.output_dim//2, :].squeeze().reshape(2, -1).mean(dim=-1)
# angle = data_slice[1, 3, self.output_dim // 2, 0].real
target = data_slice[0, :2, self.output_dim // 2, 0]
target_timestamp = data_slice[0, 2, self.output_dim // 2, 0].real
...
# data_timestamps = data[-1,:].real # data_timestamps = data[-1,:].real
data = data[:-1, :] # data = data[:-1, :]
target_timestamp = target[-1].real # target_timestamp = target[-1].real
target = target[:-1] # target = target[:-1]
# plot_data = plot_data[:-1]
# transpose to interleave the x and y data in the output tensor
data = data.transpose(0, 1).flatten().squeeze() data = data.transpose(0, 1).flatten().squeeze()
angle_data = angle_data.flatten().squeeze()
angle_data2 = angle_data.flatten().squeeze()
angle = angle.flatten().squeeze()
# data_timestamps = data_timestamps.flatten().squeeze() # data_timestamps = data_timestamps.flatten().squeeze()
target = target.flatten().squeeze() target = target.flatten().squeeze()
target_timestamp = target_timestamp.flatten().squeeze() target_timestamp = target_timestamp.flatten().squeeze()
return data, target, target_timestamp return {"x": data, "y": target, "angle": angle, "sop": sop, "angle_data": angle_data, "angle_data2": angle_data2, "timestamp": target_timestamp, "plot_data": plot_data}
def complex_max(self, data, dim=-1):
# returns element(s) with the maximum absolute value along a given dimension
# ind = torch.argmax(data.abs(), dim=dim, keepdim=True)
# max_values = torch.gather(data, dim, ind).squeeze(dim=dim)
# return max_values
return torch.gather(data, dim, torch.argmax(data.abs(), dim=dim, keepdim=True)).squeeze(dim=dim)
def rotate(self, data, angle):
# rotates a 2d tensor by a given angle
# data: [2, ...]
# angle: [1]
# returns: [2, ...]
# get sine and cosine of the angle
sine = torch.sin(angle)
cosine = torch.cos(angle)
return torch.stack([data[0] * cosine - data[1] * sine, data[0] * sine + data[1] * cosine], dim=0)
def polarimeter(self, data):
# data: [2, ...] -> x, y
# returns [4] -> S0, S1, S2, S3
x = data[0].mean()
y = data[1].mean()
I_X = x.abs().square()
I_Y = y.abs().square()
I_45 = (x+y).abs().square()
I_RHC = (x + 1j*y).abs().square()
S0 = I_X + I_Y
S1 = (2*I_X - S0) / S0
S2 = (2*I_45 - S0) / S0
S3 = (2*I_RHC - S0) / S0
return torch.stack([S1, S2, S3], dim=0)