From 39ae13d0afa03bf347fe168a33bbd9333e0bc145 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joseph=20Hopfm=C3=BCller?= Date: Wed, 11 Dec 2024 09:48:38 +0100 Subject: [PATCH] add training script for polarization estimation, refactor model definitions, randomised polarisation support in data_loader --- .../hypertraining/lighning_models.py | 443 ++++++++++ src/single-core-regen/hypertraining/models.py | 204 +++++ .../hypertraining/settings.py | 49 +- .../hypertraining/training.py | 778 +++++++++++++++--- src/single-core-regen/regen_no_hyper.py | 144 +++- src/single-core-regen/train_pol_estimator.py | 230 ++++++ src/single-core-regen/util/complexNN.py | 183 ++-- src/single-core-regen/util/datasets.py | 127 ++- 8 files changed, 1899 insertions(+), 259 deletions(-) create mode 100644 src/single-core-regen/hypertraining/lighning_models.py create mode 100644 src/single-core-regen/hypertraining/models.py create mode 100644 src/single-core-regen/train_pol_estimator.py diff --git a/src/single-core-regen/hypertraining/lighning_models.py b/src/single-core-regen/hypertraining/lighning_models.py new file mode 100644 index 0000000..f99e5ef --- /dev/null +++ b/src/single-core-regen/hypertraining/lighning_models.py @@ -0,0 +1,443 @@ +from typing import Any +import lightning as L +import numpy as np +import torch +import torch.nn as nn +# import torch.nn.functional as F + +from util.complexNN import DropoutComplex, Scale, ONNRect, EOActivation, energy_conserving, clamp, complex_mse_loss +from util.datasets import FiberRegenerationDataset + + +class regeneratorData(L.LightningDataModule): + def __init__( + self, + config_globs, + output_symbols, + output_dim, + dtype, + drop_first, + shuffle=True, + train_split=None, + batch_size=None, + loader_settings=None, + seed=None, + num_symbols=None, + test_globs=None, + ): + super().__init__() + self._config_globs = config_globs + self._test_globs = test_globs + self._test_data_available = test_globs is not None + if self._test_data_available: + self.test_dataloader = self._test_dataloader + self._output_symbols = output_symbols + self._output_dim = output_dim + self._dtype = dtype + self._drop_first = drop_first + self._seed = seed + self._shuffle = shuffle + self._num_symbols = num_symbols + self._train_split = train_split if train_split is not None else 0.8 + self.batch_size = batch_size if batch_size is not None else 1024 + self._loader_settings = loader_settings if loader_settings is not None else {} + + def _get_data(self): + self._data_train = FiberRegenerationDataset( + file_path=self._config_globs, + symbols=self._output_symbols, + output_dim=self._output_dim, + dtype=self._dtype, + real=not self._dtype.is_complex, + drop_first=self._drop_first, + num_symbols=self._num_symbols, + ) + # self._data_plot = FiberRegenerationDataset( + # file_path=self._config_globs, + # symbols=self._output_symbols, + # output_dim=self._output_dim, + # dtype=self._dtype, + # real=not self._dtype.is_complex, + # drop_first=self._drop_first, + # num_symbols=400, + # ) + if self._test_data_available: + self._data_test = FiberRegenerationDataset( + file_path=self._test_globs, + symbols=self._output_symbols, + output_dim=self._output_dim, + dtype=self._dtype, + real=not self._dtype.is_complex, + drop_first=self._drop_first, + num_symbols=self._num_symbols, + ) + return self._data_train, self._data_test + return self._data_train + + def _split_data(self, stage="fit", split=None, shuffle=None): + _split = split if split is not None else self._train_split + _shuffle = shuffle if shuffle is not None else self._shuffle + + dataset_size = len(self._data_train) + indices = list(range(dataset_size)) + split_index = int(np.floor(_split * dataset_size)) + train_indices, valid_indices = indices[:split_index], indices[split_index:] + if _shuffle: + np.random.seed(self._seed) + np.random.shuffle(train_indices) + + + if _shuffle: + if stage == "fit" or stage == "predict": + self._train_sampler = torch.utils.data.SubsetRandomSampler(train_indices) + # if stage == "fit" or stage == "validate": + # self._valid_sampler = torch.utils.data.SubsetRandomSampler(valid_indices) + else: + if stage == "fit" or stage == "predict": + self._train_sampler = train_indices + if stage == "fit" or stage == "validate": + self._valid_sampler = valid_indices + + if stage == "fit": + return self._train_sampler, self._valid_sampler + elif stage == "validate": + return self._valid_sampler + elif stage == "predict": + return self._train_sampler + + def prepare_data(self): + self._get_data() + + def setup(self, stage=None): + stage = stage or "fit" + self._split_data(stage=stage) + + def train_dataloader(self): + return torch.utils.data.DataLoader( + self._data_train, + batch_size=self.batch_size, + sampler=self._train_sampler, + **self._loader_settings + ) + + def val_dataloader(self): + return torch.utils.data.DataLoader( + self._data_train, + batch_size=self.batch_size, + sampler=self._valid_sampler, + **self._loader_settings + ) + + def _test_dataloader(self): + return torch.utils.data.DataLoader( + self._data_test, + shuffle=self._shuffle, + batch_size=self.batch_size, + **self._loader_settings + ) + + def predict_dataloader(self): + return torch.utils.data.DataLoader( + self._data_plot, + shuffle=False, + batch_size=40, + pin_memory=True, + drop_last=True, + num_workers=4, + prefetch_factor=2, + ) + + # def plot_dataloader(self): + + + +class regenerator(L.LightningModule): + def __init__( + self, + *dims, + layer_function=ONNRect, + layer_func_kwargs: dict | None = {"square": True}, + act_function=EOActivation, + act_func_kwargs: dict | None = None, + parametrizations: list[dict] | None = [ + { + "tensor_name": "weight", + "parametrization": energy_conserving, + }, + { + "tensor_name": "alpha", + "parametrization": clamp, + }, + { + "tensor_name": "alpha", + "parametrization": clamp, + }, + ], + dtype=torch.complex64, + dropout_prob=0.01, + scale_layers=False, + optimizer=torch.optim.AdamW, + optimizer_kwargs: dict | None = { + "lr": 0.01, + "amsgrad": True, + }, + lr_scheduler=None, + lr_scheduler_kwargs: dict | None = { + "patience": 20, + "factor": 0.5, + "min_lr": 1e-6, + "cooldown": 10, + }, + sps = 128, + # **kwargs, + ): + torch.set_float32_matmul_precision('high') + layer_func_kwargs = layer_func_kwargs if layer_func_kwargs is not None else {} + act_func_kwargs = act_func_kwargs if act_func_kwargs is not None else {} + optimizer_kwargs = optimizer_kwargs if optimizer_kwargs is not None else {} + lr_scheduler_kwargs = lr_scheduler_kwargs if lr_scheduler_kwargs is not None else {} + super().__init__() + + self.example_input_array = torch.randn(1, dims[0], dtype=dtype) + self._sps = sps + + self.optimizer_settings = { + "optimizer": optimizer, + "optimizer_kwargs": optimizer_kwargs, + "lr_scheduler": lr_scheduler, + "lr_scheduler_kwargs": lr_scheduler_kwargs, + } + + # if len(dims) == 0: + # try: + # dims = kwargs["dims"] + # except KeyError: + # raise ValueError("dims must be provided") + self._n_hidden_layers = len(dims) - 2 + + self.build_model(dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob, scale_layers) + + def build_model(self, dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob, scale_layers): + input_layer = nn.Sequential( + layer_function(dims[0], dims[1], dtype=dtype, **layer_func_kwargs), + act_function(size=dims[1], **act_func_kwargs), + DropoutComplex(p=dropout_prob), + ) + + if scale_layers: + input_layer = nn.Sequential(Scale(dims[0]), input_layer) + + self.layer_0 = input_layer + + for i in range(1, self._n_hidden_layers): + layer = nn.Sequential( + layer_function(dims[i], dims[i + 1], dtype=dtype, **layer_func_kwargs), + act_function(size=dims[i + 1], **act_func_kwargs), + DropoutComplex(p=dropout_prob), + ) + if scale_layers: + layer = nn.Sequential(Scale(dims[i]), layer) + setattr(self, f"layer_{i}", layer) + + output_layer = nn.Sequential( + layer_function(dims[-2], dims[-1], dtype=dtype, **layer_func_kwargs), + act_function(size=dims[-1], **act_func_kwargs), + Scale(dims[-1]), + ) + setattr(self, f"layer_{self._n_hidden_layers}", output_layer) + + if parametrizations is not None: + self._apply_parametrizations(self, parametrizations) + + def _apply_parametrizations(self, layer, parametrizations): + for sub_layer in layer.children(): + if len(sub_layer._modules) > 0: + self._apply_parametrizations(sub_layer, parametrizations) + else: + for parametrization in parametrizations: + tensor_name = parametrization.get("tensor_name", None) + if tensor_name is None: + continue + parametrization_func = parametrization.get("parametrization", None) + if parametrization_func is None: + continue + param_kwargs = parametrization.get("kwargs", {}) + if tensor_name in sub_layer._parameters: + parametrization_func(sub_layer, tensor_name, **param_kwargs) + + def _trace_powers(self, enable, x, powers=None): + if not enable: + return + if powers is None: + powers = [] + powers.append(x.abs().square().sum()) + return powers + + # def plot(self, mode): + # self.predict_step() + + # def validation_epoch_end(self, outputs): + # x = torch.vstack([output['x'].view(output['x'].shape[0], -1, 2)[:, output['x'].shape[1]//2, :].squeeze() for output in outputs]) + # y = torch.vstack([output['y'].view(output['y'].shape[0], -1, 2).squeeze() for output in outputs]) + # y_hat = torch.vstack([output['y_hat'].view(output['y_hat'].shape[0], -1, 2).squeeze() for output in outputs]) + # timesteps = torch.vstack([output['timesteps'].squeeze() for output in outputs]) + # powers = torch.vstack([output['powers'] for output in outputs]) + + # return {'x': x, 'y': y, 'y_hat': y_hat, 'timesteps': timesteps, 'powers': powers} + + def on_validation_epoch_end(self): + if self.current_epoch % 10 == 0 or self.current_epoch == self.trainer.max_epochs - 1 or self.current_epoch < 10: + x = self.val_outputs['x'] + # x = x.view(x.shape[0], -1, 2) + # x = x[:, x.shape[1]//2, :].squeeze() + y = self.val_outputs['y'] + # y = y.view(y.shape[0], -1, 2).squeeze() + y_hat = self.val_outputs['y_hat'] + # y_hat = y_hat.view(y_hat.shape[0], -1, 2).squeeze() + timesteps = self.val_outputs['timesteps'] + # timesteps = timesteps.squeeze() + powers = self.val_outputs['powers'] + # powers = powers.squeeze() + + fiber_in = x.detach().cpu().numpy() + fiber_out = y.detach().cpu().numpy() + regen = y_hat.detach().cpu().numpy() + timesteps = timesteps.detach().cpu().numpy() + # powers = np.array([power.detach().cpu().numpy() for power in powers]) + + # fiber_in = np.concat(fiber_in, axis=0) + # fiber_out = np.concat(fiber_out, axis=0) + # regen = np.concat(regen, axis=0) + # timesteps = np.concat(timesteps, axis=0) + # powers = powers.detach().cpu().numpy() + + + import gc + + fig = self.plot_model_head(fiber_in, fiber_out, regen, timesteps, sps=self._sps) + + self.logger.experiment.add_figure("model response", fig, self.current_epoch) + + # fig = self.plot_model_eye(fiber_in, fiber_out, regen, timesteps, sps=self._sps) + + # self.logger.experiment.add_figure("model eye", fig, self.current_epoch) + + # fig = self.plot_model_powers(powers) + + # self.logger.experiment.add_figure("powers", fig, self.current_epoch) + + gc.collect() + # x, y, y_hat, timesteps, powers = self.validation_epoch_end(self.outputs) + # self.plot(x, y, y_hat, timesteps, powers) + + def plot_model_head(self, fiber_in, fiber_out, regen, timesteps, sps): + import matplotlib + matplotlib.use("TkCairo") + import matplotlib.pyplot as plt + + ordering = np.argsort(timesteps) + signals = [signal[ordering] for signal in [fiber_in, fiber_out, regen]] + timesteps = timesteps[ordering] + + signals = [signal[:sps*40] for signal in signals] + timesteps = timesteps[:sps*40] + + fig, axs = plt.subplots(1, 2, sharex=True, sharey=True) + fig.set_figwidth(16) + fig.set_figheight(4) + + for i, ax in enumerate(axs): + for j, signal in enumerate(signals): + ax.plot(timesteps / sps, np.square(np.abs(signal[:,i])), label=["fiber in", "fiber out", "regen"][j] + [" x", " y"][i]) + ax.set_xlabel("symbol") + ax.set_ylabel("amplitude") + ax.minorticks_on() + ax.tick_params(axis="y", which="minor", left=False, right=False) + ax.grid(which="major", axis="x") + ax.grid(which="minor", axis="x", linestyle=":") + ax.grid(which="major", axis="y") + ax.legend(loc="upper right") + fig.tight_layout() + + return fig + + def plot_model_eye(self, fiber_in, fiber_out, regen, timesteps, sps): + ... + + def plot_model_powers(self, powers): + ... + + def forward(self, x, trace_powers=False): + powers = self._trace_powers(trace_powers, x) + x = self.layer_0(x) + powers = self._trace_powers(trace_powers, x, powers) + for i in range(1, self._n_hidden_layers): + x = getattr(self, f"layer_{i}")(x) + powers = self._trace_powers(trace_powers, x, powers) + x = getattr(self, f"layer_{self._n_hidden_layers}")(x) + powers = self._trace_powers(trace_powers, x, powers) + if trace_powers: + return x, powers + return x + + def configure_optimizers(self): + optimizer = self.optimizer_settings["optimizer"]( + self.parameters(), **self.optimizer_settings["optimizer_kwargs"] + ) + if self.optimizer_settings["lr_scheduler"] is not None: + lr_scheduler = self.optimizer_settings["lr_scheduler"]( + optimizer, **self.optimizer_settings["lr_scheduler_kwargs"] + ) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": lr_scheduler, + "monitor": "val_loss", + } + } + return {"optimizer": optimizer} + + def training_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0): + x, y, timesteps = batch + y_hat = self(x) + loss = complex_mse_loss(y_hat, y, power=True) + self.log("train_loss", loss, on_epoch=True, on_step=True) + return loss + + def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0): + x, y, timesteps = batch + if batch_idx == 0: + y_hat, powers = self.forward(x, trace_powers=True) + else: + y_hat = self.forward(x) + loss = complex_mse_loss(y_hat, y, power=True) + self.log("val_loss", loss, on_epoch=True) + y = y.view(y.shape[0], -1, 2).squeeze() + x = x.view(x.shape[0], -1, 2) + x = x[:, x.shape[1]//2, :].squeeze() + y_hat = y_hat.view(y_hat.shape[0], -1, 2).squeeze() + timesteps = timesteps.squeeze() + if batch_idx == 0: + powers = np.array([power.detach().cpu() for power in powers]) + self.val_outputs = {"y": y, "x": x, "y_hat": y_hat, "timesteps": timesteps, "powers": powers} + else: + self.val_outputs["y"] = torch.vstack([self.val_outputs["y"], y]) + self.val_outputs["x"] = torch.vstack([self.val_outputs["x"], x]) + self.val_outputs["y_hat"] = torch.vstack([self.val_outputs["y_hat"], y_hat]) + self.val_outputs["timesteps"] = torch.concat([self.val_outputs["timesteps"], timesteps], dim=0) + return loss + + def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0): + x, y, timesteps = batch + y_hat = self(x) + loss = complex_mse_loss(y_hat, y, power=True) + self.log("test_loss", loss, on_epoch=True) + return loss + + # def predict_step(self, batch, batch_idx): + # x, y, timesteps = batch + # y_hat = self(x) + # return y, x, y_hat, timesteps + + + diff --git a/src/single-core-regen/hypertraining/models.py b/src/single-core-regen/hypertraining/models.py new file mode 100644 index 0000000..f9b8b03 --- /dev/null +++ b/src/single-core-regen/hypertraining/models.py @@ -0,0 +1,204 @@ +import torch +from torch.nn import Module, Sequential + +from util.complexNN import ( + DropoutComplex, + Scale, + ONNRect, + photodiode, + EOActivation, + polarimeter, + normalize_by_first +) + + +class polarisation_estimator2(Module): + def __init__(self): + super(polarisation_estimator2, self).__init__() + self.layers = Sequential( + polarimeter(), + torch.nn.Linear(4, 4), + torch.nn.ReLU(), + torch.nn.Dropout(p=0.01), + torch.nn.Linear(4, 4), + torch.nn.ReLU(), + torch.nn.Dropout(p=0.01), + torch.nn.Linear(4, 4), + ) + + def forward(self, x): + # x = self.polarimeter(x) + for layer in self.layers: + x = layer(x) + return x + +class polarisation_estimator(Module): + def __init__( + self, + *dims, + layer_function=ONNRect, + layer_func_kwargs: dict | None = None, + output_layer_function=photodiode, + # output_layer_func_kwargs: dict | None = None, + act_function=EOActivation, + act_func_kwargs: dict | None = None, + parametrizations: list[dict] = None, + dtype=torch.float64, + dropout_prob=0.01, + scale_layers=False, + ): + super(polarisation_estimator, self).__init__() + self._n_hidden_layers = len(dims) - 2 + + layer_func_kwargs = layer_func_kwargs or {} + act_func_kwargs = act_func_kwargs or {} + + self.build_model(dims, layer_function, layer_func_kwargs, output_layer_function, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob, scale_layers) + + def forward(self, x): + x = self.layer_0(x) + for i in range(1, self._n_hidden_layers): + x = getattr(self, f"layer_{i}")(x) + x = getattr(self, f"layer_{self._n_hidden_layers}")(x) + x = torch.remainder(x, torch.ones_like(x) * 2 * torch.pi) + return x.squeeze() + + def build_model(self, dims, layer_function, layer_func_kwargs, output_layer_function, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob, scale_layers): + for i in range(0, self._n_hidden_layers): + self.add_module(f"layer_{i}", Sequential()) + + if scale_layers: + self.get_submodule(f"layer_{i}").add_module("scale", Scale(dims[i])) + + module = layer_function(dims[i], dims[i + 1], dtype=dtype, **layer_func_kwargs) + self.get_submodule(f"layer_{i}").add_module("ONN", module) + + module = act_function(size=dims[i + 1], **act_func_kwargs) + self.get_submodule(f"layer_{i}").add_module("activation", module) + + module = DropoutComplex(p=dropout_prob) + self.get_submodule(f"layer_{i}").add_module("dropout", module) + + self.add_module(f"layer_{self._n_hidden_layers}", Sequential()) + + if scale_layers: + self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("scale", Scale(dims[-2])) + + module = layer_function(dims[-2], dims[-1], dtype=dtype, **layer_func_kwargs) + self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("ONN", module) + + module = output_layer_function(size=dims[-1]) + self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("photodiode", module) + + # module = normalize_by_first() + # self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("normalize", module) + + if parametrizations is not None: + self._apply_parametrizations(self, parametrizations) + + def _apply_parametrizations(self, layer, parametrizations): + for sub_layer in layer.children(): + if len(sub_layer._modules) > 0: + self._apply_parametrizations(sub_layer, parametrizations) + else: + for parametrization in parametrizations: + tensor_name = parametrization.get("tensor_name", None) + if tensor_name is None: + continue + parametrization_func = parametrization.get("parametrization", None) + if parametrization_func is None: + continue + param_kwargs = parametrization.get("kwargs", {}) + if tensor_name in sub_layer._parameters: + parametrization_func(sub_layer, tensor_name, **param_kwargs) + +class regenerator(Module): + def __init__( + self, + *dims, + layer_function=ONNRect, + layer_func_kwargs: dict | None = None, + act_function=EOActivation, + act_func_kwargs: dict | None = None, + parametrizations: list[dict] = None, + dtype=torch.float64, + dropout_prob=0.01, + scale_layers=False, + ): + super(regenerator, self).__init__() + self._n_hidden_layers = len(dims) - 2 + + layer_func_kwargs = layer_func_kwargs or {} + act_func_kwargs = act_func_kwargs or {} + + self.build_model(dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob, scale_layers) + + def build_model(self, dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob, scale_layers): + for i in range(0, self._n_hidden_layers): + self.add_module(f"layer_{i}", Sequential()) + + if scale_layers: + self.get_submodule(f"layer_{i}").add_module("scale", Scale(dims[i])) + + module = layer_function(dims[i], dims[i + 1], dtype=dtype, **layer_func_kwargs) + self.get_submodule(f"layer_{i}").add_module("ONN", module) + + module = act_function(size=dims[i + 1], **act_func_kwargs) + self.get_submodule(f"layer_{i}").add_module("activation", module) + + module = DropoutComplex(p=dropout_prob) + self.get_submodule(f"layer_{i}").add_module("dropout", module) + + self.add_module(f"layer_{self._n_hidden_layers}", Sequential()) + + if scale_layers: + self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("scale", Scale(dims[-2])) + + module = layer_function(dims[-2], dims[-1], dtype=dtype, **layer_func_kwargs) + self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("ONN", module) + + module = act_function(size=dims[-1], **act_func_kwargs) + self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("activation", module) + + # module = Scale(size=dims[-1]) + # self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("out_scale", module) + + if parametrizations is not None: + self._apply_parametrizations(self, parametrizations) + + def _apply_parametrizations(self, layer, parametrizations): + for sub_layer in layer.children(): + if len(sub_layer._modules) > 0: + self._apply_parametrizations(sub_layer, parametrizations) + else: + for parametrization in parametrizations: + tensor_name = parametrization.get("tensor_name", None) + if tensor_name is None: + continue + parametrization_func = parametrization.get("parametrization", None) + if parametrization_func is None: + continue + param_kwargs = parametrization.get("kwargs", {}) + if tensor_name in sub_layer._parameters: + parametrization_func(sub_layer, tensor_name, **param_kwargs) + + def _trace_powers(self, enable, x, powers=None): + if not enable: + return + if powers is None: + powers = [] + powers.append(x.abs().square().sum()) + return powers + + def forward(self, x, trace_powers=False): + powers = self._trace_powers(trace_powers, x) + x = self.layer_0(x) + powers = self._trace_powers(trace_powers, x, powers) + for i in range(1, self._n_hidden_layers): + x = getattr(self, f"layer_{i}")(x) + powers = self._trace_powers(trace_powers, x, powers) + x = getattr(self, f"layer_{self._n_hidden_layers}")(x) + powers = self._trace_powers(trace_powers, x, powers) + if trace_powers: + return x, powers + return x \ No newline at end of file diff --git a/src/single-core-regen/hypertraining/settings.py b/src/single-core-regen/hypertraining/settings.py index 1ceb51b..797cb89 100644 --- a/src/single-core-regen/hypertraining/settings.py +++ b/src/single-core-regen/hypertraining/settings.py @@ -20,6 +20,22 @@ class DataSettings: xy_delay: tuple | float | int = 0 drop_first: int = 1000 train_split: float = 0.8 + polarisations: tuple | list = (0,) + randomise_polarisations: bool = False + + """ + change to: + + config_path: tuple | list | None = None + dtype: torch.dtype | None = None + symbols: int | float = 1 + output_dim: int = 2 + shuffle: bool = True + drop_first: float | int = 0 + train_split: float = 0.8 + randomise_polarisations: bool = False + + """ # pytorch settings @@ -30,8 +46,8 @@ class PytorchSettings: device: str = "cuda" - dataloader_workers: int = 2 - dataloader_prefetch: int = 2 + dataloader_workers: int = 1 + dataloader_prefetch: int = 1 save_models: bool = True model_dir: str = ".models" @@ -56,6 +72,24 @@ class ModelSettings: model_layer_kwargs: dict | None = None model_layer_parametrizations: list= field(default_factory=list) + """ + change to: + + dims: tuple | list | None = None + layer_function: nn.Module | None = None + layer_func_kwargs: dict | None = None + activation_function: nn.Module | None = None + activation_func_kwargs: dict | None = None + output_function: nn.Module | None = None + output_func_kwargs: dict | None = None + dropout_function: nn.Module | None = None + dropout_func_kwargs: dict | None = None + scale_function: nn.Module | None = None + scale_func_kwargs: dict | None = None + parametrizations: list | None = None + + """ + @dataclass class OptimizerSettings: @@ -65,6 +99,17 @@ class OptimizerSettings: scheduler: str | None = None scheduler_kwargs: dict | None = None + """ + change to: + + optimizer: torch.optim.Optimizer | None = None + optimizer_kwargs: dict | None = None + learning_rate: float | None = None + scheduler: torch.optim.lr_scheduler | None = None + scheduler_kwargs: dict | None = None + + """ + def _pruner_default_kwargs(): # MedianPruner diff --git a/src/single-core-regen/hypertraining/training.py b/src/single-core-regen/hypertraining/training.py index 71b7bef..fb9a7a0 100644 --- a/src/single-core-regen/hypertraining/training.py +++ b/src/single-core-regen/hypertraining/training.py @@ -37,6 +37,7 @@ from rich.console import Console from util.datasets import FiberRegenerationDataset import util +import hypertraining.models as models from .settings import ( GlobalSettings, @@ -59,8 +60,527 @@ def traverse_dict_update(target, source): except TypeError: target.__dict__[k] = v +def get_parameter_names_and_values(model): + def is_parametrized(module): + if hasattr(module, "parametrizations"): + return True + return False -class Trainer: + def _get_param_info(module, prefix='', parametrization=False): + param_list = [] + for name, param in module.named_parameters(recurse = parametrization): + if parametrization and name.startswith("parametrizations"): + name_parts = name.split('.') + name = name_parts[1] + param = getattr(module, name) + full_name = prefix + ('.' if prefix else '') + name + param_value = param.data + param_list.append((full_name, param_value)) + + for child_name, child_module in module.named_children(): + child_prefix = prefix + ('.' if prefix else '') + child_name + if child_name == "parametrizations": + continue + param_list.extend(_get_param_info(child_module, child_prefix, is_parametrized(child_module))) + + return param_list + + return _get_param_info(model) + +class PolarizationTrainer: + def __init__( + self, + *, + global_settings=None, + data_settings=None, + pytorch_settings=None, + model_settings=None, + optimizer_settings=None, + console=None, + checkpoint_path=None, + settings_override=None, + reset_epoch=False, + ): + self.mod = torch.pi/2 + self.resume = checkpoint_path is not None + torch.serialization.add_safe_globals([ + *util.complexNN.__all__, + GlobalSettings, + DataSettings, + ModelSettings, + OptimizerSettings, + PytorchSettings, + models.regenerator, + torch.nn.utils.parametrizations.orthogonal, + ]) + if self.resume: + print(f"loading checkpoint from {checkpoint_path}") + self.checkpoint_dict = torch.load(checkpoint_path, weights_only=True) + if settings_override is not None: + traverse_dict_update(self.checkpoint_dict["settings"], settings_override) + if reset_epoch: + self.checkpoint_dict["epoch"] = -1 + + self.global_settings: GlobalSettings = self.checkpoint_dict["settings"]["global_settings"] + self.data_settings: DataSettings = self.checkpoint_dict["settings"]["data_settings"] + self.pytorch_settings: PytorchSettings = self.checkpoint_dict["settings"]["pytorch_settings"] + self.model_settings: ModelSettings = self.checkpoint_dict["settings"]["model_settings"] + self.optimizer_settings: OptimizerSettings = self.checkpoint_dict["settings"]["optimizer_settings"] + else: + if global_settings is None: + global_settings = GlobalSettings() + raise UserWarning("Global settings not provided, using default settings") + if data_settings is None: + data_settings = DataSettings() + raise UserWarning("Data settings not provided, using default settings") + if pytorch_settings is None: + pytorch_settings = PytorchSettings() + raise UserWarning("Pytorch settings not provided, using default settings") + if model_settings is None: + model_settings = ModelSettings() + raise UserWarning("Model settings not provided, using default settings") + if optimizer_settings is None: + optimizer_settings = OptimizerSettings() + raise UserWarning("Optimizer settings not provided, using default settings") + + self.global_settings: GlobalSettings = global_settings + self.data_settings: DataSettings = data_settings + self.pytorch_settings: PytorchSettings = pytorch_settings + self.model_settings: ModelSettings = model_settings + self.optimizer_settings: OptimizerSettings = optimizer_settings + + self.console = console or Console() + self.writer = None + + def setup_tb_writer(self, append=None): + log_dir = self.pytorch_settings.summary_dir + "/pol_" + (datetime.now().strftime("%Y%m%d_%H%M%S")) + if append is not None: + log_dir += "_" + str(append) + + print(f"Logging to {log_dir}") + self.writer = SummaryWriter(log_dir=log_dir) + + def save_checkpoint(self, save_dict, filename): + torch.save(save_dict, filename) + + def build_checkpoint_dict(self, loss=None, epoch=None): + return { + "epoch": -1 if epoch is None else epoch, + "loss": float("inf") if loss is None else loss, + "model_state_dict": copy.deepcopy(self.model.state_dict()), + "optimizer_state_dict": copy.deepcopy(self.optimizer.state_dict()), + "scheduler_state_dict": copy.deepcopy(self.scheduler.state_dict()) if hasattr(self, "scheduler") else None, + "model_kwargs": copy.deepcopy(self.model_kwargs), + "settings": { + "global_settings": copy.deepcopy(self.global_settings), + "data_settings": copy.deepcopy(self.data_settings), + "pytorch_settings": copy.deepcopy(self.pytorch_settings), + "model_settings": copy.deepcopy(self.model_settings), + "optimizer_settings": copy.deepcopy(self.optimizer_settings), + }, + } + + def define_model(self, model_kwargs=None): + if self.resume: + model_kwargs = self.checkpoint_dict["model_kwargs"] + else: + model_kwargs = model_kwargs + + if model_kwargs is None: + n_hidden_layers = self.model_settings.n_hidden_layers + + input_dim = 2 * self.data_settings.output_size + + dtype = getattr(torch, self.data_settings.dtype) + + afunc = getattr(util.complexNN, self.model_settings.model_activation_func) + + layer_func = getattr(util.complexNN, self.model_settings.model_layer_function) + + layer_parametrizations = self.model_settings.model_layer_parametrizations + + hidden_dims = [self.model_settings.overrides.get(f"n_hidden_nodes_{i}") for i in range(n_hidden_layers)] + + self.model_kwargs = { + "dims": (input_dim, *hidden_dims, self.model_settings.output_dim), + "layer_function": layer_func, + "layer_func_kwargs": self.model_settings.model_layer_kwargs, + "act_function": afunc, + "act_func_kwargs": None, + "parametrizations": layer_parametrizations, + "dtype": dtype, + "dropout_prob": self.model_settings.dropout_prob, + "scale_layers": self.model_settings.scale, + } + else: + self.model_kwargs = model_kwargs + input_dim = self.model_kwargs["dims"][0] + dtype = self.model_kwargs["dtype"] + + # dims = self.model_kwargs.pop("dims") + model_kwargs = copy.deepcopy(self.model_kwargs) + self.model = models.polarisation_estimator(*model_kwargs.pop('dims'),**model_kwargs) + # self.model = models.polarisation_estimator2() + + if self.writer is not None: + try: + self.writer.add_graph(self.model, torch.rand(1, input_dim, dtype=dtype), use_strict_trace=False) + except RuntimeError: + self.writer.add_graph(self.model, torch.rand(1, 2, dtype=dtype), use_strict_trace=False) + + self.model = self.model.to(self.pytorch_settings.device) + if self.resume: + self.model.load_state_dict(self.checkpoint_dict["model_state_dict"], strict=False) + + def get_sliced_data(self, override=None): + symbols = self.data_settings.symbols + + in_out_delay = self.data_settings.in_out_delay + + xy_delay = self.data_settings.xy_delay + + data_size = self.data_settings.output_size + + dtype = getattr(torch, self.data_settings.dtype) + + num_symbols = None + config_path = self.data_settings.config_path + polarisations = self.data_settings.polarisations + randomise_polarisations = self.data_settings.randomise_polarisations + if override is not None: + num_symbols = override.get("num_symbols", None) + config_path = override.get("config_path", config_path) + polarisations = override.get("polarisations", polarisations) + randomise_polarisations = override.get("randomise_polarisation", randomise_polarisations) + # get dataset + dataset = FiberRegenerationDataset( + file_path=config_path, + symbols=symbols, + output_dim=data_size, + target_delay=in_out_delay, + xy_delay=xy_delay, + drop_first=self.data_settings.drop_first, + dtype=dtype, + real=not dtype.is_complex, + num_symbols=num_symbols, + polarisations=polarisations, + randomise_polarisations=randomise_polarisations, + ) + + dataset_size = len(dataset) + indices = list(range(dataset_size)) + split = int(np.floor(self.data_settings.train_split * dataset_size)) + if self.data_settings.shuffle: + np.random.seed(self.global_settings.seed) + np.random.shuffle(indices) + + train_indices, valid_indices = indices[:split], indices[split:] + + if self.data_settings.shuffle: + train_sampler = torch.utils.data.SubsetRandomSampler(train_indices) + valid_sampler = torch.utils.data.SubsetRandomSampler(valid_indices) + else: + train_sampler = train_indices + valid_sampler = valid_indices + + train_loader = torch.utils.data.DataLoader( + dataset, + batch_size=self.pytorch_settings.batchsize, + sampler=train_sampler, + drop_last=True, + pin_memory=True, + num_workers=self.pytorch_settings.dataloader_workers, + prefetch_factor=self.pytorch_settings.dataloader_prefetch, + ) + + valid_loader = torch.utils.data.DataLoader( + dataset, + batch_size=self.pytorch_settings.batchsize, + sampler=valid_sampler, + drop_last=True, + pin_memory=True, + num_workers=self.pytorch_settings.dataloader_workers, + prefetch_factor=self.pytorch_settings.dataloader_prefetch, + ) + + return train_loader, valid_loader + + def train_model( + self, + optimizer, + train_loader, + epoch, + enable_progress=False, + ): + if enable_progress: + progress = Progress( + TextColumn("[yellow] Training..."), + TextColumn("Error: {task.description}"), + BarColumn(), + TaskProgressColumn(), + TextColumn("[green]Batch"), + MofNCompleteColumn(), + TimeRemainingColumn(), + TimeElapsedColumn(), + transient=False, + console=self.console, + refresh_per_second=10, + ) + task = progress.add_task("-.---e--", total=len(train_loader)) + progress.start() + + running_loss2 = 0.0 + running_loss = 0.0 + self.model.train() + loader_len = len(train_loader) + write_div = 0 + loss_div = 0 + for batch_idx, batch in enumerate(train_loader): + x = batch["x"] + y = batch["sop"] + self.model.zero_grad(set_to_none=True) + x, y = ( + x.to(self.pytorch_settings.device), + y.to(self.pytorch_settings.device), + ) + y_pred = self.model(x) + # loss = torch.nn.functional.smooth_l1_loss(torch.cos(torch.fmod(y_pred, torch.pi/2)).squeeze(), torch.cos(torch.fmod(y, torch.pi/2)).squeeze(), beta=0.5) + loss = torch.nn.functional.mse_loss(y_pred, y) + # loss = util.complexNN.naive_angle_loss(y_pred, y, mod=torch.pi/2) + loss_value = loss.item() + loss.backward() + optimizer.step() + running_loss += loss_value + running_loss2 += loss_value + write_div += 1 + loss_div += 1 + + if enable_progress: + progress.update(task, advance=1, description=f"{loss_value:.3e}") + + if batch_idx % self.pytorch_settings.write_every == 0: + self.writer.add_scalar( + "training loss", + running_loss2 / write_div, + epoch * loader_len + batch_idx, + ) + running_loss2 = 0.0 + write_div = 0 + + if enable_progress: + progress.stop() + + return running_loss / loss_div + + def eval_model(self, valid_loader, epoch, enable_progress=True): + if enable_progress: + progress = Progress( + TextColumn("[green]Evaluating..."), + TextColumn("Error: {task.description}"), + BarColumn(), + TaskProgressColumn(), + TextColumn("[green]Batch"), + MofNCompleteColumn(), + TimeRemainingColumn(), + TimeElapsedColumn(), + transient=False, + console=self.console, + refresh_per_second=10, + ) + progress.start() + task = progress.add_task("-.---e--", total=len(valid_loader)) + + self.model.eval() + running_loss = 0 + loss_div = 0 + with torch.no_grad(): + for _, batch in enumerate(valid_loader): + x = batch["x"] + y = batch["sop"] + x, y = ( + x.to(self.pytorch_settings.device), + y.to(self.pytorch_settings.device), + ) + y_pred = self.model(x) + # loss = torch.nn.functional.smooth_l1_loss(torch.cos(torch.fmod(y_pred, torch.pi/2)).squeeze(), torch.cos(torch.fmod(y, torch.pi/2)).squeeze(), beta=0.5) + loss = torch.nn.functional.mse_loss(y_pred, y) + # loss = util.complexNN.naive_angle_loss(y_pred, y, mod=torch.pi/2) + loss_value = loss.item() + running_loss += loss_value + loss_div += 1 + + if enable_progress: + progress.update(task, advance=1, description=f"{loss_value:.3e}") + + running_loss = running_loss/loss_div + + self.writer.add_scalar( + "eval loss", + running_loss, + epoch, + ) + + # self.write_parameters(epoch + 1) + self.writer.flush() + + if enable_progress: + progress.stop() + + return running_loss + + # def run_model(self, model, loader, trace_powers=False): + # model.eval() + # fiber_out = [] + # fiber_in = [] + # regen = [] + # timestamps = [] + + # with torch.no_grad(): + # model = model.to(self.pytorch_settings.device) + # for batch in loader: + # x = batch["x"] + # y = batch["angle"] + # timestamp = batch["timestamp"] + # plot_data = batch["plot_data"] + # x, y = ( + # x.to(self.pytorch_settings.device), + # y.to(self.pytorch_settings.device), + # ) + # if trace_powers: + # y_pred, powers = model(x, trace_powers).cpu() + # else: + # y_pred = model(x, trace_powers).cpu() + # # x = x.cpu() + # # y = y.cpu() + # y_pred = y_pred.view(y_pred.shape[0], -1, 2) + # y = y.view(y.shape[0], -1, 2) + # plot_data = plot_data.view(plot_data.shape[0], -1, 2) + # # x = x.view(x.shape[0], -1, 2) + + # # timestamp = timestamp.view(-1, 1) + # fiber_out.append(plot_data.squeeze()) + # fiber_in.append(y.squeeze()) + # regen.append(y_pred.squeeze()) + # timestamps.append(timestamp.squeeze()) + + # fiber_out = torch.vstack(fiber_out).cpu() + # fiber_in = torch.vstack(fiber_in).cpu() + # regen = torch.vstack(regen).cpu() + # timestamps = torch.concat(timestamps).cpu() + # if trace_powers: + # return fiber_in, fiber_out, regen, timestamps, powers + # return fiber_in, fiber_out, regen, timestamps + + def write_parameters(self, epoch, attributes: list[str] | tuple[str] = None): + parameter_list = get_parameter_names_and_values(self.model) + for name, value in parameter_list: + plot = (attributes is None) or (name in attributes) + if plot: + vals: np.ndarray = value.detach().cpu().numpy().flatten() + if vals.ndim <= 1 and len(vals) == 1: + if np.iscomplexobj(vals): + self.writer.add_scalar(f"{name} (Mag)", np.abs(vals), epoch) + self.writer.add_scalar(f"{name} (Phase)", np.angle(vals), epoch) + else: + self.writer.add_scalar(f"{name}", vals, epoch) + else: + if np.iscomplexobj(vals): + self.writer.add_histogram(f"{name} (Mag)", np.abs(vals), epoch, bins="fd") + self.writer.add_histogram(f"{name} (Phase)", np.angle(vals), epoch, bins="fd") + else: + self.writer.add_histogram(f"{name}", vals, epoch, bins="fd") + + def train(self): + if self.writer is None: + self.setup_tb_writer() + + self.define_model() + + print( + f"number of parameters (trainable): {sum(p.numel() for p in self.model.parameters())} ({sum(p.numel() for p in self.model.parameters() if p.requires_grad)})" + ) + + # self.write_parameters(0) + + if isinstance(self.data_settings.config_path, (list, tuple)): + for i, config_path in enumerate(self.data_settings.config_path): + paths = Path.cwd().glob(config_path) + for j, path in enumerate(paths): + text = str(path) + '\n' + with open(path, 'r') as f: + text += f.read() + text += '\n' + self.writer.add_text(f"config_{i*len(self.data_settings.config_path)+j}", text) + + elif isinstance(self.data_settings.config_path, str): + paths = Path.cwd().glob(self.data_settings.config_path) + for j, path in enumerate(paths): + text = str(path) + '\n' + with open(path, 'r') as f: + text += f.read() + text += '\n' + self.writer.add_text(f"config_{j}", text) + + self.writer.flush() + + train_loader, valid_loader = self.get_sliced_data() + + optimizer_name = self.optimizer_settings.optimizer + + self.optimizer: optim.Optimizer = getattr(optim, optimizer_name)( + self.model.parameters(), **self.optimizer_settings.optimizer_kwargs + ) + if self.optimizer_settings.scheduler is not None: + self.scheduler = getattr(optim.lr_scheduler, self.optimizer_settings.scheduler)( + self.optimizer, **self.optimizer_settings.scheduler_kwargs + ) + self.writer.add_scalar("learning rate", self.optimizer.param_groups[0]["lr"], -1) + + if not self.resume: + self.best = self.build_checkpoint_dict() + else: + self.best = self.checkpoint_dict + self.best["loss"] = float("inf") + + for epoch in range(self.best["epoch"] + 1, self.pytorch_settings.epochs): + enable_progress = True + if enable_progress: + self.console.rule(f"Epoch {epoch + 1}/{self.pytorch_settings.epochs}") + self.train_model( + self.optimizer, + train_loader, + epoch, + enable_progress=enable_progress, + ) + loss = self.eval_model( + valid_loader, + epoch, + enable_progress=enable_progress, + ) + if self.optimizer_settings.scheduler is not None: + self.scheduler.step(loss) + self.writer.add_scalar("learning rate", self.optimizer.param_groups[0]["lr"], epoch) + if self.pytorch_settings.save_models and self.model is not None: + save_path = ( + Path(self.pytorch_settings.model_dir) / f"pol_{self.writer.get_logdir().split('/')[-1]}_{epoch}.tar" + ) + save_path.parent.mkdir(parents=True, exist_ok=True) + checkpoint = self.build_checkpoint_dict(loss, epoch) + self.save_checkpoint(checkpoint, save_path) + + if loss < self.best["loss"]: + self.best = checkpoint + save_path = ( + Path(self.pytorch_settings.model_dir) / f"best_pol_{self.writer.get_logdir().split('/')[-1]}.tar" + ) + save_path.parent.mkdir(parents=True, exist_ok=True) + self.save_checkpoint(self.best, save_path) + self.writer.flush() + + self.writer.close() + return self.best + +class RegenerationTrainer: def __init__( self, *, @@ -82,10 +602,11 @@ class Trainer: ModelSettings, OptimizerSettings, PytorchSettings, - util.complexNN.regenerator, + models.regenerator, torch.nn.utils.parametrizations.orthogonal, ]) if self.resume: + print(f"loading checkpoint from {checkpoint_path}") self.checkpoint_dict = torch.load(checkpoint_path, weights_only=True) if settings_override is not None: traverse_dict_update(self.checkpoint_dict["settings"], settings_override) @@ -170,11 +691,13 @@ class Trainer: self.model_kwargs = { "dims": (input_dim, *hidden_dims, self.model_settings.output_dim), "layer_function": layer_func, - "layer_parametrizations": layer_parametrizations, - "activation_function": afunc, + "layer_func_kwargs": self.model_settings.model_layer_kwargs, + "act_function": afunc, + "act_func_kwargs": None, + "parametrizations": layer_parametrizations, "dtype": dtype, "dropout_prob": self.model_settings.dropout_prob, - "scale": self.model_settings.scale, + "scale_layers": self.model_settings.scale, } else: self.model_kwargs = model_kwargs @@ -182,7 +705,8 @@ class Trainer: dtype = self.model_kwargs["dtype"] # dims = self.model_kwargs.pop("dims") - self.model = util.complexNN.regenerator(**self.model_kwargs) + model_kwargs = copy.deepcopy(self.model_kwargs) + self.model = models.regenerator(*model_kwargs.pop('dims'),**model_kwargs) if self.writer is not None: self.writer.add_graph(self.model, torch.zeros(1, input_dim, dtype=dtype)) @@ -204,9 +728,13 @@ class Trainer: num_symbols = None config_path = self.data_settings.config_path + polarisations = self.data_settings.polarisations + randomise_polarisations = self.data_settings.randomise_polarisations if override is not None: num_symbols = override.get("num_symbols", None) config_path = override.get("config_path", config_path) + polarisations = override.get("polarisations", polarisations) + randomise_polarisations = override.get("randomise_polarisation", randomise_polarisations) # get dataset dataset = FiberRegenerationDataset( file_path=config_path, @@ -218,6 +746,8 @@ class Trainer: dtype=dtype, real=not dtype.is_complex, num_symbols=num_symbols, + polarisations=polarisations, + randomise_polarisations=randomise_polarisations, ) dataset_size = len(dataset) @@ -286,7 +816,9 @@ class Trainer: running_loss = 0.0 self.model.train() loader_len = len(train_loader) - for batch_idx, (x, y, _) in enumerate(train_loader): + for batch_idx, batch in enumerate(train_loader): + x = batch["x"] + y = batch["y"] self.model.zero_grad(set_to_none=True) x, y = ( x.to(self.pytorch_settings.device), @@ -307,7 +839,7 @@ class Trainer: self.writer.add_scalar( "training loss", running_loss2 / (self.pytorch_settings.write_every if batch_idx > 0 else 1), - epoch + batch_idx/loader_len, + epoch * loader_len + batch_idx, ) running_loss2 = 0.0 @@ -337,7 +869,9 @@ class Trainer: self.model.eval() running_error = 0 with torch.no_grad(): - for _, (x, y, _) in enumerate(valid_loader): + for _, batch in enumerate(valid_loader): + x = batch["x"] + y = batch["y"] x, y = ( x.to(self.pytorch_settings.device), y.to(self.pytorch_settings.device), @@ -360,37 +894,26 @@ class Trainer: if (epoch + 1) % 10 == 0 or epoch < 10: # plotting is slow, so only do it every 10 epochs title_append, subtitle = self.build_title(epoch + 1) + head_fig, eye_fig, powers_fig = self.plot_model_response( + model=self.model, + title_append=title_append, + subtitle=subtitle, + show=False, + ) self.writer.add_figure( "fiber response", - self.plot_model_response( - model=self.model, - title_append=title_append, - subtitle=subtitle, - show=False, - ), + head_fig, epoch + 1, ) self.writer.add_figure( "eye diagram", - self.plot_model_response( - model=self.model, - title_append=title_append, - subtitle=subtitle, - show=False, - mode="eye", - ), + eye_fig, epoch + 1, ) self.writer.add_figure( "powers", - self.plot_model_response( - model=self.model, - title_append=title_append, - subtitle=subtitle, - mode="powers", - show=False, - ), + powers_fig, epoch + 1, ) @@ -411,7 +934,11 @@ class Trainer: with torch.no_grad(): model = model.to(self.pytorch_settings.device) - for x, y, timestamp in loader: + for batch in loader: + x = batch["x"] + y = batch["y"] + timestamp = batch["timestamp"] + plot_data = batch["plot_data"] x, y = ( x.to(self.pytorch_settings.device), y.to(self.pytorch_settings.device), @@ -424,9 +951,11 @@ class Trainer: # y = y.cpu() y_pred = y_pred.view(y_pred.shape[0], -1, 2) y = y.view(y.shape[0], -1, 2) - x = x.view(x.shape[0], -1, 2) + plot_data = plot_data.view(plot_data.shape[0], -1, 2) + # x = x.view(x.shape[0], -1, 2) + # timestamp = timestamp.view(-1, 1) - fiber_out.append(x[:, x.shape[1] // 2, :].squeeze()) + fiber_out.append(plot_data.squeeze()) fiber_in.append(y.squeeze()) regen.append(y_pred.squeeze()) timestamps.append(timestamp.squeeze()) @@ -440,28 +969,23 @@ class Trainer: return fiber_in, fiber_out, regen, timestamps def write_parameters(self, epoch, attributes: list[str] | tuple[str] = None): - for i, layer in enumerate(self.model._layers): - tag = f"layer {i}" - if hasattr(layer, "parametrizations"): - attribute_pool = set(layer.parametrizations._modules) | set(layer._parameters) - else: - attribute_pool = set(layer._parameters) - for attribute in attribute_pool: - plot = (attributes is None) or (attribute in attributes) - if plot: - vals: np.ndarray = getattr(layer, attribute).detach().cpu().numpy().flatten() - if vals.ndim <= 1 and len(vals) == 1: - if np.iscomplexobj(vals): - self.writer.add_scalar(f"{tag} {attribute} (Mag)", np.abs(vals), epoch) - self.writer.add_scalar(f"{tag} {attribute} (Phase)", np.angle(vals), epoch) - else: - self.writer.add_scalar(f"{tag} {attribute}", vals, epoch) + parameter_list = get_parameter_names_and_values(self.model) + for name, value in parameter_list: + plot = (attributes is None) or (name in attributes) + if plot: + vals: np.ndarray = value.detach().cpu().numpy().flatten() + if vals.ndim <= 1 and len(vals) == 1: + if np.iscomplexobj(vals): + self.writer.add_scalar(f"{name} (Mag)", np.abs(vals), epoch) + self.writer.add_scalar(f"{name} (Phase)", np.angle(vals), epoch) else: - if np.iscomplexobj(vals): - self.writer.add_histogram(f"{tag} {attribute} (Mag)", np.abs(vals), epoch, bins="fd") - self.writer.add_histogram(f"{tag} {attribute} (Phase)", np.angle(vals), epoch, bins="fd") - else: - self.writer.add_histogram(f"{tag} {attribute}", vals, epoch, bins="fd") + self.writer.add_scalar(f"{name}", vals, epoch) + else: + if np.iscomplexobj(vals): + self.writer.add_histogram(f"{name} (Mag)", np.abs(vals), epoch, bins="fd") + self.writer.add_histogram(f"{name} (Phase)", np.angle(vals), epoch, bins="fd") + else: + self.writer.add_histogram(f"{name}", vals, epoch, bins="fd") def train(self): if self.writer is None: @@ -474,44 +998,48 @@ class Trainer: ) title_append, subtitle = self.build_title(0) - + head_fig, eye_fig, powers_fig = self.plot_model_response( + model=self.model, + title_append=title_append, + subtitle=subtitle, + show=False, + ) self.writer.add_figure( "fiber response", - self.plot_model_response( - model=self.model, - title_append=title_append, - subtitle=subtitle, - show=False, - ), + head_fig, 0, ) self.writer.add_figure( "eye diagram", - self.plot_model_response( - model=self.model, - title_append=title_append, - subtitle=subtitle, - mode="eye", - show=False, - ), + eye_fig, 0, ) self.writer.add_figure( "powers", - self.plot_model_response( - model=self.model, - title_append=title_append, - subtitle=subtitle, - mode="powers", - show=False, - ), + powers_fig, 0, ) self.write_parameters(0) - self.writer.add_text("datasets", '\n'.join(self.data_settings.config_path)) + if isinstance(self.data_settings.config_path, (list, tuple)): + for i, config_path in enumerate(self.data_settings.config_path): + paths = Path.cwd().glob(config_path) + for j, path in enumerate(paths): + text = str(path) + '\n' + with open(path, 'r') as f: + text += f.read() + text += '\n' + self.writer.add_text(f"config_{i*len(self.data_settings.config_path)+j}", text) + elif isinstance(self.data_settings.config_path, str): + paths = Path.cwd().glob(self.data_settings.config_path) + for j, path in enumerate(paths): + text = str(path) + '\n' + with open(path, 'r') as f: + text += f.read() + text += '\n' + self.writer.add_text(f"config_{j}", text) self.writer.flush() @@ -741,54 +1269,50 @@ class Trainer: def plot_model_response( self, - model=None, + model:torch.nn.Module=None, title_append="", subtitle="", - mode: Literal["eye", "head", "powers"] = "head", + # mode: Literal["eye", "head", "powers"] = "head", show=False, ): - if mode == "powers": - input_data = torch.ones( - 1, 2 * self.data_settings.output_size, dtype=getattr(torch, self.data_settings.dtype) - ).to(self.pytorch_settings.device) - model = model.to(self.pytorch_settings.device) - model.eval() - with torch.no_grad(): - _, powers = model(input_data, trace_powers=True) + input_data = torch.ones( + 1, 2 * self.data_settings.output_size, dtype=getattr(torch, self.data_settings.dtype) + ).to(self.pytorch_settings.device) + model = model.to(self.pytorch_settings.device) + model.eval() + with torch.no_grad(): + _, powers = model(input_data, trace_powers=True) - powers = [power.item() for power in powers] - layer_names = ["input", *[str(x).split("(")[0] for x in model._layers._modules.values()]] + powers = [power.item() for power in powers] + layer_names = [name for (name, _) in model.named_children()] - # remove dropout layers - mask = [1 if "Dropout" not in layer_name else 0 for layer_name in layer_names] - layer_names = [layer_name for layer_name, m in zip(layer_names, mask) if m] - powers = [power for power, m in zip(powers, mask) if m] - - fig = self._plot_model_response_powers( - powers, layer_names, title_append=title_append, subtitle=subtitle, show=show - ) - return fig + power_fig = self._plot_model_response_powers( + powers, layer_names, title_append=title_append, subtitle=subtitle, show=show + ) data_settings_backup = copy.deepcopy(self.data_settings) pytorch_settings_backup = copy.deepcopy(self.pytorch_settings) self.data_settings.drop_first = 99.5 + random.randint(0, 1000) self.data_settings.shuffle = False self.data_settings.train_split = 1.0 - self.pytorch_settings.batchsize = ( - self.pytorch_settings.eye_symbols if mode == "eye" else self.pytorch_settings.head_symbols - ) + self.pytorch_settings.batchsize = max(self.pytorch_settings.head_symbols, self.pytorch_settings.eye_symbols) config_path = random.choice(self.data_settings.config_path) if isinstance(self.data_settings.config_path, (list, tuple)) else self.data_settings.config_path - fiber_length = int(float(str(config_path).split('-')[-7])/1000) - plot_loader, _ = self.get_sliced_data( - override={ - "num_symbols": self.pytorch_settings.batchsize, - "config_path": config_path, - } - ) + fiber_length = int(float(str(config_path).split('-')[4])/1000) + if not hasattr(self, "_plot_loader"): + self._plot_loader, _ = self.get_sliced_data( + override={ + "num_symbols": self.pytorch_settings.batchsize, + "config_path": config_path, + "shuffle": False, + "polarisations": (np.random.rand(1)*np.pi*2,), + "randomise_polarisation": False, + } + ) + self._sps = self._plot_loader.dataset.samples_per_symbol self.data_settings = data_settings_backup self.pytorch_settings = pytorch_settings_backup - fiber_in, fiber_out, regen, timestamps = self.run_model(model, plot_loader) + fiber_in, fiber_out, regen, timestamps = self.run_model(model, self._plot_loader) fiber_in = fiber_in.view(-1, 2) fiber_out = fiber_out.view(-1, 2) regen = regen.view(-1, 2) @@ -802,36 +1326,32 @@ class Trainer: # https://github.com/matplotlib/matplotlib/issues/27713#issuecomment-1915497463 import gc - if mode == "head": - fig = self._plot_model_response_head( - fiber_in, - fiber_out, - regen, - timestamps=timestamps, - labels=("fiber in", "fiber out", "regen"), - sps=plot_loader.dataset.samples_per_symbol, - title_append=title_append + f" ({fiber_length} km)", - subtitle=subtitle, - show=show, - ) - elif mode == "eye": + head_fig = self._plot_model_response_head( + fiber_in[:self.pytorch_settings.head_symbols*self._sps], + fiber_out[:self.pytorch_settings.head_symbols*self._sps], + regen[:self.pytorch_settings.head_symbols*self._sps], + timestamps=timestamps[:self.pytorch_settings.head_symbols*self._sps], + labels=("fiber in", "fiber out", "regen"), + sps=self._sps, + title_append=title_append + f" ({fiber_length} km)", + subtitle=subtitle, + show=show, + ) # raise NotImplementedError("Eye diagram not implemented") - fig = self._plot_model_response_eye( - fiber_in, - fiber_out, - regen, - timestamps=timestamps, + eye_fig = self._plot_model_response_eye( + fiber_in[:self.pytorch_settings.eye_symbols*self._sps], + fiber_out[:self.pytorch_settings.eye_symbols*self._sps], + regen[:self.pytorch_settings.eye_symbols*self._sps], + timestamps=timestamps[:self.pytorch_settings.eye_symbols*self._sps], labels=("fiber in", "fiber out", "regen"), - sps=plot_loader.dataset.samples_per_symbol, + sps=self._sps, title_append=title_append + f" ({fiber_length} km)", subtitle=subtitle, show=show, ) - else: - raise ValueError(f"Unknown mode: {mode}") gc.collect() - return fig + return head_fig, eye_fig, power_fig def build_title(self, number: int): title_append = f"epoch {number}" diff --git a/src/single-core-regen/regen_no_hyper.py b/src/single-core-regen/regen_no_hyper.py index a68022b..5c64be2 100644 --- a/src/single-core-regen/regen_no_hyper.py +++ b/src/single-core-regen/regen_no_hyper.py @@ -1,7 +1,10 @@ +from datetime import datetime from pathlib import Path import matplotlib import numpy as np import torch +import torch.utils.tensorboard +import torch.utils.tensorboard.summary from hypertraining.settings import ( GlobalSettings, DataSettings, @@ -10,7 +13,7 @@ from hypertraining.settings import ( OptimizerSettings, ) -from hypertraining.training import Trainer +from hypertraining.training import RegenerationTrainer, PolarizationTrainer # import torch import json @@ -23,7 +26,7 @@ global_settings = GlobalSettings( ) data_settings = DataSettings( - config_path="data/20241204-13*-128-16384-100000-0-0-17-0-PAM4-0.ini", + config_path="data/*-128-16384-10000-0-0-0-0-PAM4-0-0.4.ini", # config_path=[f"data/20241202-*-128-16384-{length}-0-0-17-0-PAM4-0.ini" for length in range(48000, 53000, 1000)], dtype="complex64", # symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber @@ -31,17 +34,16 @@ data_settings = DataSettings( # output_size = (11, 32), # 26 taps -> 2 taps per input symbol -> 1 tap every 1cm (model has 52 inputs (x/y)) output_size=26, # study: single_core_regen_20241123_011232 (model_input_dim/2) shuffle=True, - in_out_delay=0, - xy_delay=0, - drop_first=128 * 64, + drop_first=64, train_split=0.8, + randomise_polarisations=True, ) pytorch_settings = PytorchSettings( epochs=10000, - batchsize=2**12, + batchsize=2**14, device="cuda", - dataloader_workers=12, + dataloader_workers=16, dataloader_prefetch=8, summary_dir=".runs", write_every=2**5, @@ -51,12 +53,14 @@ pytorch_settings = PytorchSettings( model_settings = ModelSettings( output_dim=2, - n_hidden_layers=4, + n_hidden_layers=5, overrides={ + # "hidden_layer_dims": (8, 8, 4, 4), "n_hidden_nodes_0": 8, "n_hidden_nodes_1": 8, "n_hidden_nodes_2": 4, "n_hidden_nodes_3": 4, + "n_hidden_nodes_4": 2, }, model_activation_func="EOActivation", dropout_prob=0.01, @@ -92,6 +96,14 @@ model_settings = ModelSettings( "tensor_name": "scales", "parametrization": util.complexNN.clamp, }, + { + "tensor_name": "angle", + "parametrization": util.complexNN.clamp, + "kwargs": { + "min": -torch.pi, + "max": torch.pi, + }, + }, # { # "tensor_name": "scale", # "parametrization": util.complexNN.clamp, @@ -143,7 +155,7 @@ def save_dict_to_file(dictionary, filename): json.dump(dictionary, f, indent=4) -def sweep_lengths(*lengths, model=None, data_glob:str=None, strategy="newest"): +def sweep_lengths(*lengths, model=None, data_glob: str = None, strategy="newest"): assert model is not None, "Model must be provided." assert data_glob is not None, "Data glob must be provided." model = model @@ -153,9 +165,9 @@ def sweep_lengths(*lengths, model=None, data_glob:str=None, strategy="newest"): regens = {} timestampss = {} - trainer = Trainer( - checkpoint_path=model, - ) + trainer = RegenerationTrainer( + checkpoint_path=model, + ) trainer.define_model() for length in lengths: @@ -165,13 +177,13 @@ def sweep_lengths(*lengths, model=None, data_glob:str=None, strategy="newest"): continue if strategy == "newest": sorted_kwargs = { - 'key': lambda x: x.stat().st_mtime, - 'reverse': True, + "key": lambda x: x.stat().st_mtime, + "reverse": True, } elif strategy == "oldest": sorted_kwargs = { - 'key': lambda x: x.stat().st_mtime, - 'reverse': False, + "key": lambda x: x.stat().st_mtime, + "reverse": False, } else: raise ValueError(f"Unknown strategy {strategy}.") @@ -186,22 +198,21 @@ def sweep_lengths(*lengths, model=None, data_glob:str=None, strategy="newest"): timestampss[length] = timestamps data = torch.zeros(2 * len(timestampss.keys()) + 2, 2, tuple(fiber_outs.values())[-1].shape[0]) - channel_names = ["" for _ in range(2 * len(timestampss.keys())+2)] + channel_names = ["" for _ in range(2 * len(timestampss.keys()) + 2)] data[1, 0, :] = timestampss[tuple(timestampss.keys())[-1]] / 128 data[1, 1, :] = fiber_ins[tuple(timestampss.keys())[-1]][:, 0].abs().square() channel_names[1] = "fiber in x" - for li, length in enumerate(timestampss.keys()): - data[2+2 * li, 0, :] = timestampss[length] / 128 - data[2+2 * li, 1, :] = fiber_outs[length][:, 0].abs().square() - data[2+2 * li + 1, 0, :] = timestampss[length] / 128 - data[2+2 * li + 1, 1, :] = regens[length][:, 0].abs().square() + data[2 + 2 * li, 0, :] = timestampss[length] / 128 + data[2 + 2 * li, 1, :] = fiber_outs[length][:, 0].abs().square() + data[2 + 2 * li + 1, 0, :] = timestampss[length] / 128 + data[2 + 2 * li + 1, 1, :] = regens[length][:, 0].abs().square() - channel_names[2+2 * li+1] = f"regen x {length}" - channel_names[2+2 * li] = f"fiber out x {length}" + channel_names[2 + 2 * li + 1] = f"regen x {length}" + channel_names[2 + 2 * li] = f"fiber out x {length}" # get current backend backend = matplotlib.get_backend() @@ -210,7 +221,7 @@ def sweep_lengths(*lengths, model=None, data_glob:str=None, strategy="newest"): eye = util.eye_diagram.eye_diagram(data.to(dtype=torch.float32).detach().cpu().numpy(), channel_names=channel_names) print_attrs = ("channel_name", "success", "min_area") - with np.printoptions(precision=3, suppress=True, formatter={'float': '{:0.3e}'.format}): + with np.printoptions(precision=3, suppress=True, formatter={"float": "{:0.3e}".format}): for result in eye.eye_stats: print_dict = {attr: result[attr] for attr in print_attrs} rprint(print_dict) @@ -221,18 +232,77 @@ def sweep_lengths(*lengths, model=None, data_glob:str=None, strategy="newest"): if __name__ == "__main__": - - lengths = range(90000, 100000+10000, 10000) + # lengths = range(90000, 100000+10000, 10000) # lengths = [100000] - sweep_lengths(*lengths, model=".models/best_20241204_132605.tar", data_glob="data/202412*-{length}-*-0.ini", strategy="newest") + # sweep_lengths(*lengths, model=".models/best_20241204_132605.tar", data_glob="data/202412*-{length}-*-0.ini", strategy="newest") - # trainer = Trainer( - # global_settings=global_settings, - # data_settings=data_settings, - # pytorch_settings=pytorch_settings, - # model_settings=model_settings, - # optimizer_settings=optimizer_settings, - # # checkpoint_path=".models/best_20241202_143149.tar", - # # 20241202_143149 + trainer = RegenerationTrainer( + global_settings=global_settings, + data_settings=data_settings, + pytorch_settings=pytorch_settings, + model_settings=model_settings, + optimizer_settings=optimizer_settings, + # checkpoint_path=".models/best_20241205_235929.tar", + # 20241202_143149 + ) + trainer.train() + + # from hypertraining.lighning_models import regenerator, regeneratorData + # import lightning as L + + # model = regenerator( + # 2 * data_settings.output_size, + # *model_settings.overrides["hidden_layer_dims"], + # model_settings.output_dim, + # layer_function=getattr(util.complexNN, model_settings.model_layer_function), + # layer_func_kwargs=model_settings.model_layer_kwargs, + # act_function=getattr(util.complexNN, model_settings.model_activation_func), + # act_func_kwargs=None, + # parametrizations=model_settings.model_layer_parametrizations, + # dtype=getattr(torch, data_settings.dtype), + # dropout_prob=model_settings.dropout_prob, + # scale_layers=model_settings.scale, + # optimizer=getattr(torch.optim, optimizer_settings.optimizer), + # optimizer_kwargs=optimizer_settings.optimizer_kwargs, + # lr_scheduler=getattr(torch.optim.lr_scheduler, optimizer_settings.scheduler), + # lr_scheduler_kwargs=optimizer_settings.scheduler_kwargs, # ) - # trainer.train() \ No newline at end of file + + # dm = regeneratorData( + # config_globs=data_settings.config_path, + # output_symbols=data_settings.symbols, + # output_dim=data_settings.output_size, + # dtype=getattr(torch, data_settings.dtype), + # drop_first=data_settings.drop_first, + # shuffle=data_settings.shuffle, + # train_split=data_settings.train_split, + # batch_size=pytorch_settings.batchsize, + # loader_settings={ + # "num_workers": pytorch_settings.dataloader_workers, + # "prefetch_factor": pytorch_settings.dataloader_prefetch, + # "pin_memory": True, + # "drop_last": True, + # }, + # seed=global_settings.seed, + # ) + + # # writer = L.SummaryWriter(pytorch_settings.summary_dir + f"/{datetime.now().strftime('%Y%m%d_%H%M%S')}") + + # # from torch.utils.tensorboard import SummaryWriter + + # subdir = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}" + + # # writer = SummaryWriter(pytorch_settings.summary_dir + f"/{subdir}") + + # logger = L.pytorch.loggers.TensorBoardLogger(pytorch_settings.summary_dir, name=subdir, log_graph=True) + + # trainer = L.Trainer( + # fast_dev_run=False, + # # max_epochs=pytorch_settings.epochs, + # max_epochs=2, + # enable_checkpointing=True, + # default_root_dir=f".models/{subdir}/", + # logger=logger, + # ) + + # trainer.fit(model, dm) diff --git a/src/single-core-regen/train_pol_estimator.py b/src/single-core-regen/train_pol_estimator.py new file mode 100644 index 0000000..153f807 --- /dev/null +++ b/src/single-core-regen/train_pol_estimator.py @@ -0,0 +1,230 @@ +from datetime import datetime +from pathlib import Path +import matplotlib +import numpy as np +import torch +import torch.utils.tensorboard +import torch.utils.tensorboard.summary +from hypertraining.settings import ( + GlobalSettings, + DataSettings, + PytorchSettings, + ModelSettings, + OptimizerSettings, +) + +from hypertraining.training import RegenerationTrainer, PolarizationTrainer + +# import torch +import json +import util + +from rich import print as rprint + +global_settings = GlobalSettings( + seed=0xC0FFEE, +) + +data_settings = DataSettings( + config_path="data/*-128-16384-10000-0-0-0-0-PAM4-0-0.4.ini", + # config_path=[f"data/20241202-*-128-16384-{length}-0-0-17-0-PAM4-0.ini" for length in range(48000, 53000, 1000)], + dtype="complex64", + # symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber + symbols=13, # study: single_core_regen_20241123_011232 + # output_size = (11, 32), # 26 taps -> 2 taps per input symbol -> 1 tap every 1cm (model has 52 inputs (x/y)) + output_size=26, # study: single_core_regen_20241123_011232 (model_input_dim/2) + shuffle=True, + drop_first=64, + train_split=0.8, + # polarisations=tuple(np.random.rand(2)*2*np.pi), + randomise_polarisations=True, +) + +pytorch_settings = PytorchSettings( + epochs=10000, + batchsize=2**12, + device="cuda", + dataloader_workers=16, + dataloader_prefetch=8, + summary_dir=".runs", + write_every=2**5, + save_models=True, + model_dir=".models", +) + +model_settings = ModelSettings( + output_dim=3, + n_hidden_layers=3, + overrides={ + "n_hidden_nodes_0": 2, + "n_hidden_nodes_1": 2, + "n_hidden_nodes_2": 2, + }, + dropout_prob=0.01, + model_layer_function="ONNRect", + model_activation_func="EOActivation", + model_layer_kwargs={"square": True}, + scale=False, + model_layer_parametrizations=[ + { + "tensor_name": "weight", + "parametrization": util.complexNN.energy_conserving, + }, + { + "tensor_name": "alpha", + "parametrization": util.complexNN.clamp, + }, + { + "tensor_name": "gain", + "parametrization": util.complexNN.clamp, + "kwargs": { + "min": 0, + "max": float("inf"), + }, + }, + { + "tensor_name": "phase_bias", + "parametrization": util.complexNN.clamp, + "kwargs": { + "min": 0, + "max": 2 * torch.pi, + }, + }, + { + "tensor_name": "scales", + "parametrization": util.complexNN.clamp, + }, + { + "tensor_name": "angle", + "parametrization": util.complexNN.clamp, + "kwargs": { + "min": 0, + "max": 2*torch.pi, + }, + }, + { + "tensor_name": "loss", + "parametrization": util.complexNN.clamp, + }, + ], +) + +optimizer_settings = OptimizerSettings( + optimizer="AdamW", + optimizer_kwargs={ + "lr": 0.005, + "amsgrad": True, + # "weight_decay": 1e-7, + }, + # learning_rate=0.05, + scheduler="ReduceLROnPlateau", + scheduler_kwargs={ + "patience": 2**6, + "factor": 0.75, + # "threshold": 1e-3, + "min_lr": 1e-6, + "cooldown": 10, + }, +) + + +def save_dict_to_file(dictionary, filename): + """ + Save the best dictionary to a JSON file. + + :param best: Dictionary containing the best training results. + :type best: dict + :param filename: Path to the JSON file where the dictionary will be saved. + :type filename: str + """ + with open(filename, "w") as f: + json.dump(dictionary, f, indent=4) + + +def sweep_lengths(*lengths, model=None, data_glob: str = None, strategy="newest"): + assert model is not None, "Model must be provided." + assert data_glob is not None, "Data glob must be provided." + model = model + + fiber_ins = {} + fiber_outs = {} + regens = {} + timestampss = {} + + trainer = RegenerationTrainer( + checkpoint_path=model, + ) + trainer.define_model() + + for length in lengths: + data_glob_length = data_glob.replace("{length}", str(length)) + files = list(Path.cwd().glob(data_glob_length)) + if len(files) == 0: + continue + if strategy == "newest": + sorted_kwargs = { + "key": lambda x: x.stat().st_mtime, + "reverse": True, + } + elif strategy == "oldest": + sorted_kwargs = { + "key": lambda x: x.stat().st_mtime, + "reverse": False, + } + else: + raise ValueError(f"Unknown strategy {strategy}.") + file = sorted(files, **sorted_kwargs)[0] + + loader, _ = trainer.get_sliced_data(override={"config_path": file}) + fiber_in, fiber_out, regen, timestamps = trainer.run_model(trainer.model, loader=loader) + + fiber_ins[length] = fiber_in + fiber_outs[length] = fiber_out + regens[length] = regen + timestampss[length] = timestamps + + data = torch.zeros(2 * len(timestampss.keys()) + 2, 2, tuple(fiber_outs.values())[-1].shape[0]) + channel_names = ["" for _ in range(2 * len(timestampss.keys()) + 2)] + + data[1, 0, :] = timestampss[tuple(timestampss.keys())[-1]] / 128 + data[1, 1, :] = fiber_ins[tuple(timestampss.keys())[-1]][:, 0].abs().square() + + channel_names[1] = "fiber in x" + + for li, length in enumerate(timestampss.keys()): + data[2 + 2 * li, 0, :] = timestampss[length] / 128 + data[2 + 2 * li, 1, :] = fiber_outs[length][:, 0].abs().square() + data[2 + 2 * li + 1, 0, :] = timestampss[length] / 128 + data[2 + 2 * li + 1, 1, :] = regens[length][:, 0].abs().square() + + channel_names[2 + 2 * li + 1] = f"regen x {length}" + channel_names[2 + 2 * li] = f"fiber out x {length}" + + # get current backend + backend = matplotlib.get_backend() + + matplotlib.use("TkCairo") + eye = util.eye_diagram.eye_diagram(data.to(dtype=torch.float32).detach().cpu().numpy(), channel_names=channel_names) + + print_attrs = ("channel_name", "success", "min_area") + with np.printoptions(precision=3, suppress=True, formatter={"float": "{:0.3e}".format}): + for result in eye.eye_stats: + print_dict = {attr: result[attr] for attr in print_attrs} + rprint(print_dict) + rprint() + + eye.plot(all_stats=False) + matplotlib.use(backend) + + +if __name__ == "__main__": + trainer = PolarizationTrainer( + global_settings=global_settings, + data_settings=data_settings, + pytorch_settings=pytorch_settings, + model_settings=model_settings, + optimizer_settings=optimizer_settings, + # checkpoint_path='.models/pol_pol_20241208_122418_1116.tar', + # reset_epoch=True + ) + trainer.train() diff --git a/src/single-core-regen/util/complexNN.py b/src/single-core-regen/util/complexNN.py index a8cd75e..cdcf4d6 100644 --- a/src/single-core-regen/util/complexNN.py +++ b/src/single-core-regen/util/complexNN.py @@ -260,12 +260,94 @@ class ONNRect(nn.Module): self.crop = lambda x: x self.crop.__doc__ = "No cropping" - def forward(self, x): - x = self.pad(x) + x = self.pad(x).to(dtype=self.weight.dtype) out = self.crop((self.weight @ x.mT).mT) return out + +class polarimeter(nn.Module): + def __init__(self): + super(polarimeter, self).__init__() + # self.input_length = input_length + + def forward(self, data): + # S0 = I + # S1 = (2*I_x - I)/I + # S2 = (2*I_45 - I)/I + # S3 = (2*I_RHC - I)/I + # # data: (batch, input_length*2) -> (batch, input_length, 2) + data = data.view(data.shape[0], -1, 2) + x = data[:, :, 0].mean(dim=1) + y = data[:, :, 1].mean(dim=1) + + # x = x.mean(dim=1) + # y = y.mean(dim=1) + + # angle = torch.atan2(y.abs().square().real, x.abs().square().real) + + # return torch.stack([angle, angle, angle, angle], dim=1) + + # horizontal polarisation + I_x = x.abs().square() + + # vertical polarisation + I_y = y.abs().square() + + # 45 degree polarisation + I_45 = (x + y).abs().square() + + + # right hand circular polarisation + I_RHC = (x + 1j*y).abs().square() + + # S0 = I_x + I_y + # S1 = I_x - I_y + # S2 = I_45 - I_m45 + # S3 = I_RHC - I_LHC + + S0 = (I_x + I_y) + S1 = ((2*I_x - S0)/S0) + S2 = ((2*I_45 - S0)/S0) + S3 = ((2*I_RHC - S0)/S0) + + return torch.stack([S0/S0, S1/S0, S2/S0, S3/S0], dim=1) + +class normalize_by_first(nn.Module): + def __init__(self): + super(normalize_by_first, self).__init__() + + def forward(self, data): + return data / data[:, 0].unsqueeze(1) + +class photodiode(nn.Module): + def __init__(self, size, bias=True): + super(photodiode, self).__init__() + self.input_dim = size + self.scale = nn.Parameter(torch.rand(size)) + self.pd_bias = nn.Parameter(torch.rand(size)) + + def forward(self, x): + return x.abs().square().to(dtype=x.dtype.to_real()).mul(self.scale).add(self.pd_bias) + + +class input_rotator(nn.Module): + def __init__(self, input_dim): + super(input_rotator, self).__init__() + assert input_dim % 2 == 0, "Input dimension must be even" + self.input_dim = input_dim + # self.angle = nn.Parameter(torch.randn(1, dtype=self.dtype.to_real())) + + def forward(self, x, angle=None): + # take channels (0,1), (2,3), ... and rotate them by the angle + angle = angle or self.angle + sine = torch.sin(angle) + cosine = torch.cos(angle) + rot = torch.tensor([[cosine, -sine], [sine, cosine]], dtype=self.dtype) + return torch.matmul(x.view(-1, 2), rot).view(x.shape) + + + # def __repr__(self): # return f"ONNRect({self.input_dim}, {self.output_dim})" @@ -371,7 +453,7 @@ class Identity(nn.Module): M(z) = z """ - def __init__(self): + def __init__(self, size=None): super(Identity, self).__init__() def forward(self, x): @@ -404,9 +486,28 @@ class MZISingle(nn.Module): def forward(self, x: torch.Tensor): return x * torch.exp(1j * self.phi) * torch.sin(self.omega + self.func(x)) +def naive_angle_loss(x: torch.Tensor, target: torch.Tensor, mod=2*torch.pi): + return torch.fmod((x - target), mod).square().mean() + +def cosine_loss(x: torch.Tensor, target: torch.Tensor): + return (2*(1 - torch.cos(x - target))).mean() + +def angle_mse_loss(x: torch.Tensor, target: torch.Tensor): + x = torch.fmod(x, 2*torch.pi) + target = torch.fmod(target, 2*torch.pi) + + x_cos = torch.cos(x) + x_sin = torch.sin(x) + target_cos = torch.cos(target) + target_sin = torch.sin(target) + + cos_diff = x_cos - target_cos + sin_diff = x_sin - target_sin + squared_diff = cos_diff**2 + sin_diff**2 + return squared_diff.mean() class EOActivation(nn.Module): - def __init__(self, bias, size=None): + def __init__(self, size=None): # 10.1109/SiPhotonics60897.2024.10543376 super(EOActivation, self).__init__() if size is None: @@ -569,83 +670,12 @@ class ZReLU(nn.Module): return x * (torch.angle(x) >= 0) * (torch.angle(x) <= torch.pi / 2) else: return torch.relu(x) - - -class regenerator(nn.Module): - def __init__( - self, - *dims, - layer_function=ONN, - layer_kwargs: dict | None = None, - layer_parametrizations: list[dict] = None, - activation_function=Pow, - dtype=torch.float64, - dropout_prob=0.01, - scale=False, - **kwargs, - ): - super(regenerator, self).__init__() - if len(dims) == 0: - try: - dims = kwargs["dims"] - except KeyError: - raise ValueError("dims must be provided") - self._n_hidden_layers = len(dims) - 2 - self._layers = nn.Sequential() - if layer_kwargs is None: - layer_kwargs = {} - # self.powers = [] - - for i in range(self._n_hidden_layers + 1): - if scale: - self._layers.append(Scale(dims[i])) - self._layers.append(layer_function(dims[i], dims[i + 1], dtype=dtype, **layer_kwargs)) - if i < self._n_hidden_layers: - if dropout_prob is not None: - self._layers.append(DropoutComplex(p=dropout_prob)) - self._layers.append(activation_function(bias=True, size=dims[i + 1])) - - self._layers.append(Scale(dims[-1])) - - # add parametrizations - if layer_parametrizations is not None: - for layer in self._layers: - for layer_parametrization in layer_parametrizations: - tensor_name = layer_parametrization.get("tensor_name", None) - parametrization = layer_parametrization.get("parametrization", None) - param_kwargs = layer_parametrization.get("kwargs", {}) - if tensor_name is not None and tensor_name in layer._parameters and parametrization is not None: - parametrization(layer, tensor_name, **param_kwargs) - - # def __call__(self, input_x, **kwargs): - # return self.forward(input_x, **kwargs) - - def forward(self, input_x, trace_powers=False): - x = input_x - - if trace_powers: - powers = [x.abs().square().sum()] - - # check if tracing - if torch.jit.is_tracing(): - for layer in self._layers: - x = layer(x) - if trace_powers: - powers.append(x.abs().square().sum()) - else: - # with torch.nn.utils.parametrize.cached(): - for layer in self._layers: - x = layer(x) - if trace_powers: - powers.append(x.abs().square().sum()) - if trace_powers: - return x, powers - return x __all__ = [ complex_sse_loss, complex_mse_loss, + angle_mse_loss, UnitaryLayer, unitary, energy_conserving, @@ -662,6 +692,7 @@ __all__ = [ ZReLU, MZISingle, EOActivation, + photodiode, # SaturableAbsorberLambertW, # SaturableAbsorber, # SpreadLayer, diff --git a/src/single-core-regen/util/datasets.py b/src/single-core-regen/util/datasets.py index 98e7781..d56bafe 100644 --- a/src/single-core-regen/util/datasets.py +++ b/src/single-core-regen/util/datasets.py @@ -54,7 +54,7 @@ def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=Fals config["glova"]["nos"] = str(symbols) - data = np.concatenate([data, timestamps.reshape(-1,1)], axis=-1) + data = np.concatenate([data, timestamps.reshape(-1, 1)], axis=-1) data = torch.tensor(data, device=device, dtype=dtype) @@ -113,6 +113,8 @@ class FiberRegenerationDataset(Dataset): dtype: torch.dtype = None, real: bool = False, device=None, + polarisations: tuple | list = (0,), + randomise_polarisations: bool = False, **kwargs, ): """ @@ -145,6 +147,8 @@ class FiberRegenerationDataset(Dataset): assert output_dim is None or output_dim > 0, "output_len must be positive or None" assert drop_first >= 0, "drop_first must be non-negative" + self.randomise_polarisations = randomise_polarisations + faux = kwargs.pop("faux", False) if faux: @@ -165,7 +169,7 @@ class FiberRegenerationDataset(Dataset): data_raw = None self.config = None files = [] - for file_path in (file_path if isinstance(file_path, (tuple, list)) else [file_path]): + for file_path in file_path if isinstance(file_path, (tuple, list)) else [file_path]: data, config = load_data( file_path, skipfirst=drop_first, @@ -185,7 +189,20 @@ class FiberRegenerationDataset(Dataset): assert self.config["glova"]["sps"] == config["glova"]["sps"], "samples per symbol must be the same" files.append(config["data"]["file"].strip('"')) self.config["data"]["file"] = str(files) - + + for i, angle in enumerate(torch.tensor(np.array(polarisations))): + data_raw_copy = data_raw.clone() + if angle == 0: + continue + sine = torch.sin(angle) + cosine = torch.cos(angle) + data_raw_copy[:, 2] = data_raw[:, 2] * cosine - data_raw[:, 3] * sine + data_raw_copy[:, 3] = data_raw[:, 2] * sine + data_raw[:, 3] * cosine + if i == 0: + data_raw = data_raw_copy + else: + data_raw = torch.cat([data_raw, data_raw_copy], dim=0) + self.device = data_raw.device self.samples_per_symbol = int(self.config["glova"]["sps"]) @@ -258,17 +275,27 @@ class FiberRegenerationDataset(Dataset): elif self.target_delay_samples < 0: data_raw = data_raw[:, : self.target_delay_samples] - timestamps = data_raw[-1, :] - data_raw = data_raw[:-1, :] + timestamps = data_raw[4, :] + data_raw = data_raw[:4, :] data_raw = data_raw.view(2, 2, -1) - timestamps_doubled = torch.cat([timestamps.unsqueeze(dim=0), timestamps.unsqueeze(dim=0)], dim=0).unsqueeze(dim=1) + timestamps_doubled = torch.cat([timestamps.unsqueeze(dim=0), timestamps.unsqueeze(dim=0)], dim=0).unsqueeze( + dim=1 + ) data_raw = torch.cat([data_raw, timestamps_doubled], dim=1) + # data_raw = torch.cat([data_raw, timestamps_doubled], dim=1) # data layout # [ [E_in_x, E_in_y, timestamps], # [E_out_x, E_out_y, timestamps] ] self.data = data_raw.unfold(dimension=-1, size=self.samples_per_slice, step=1) self.data = self.data.movedim(-2, 0) + + if randomise_polarisations: + self.angles = torch.rand(self.data.shape[0]) * np.pi * 2 + # self.data[:, 1, :2, :] = self.rotate(self.data[:, 1, :2, :], self.angles) + else: + self.angles = torch.zeros(self.data.shape[0]) + # ... # -> [no_slices, 2, 3, samples_per_slice] # data layout @@ -288,23 +315,93 @@ class FiberRegenerationDataset(Dataset): return [self.__getitem__(i) for i in range(*idx.indices(len(self)))] else: data_slice = self.data[idx].squeeze() - - data_slice = data_slice[:, :, :data_slice.shape[2] // self.output_dim * self.output_dim] + + data_slice = data_slice[:, :, : data_slice.shape[2] // self.output_dim * self.output_dim] data_slice = data_slice.view(data_slice.shape[0], data_slice.shape[1], self.output_dim, -1) - target = data_slice[0, :, self.output_dim//2, 0] - data = data_slice[1, :, :, 0] + # if self.randomise_polarisations: + # angle = torch.rand(1) * torch.pi * 2 + # sine = torch.sin(angle) + # cosine = torch.cos(angle) + # data_slice_ = data_slice[1] + # data_slice[1, 0] = data_slice_[0] * cosine - data_slice_[1] * sine + # data_slice[1,1] = data_slice_[0] * sine + data_slice_[1] * cosine + # else: + # angle = torch.zeros(1) + + # data = data_slice[1, :2, :, 0] + + angle = self.angles[idx] + + data_index = 1 + + data_slice[1, :2, :, :] = self.rotate(data_slice[data_index, :2, :, :], angle) + + data = data_slice[1, :2, :, 0] + # data = self.rotate(data, angle) + + # for both polarisations (x/y), calculate the mean of the signal around the current symbol (-> corresponding to a lowpass filter) + angle_data = data_slice[1, :2, :, :].reshape(2, -1).mean(dim=1) + angle_data2 = self.complex_max(data_slice[1, :2, :, :].reshape(2, -1)) + plot_data = data_slice[1, :2, self.output_dim // 2, 0] + sop = self.polarimeter(plot_data) + # angle_data = data_slice[1, :2, :self.output_dim//2, :].squeeze().reshape(2, -1).mean(dim=-1) + # angle = data_slice[1, 3, self.output_dim // 2, 0].real + target = data_slice[0, :2, self.output_dim // 2, 0] + target_timestamp = data_slice[0, 2, self.output_dim // 2, 0].real + ... # data_timestamps = data[-1,:].real - data = data[:-1, :] - target_timestamp = target[-1].real - target = target[:-1] - + # data = data[:-1, :] + # target_timestamp = target[-1].real + # target = target[:-1] + # plot_data = plot_data[:-1] + # transpose to interleave the x and y data in the output tensor data = data.transpose(0, 1).flatten().squeeze() + angle_data = angle_data.flatten().squeeze() + angle_data2 = angle_data.flatten().squeeze() + angle = angle.flatten().squeeze() # data_timestamps = data_timestamps.flatten().squeeze() target = target.flatten().squeeze() target_timestamp = target_timestamp.flatten().squeeze() - return data, target, target_timestamp + return {"x": data, "y": target, "angle": angle, "sop": sop, "angle_data": angle_data, "angle_data2": angle_data2, "timestamp": target_timestamp, "plot_data": plot_data} + + def complex_max(self, data, dim=-1): + # returns element(s) with the maximum absolute value along a given dimension + # ind = torch.argmax(data.abs(), dim=dim, keepdim=True) + # max_values = torch.gather(data, dim, ind).squeeze(dim=dim) + # return max_values + return torch.gather(data, dim, torch.argmax(data.abs(), dim=dim, keepdim=True)).squeeze(dim=dim) + + + def rotate(self, data, angle): + # rotates a 2d tensor by a given angle + # data: [2, ...] + # angle: [1] + # returns: [2, ...] + + # get sine and cosine of the angle + sine = torch.sin(angle) + cosine = torch.cos(angle) + + return torch.stack([data[0] * cosine - data[1] * sine, data[0] * sine + data[1] * cosine], dim=0) + + def polarimeter(self, data): + # data: [2, ...] -> x, y + # returns [4] -> S0, S1, S2, S3 + x = data[0].mean() + y = data[1].mean() + I_X = x.abs().square() + I_Y = y.abs().square() + I_45 = (x+y).abs().square() + I_RHC = (x + 1j*y).abs().square() + + S0 = I_X + I_Y + S1 = (2*I_X - S0) / S0 + S2 = (2*I_45 - S0) / S0 + S3 = (2*I_RHC - S0) / S0 + + return torch.stack([S1, S2, S3], dim=0) \ No newline at end of file