From 98305fdf479b85828ee851011b19d410e6dbcc20 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joseph=20Hopfm=C3=BCller?= Date: Sun, 29 Dec 2024 16:00:36 +0100 Subject: [PATCH] update dataset configurations, add rotation module, and refine model settings for training, new hyperparameter tuning run for corrected datasets --- data/single_core_regen.db | 4 +- .../hypertraining/hypertraining.py | 195 ++++++++-------- src/single-core-regen/hypertraining/models.py | 49 ++-- .../hypertraining/settings.py | 2 + .../hypertraining/training.py | 218 +++++++++++------- src/single-core-regen/regen.py | 93 ++++++-- src/single-core-regen/regen_no_hyper.py | 34 ++- src/single-core-regen/train_pol_estimator.py | 28 ++- src/single-core-regen/util/complexNN.py | 25 +- src/single-core-regen/util/datasets.py | 218 ++++++++++++------ 10 files changed, 561 insertions(+), 305 deletions(-) diff --git a/data/single_core_regen.db b/data/single_core_regen.db index 4084f2c..cd729de 100644 --- a/data/single_core_regen.db +++ b/data/single_core_regen.db @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:746ea83013870351296e01e294905ee027291ef79fd78c1e6b69dd9ebaa1cba0 -size 10240000 +oid sha256:76934d1d202aea1311ba67f5ea35eeb99a9c5c856f491565032e7d54ca6f072d +size 13598720 diff --git a/src/single-core-regen/hypertraining/hypertraining.py b/src/single-core-regen/hypertraining/hypertraining.py index 758331e..1202d51 100644 --- a/src/single-core-regen/hypertraining/hypertraining.py +++ b/src/single-core-regen/hypertraining/hypertraining.py @@ -26,6 +26,8 @@ import torch import torch.optim as optim import torch.utils.data +import hypertraining.models as models + from torch.utils.tensorboard import SummaryWriter import multiprocessing @@ -253,14 +255,17 @@ class HyperTraining: model_kwargs = { "dims": (input_dim, *hidden_dims, self.model_settings.output_dim), "layer_function": layer_func, - "layer_parametrizations": layer_parametrizations, - "activation_function": afunc, + "layer_func_kwargs": self.model_settings.model_layer_kwargs, + "act_function": afunc, + "act_func_kwargs": None, + "parametrizations": layer_parametrizations, "dtype": dtype, - "droupout_prob": self.model_settings.dropout_prob, - "scale": scale_layers, + "dropout_prob": self.model_settings.dropout_prob, + "scale_layers": scale_layers, + "rotate": False, } - model = util.complexNN.regenerator(**model_kwargs) + model = models.regenerator(*model_kwargs.pop("dims"), **model_kwargs) n_nodes = sum(hidden_dims) if writer is not None: @@ -381,7 +386,10 @@ class HyperTraining: running_loss = 0.0 model.train() loader_len = len(train_loader) - for batch_idx, (x, y, _) in enumerate(train_loader): + for batch_idx, batch in enumerate(train_loader): + x = batch["x"] + y = batch["y"] + if batch_idx >= self.optuna_settings._n_train_batches: break model.zero_grad(set_to_none=True) @@ -390,7 +398,7 @@ class HyperTraining: y.to(self.pytorch_settings.device), ) y_pred = model(x) - loss = util.complexNN.complex_mse_loss(y_pred, y) + loss = util.complexNN.complex_mse_loss(y_pred, y, power=True) loss_value = loss.item() loss.backward() optimizer.step() @@ -444,7 +452,9 @@ class HyperTraining: model.eval() running_error = 0 with torch.no_grad(): - for batch_idx, (x, y, _) in enumerate(valid_loader): + for batch_idx, batch in enumerate(valid_loader): + x = batch["x"] + y = batch["y"] if batch_idx >= self.optuna_settings._n_valid_batches: break x, y = ( @@ -452,50 +462,44 @@ class HyperTraining: y.to(self.pytorch_settings.device), ) y_pred = model(x) - error = util.complexNN.complex_mse_loss(y_pred, y) + error = util.complexNN.complex_mse_loss(y_pred, y, power=True) error_value = error.item() running_error += error_value running_error /= min(len(valid_loader), self.optuna_settings._n_valid_batches) if writer is not None: - title_append, subtitle = self.build_title(trial) - writer.add_figure( - "fiber response", - self.plot_model_response( - trial, - model=model, - title_append=title_append, - subtitle=subtitle, - show=False, - ), - epoch + 1, - ) - writer.add_figure( - "eye diagram", - self.plot_model_response( - trial, - model=self.model, - title_append=title_append, - subtitle=subtitle, - show=False, - mode="eye", - ), - epoch + 1, + writer.add_scalar( + "eval loss", + running_error, + epoch, ) + # if (epoch + 1) % 10 == 0 or epoch < 10: + # # plotting is slow, so only do it every 10 epochs + # title_append, subtitle = self.build_title(trial) + # head_fig, eye_fig, powers_fig = self.plot_model_response( + # model=model, + # title_append=title_append, + # subtitle=subtitle, + # show=False, + # ) + # writer.add_figure( + # "fiber response", + # head_fig, + # epoch + 1, + # ) + # writer.add_figure( + # "eye diagram", + # eye_fig, + # epoch + 1, + # ) - writer.add_figure( - "powers", - self.plot_model_response( - trial, - model=self.model, - title_append=title_append, - subtitle=subtitle, - mode="powers", - show=False, - ), - epoch + 1, - ) + # writer.add_figure( + # "powers", + # powers_fig, + # epoch + 1, + # ) + # writer.flush() # if enable_progress: # progress.stop() @@ -511,15 +515,18 @@ class HyperTraining: with torch.no_grad(): model = model.to(self.pytorch_settings.device) - for x, y, timestamp in loader: + for batch in loader: + x = batch["x"] + y = batch["y"] + timestamp = batch["timestamp"] x, y = ( x.to(self.pytorch_settings.device), y.to(self.pytorch_settings.device), ) if trace_powers: - y_pred, powers = model(x, trace_powers).cpu() + y_pred, powers = model(x, trace_powers=True).cpu() else: - y_pred = model(x, trace_powers).cpu() + y_pred = model(x, trace_powers=True).cpu() # x = x.cpu() # y = y.cpu() y_pred = y_pred.view(y_pred.shape[0], -1, 2) @@ -539,7 +546,7 @@ class HyperTraining: return fiber_in, fiber_out, regen, timestamps, powers return fiber_in, fiber_out, regen, timestamps - def objective(self, trial: optuna.Trial, plot_before=False): + def objective(self, trial: optuna.Trial): if self.stop_study: trial.study.stop() model = None @@ -555,54 +562,54 @@ class HyperTraining: title_append, subtitle = self.build_title(trial) - writer.add_figure( - "fiber response", - self.plot_model_response( - trial, - model=model, - title_append=title_append, - subtitle=subtitle, - show=False, - ), - 0, - ) - writer.add_figure( - "eye diagram", - self.plot_model_response( - trial, - model=self.model, - title_append=title_append, - subtitle=subtitle, - mode="eye", - show=False, - ), - 0, - ) + # writer.add_figure( + # "fiber response", + # self.plot_model_response( + # trial, + # model=model, + # title_append=title_append, + # subtitle=subtitle, + # show=False, + # ), + # 0, + # ) + # writer.add_figure( + # "eye diagram", + # self.plot_model_response( + # trial, + # model=self.model, + # title_append=title_append, + # subtitle=subtitle, + # mode="eye", + # show=False, + # ), + # 0, + # ) - writer.add_figure( - "powers", - self.plot_model_response( - trial, - model=self.model, - title_append=title_append, - subtitle=subtitle, - mode="powers", - show=False, - ), - 0, - ) + # writer.add_figure( + # "powers", + # self.plot_model_response( + # trial, + # model=self.model, + # title_append=title_append, + # subtitle=subtitle, + # mode="powers", + # show=False, + # ), + # 0, + # ) train_loader, valid_loader = self.get_sliced_data(trial) optimizer_name = trial.suggest_categorical_optional("optimizer", self.optimizer_settings.optimizer) - lr = trial.suggest_float_optional("lr", self.optimizer_settings.learning_rate, log=True) + lr = trial.suggest_float_optional("lr", self.optimizer_settings.optimizer_kwargs["lr"], log=True) optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr) - if self.optimizer_settings.scheduler is not None: - scheduler = getattr(optim.lr_scheduler, self.optimizer_settings.scheduler)( - optimizer, **self.optimizer_settings.scheduler_kwargs - ) + # if self.optimizer_settings.scheduler is not None: + # scheduler = getattr(optim.lr_scheduler, self.optimizer_settings.scheduler)( + # optimizer, **self.optimizer_settings.scheduler_kwargs + # ) for epoch in range(self.pytorch_settings.epochs): trial.set_user_attr("epoch", epoch) @@ -628,8 +635,8 @@ class HyperTraining: writer, # enable_progress=enable_progress, ) - if self.optimizer_settings.scheduler is not None: - scheduler.step(error) + # if self.optimizer_settings.scheduler is not None: + # scheduler.step(error) trial.set_user_attr("mse", error) trial.set_user_attr("log_mse", np.log10(error + np.finfo(float).eps)) @@ -645,10 +652,10 @@ class HyperTraining: if self.optuna_settings._multi_objective: return -np.log10(error + np.finfo(float).eps), trial.user_attrs.get("model_n_nodes", -1) - if self.pytorch_settings.save_models and model is not None: - save_path = Path(self.pytorch_settings.model_dir) / f"{self.optuna_settings.study_name}_{trial.number}.pth" - save_path.parent.mkdir(parents=True, exist_ok=True) - torch.save(model, save_path) + # if self.pytorch_settings.save_models and model is not None: + # save_path = Path(self.pytorch_settings.model_dir) / f"{self.optuna_settings.study_name}_{trial.number}.pth" + # save_path.parent.mkdir(parents=True, exist_ok=True) + # torch.save(model, save_path) return error diff --git a/src/single-core-regen/hypertraining/models.py b/src/single-core-regen/hypertraining/models.py index f9b8b03..22240b1 100644 --- a/src/single-core-regen/hypertraining/models.py +++ b/src/single-core-regen/hypertraining/models.py @@ -8,7 +8,8 @@ from util.complexNN import ( photodiode, EOActivation, polarimeter, - normalize_by_first + # normalize_by_first, + rotate, ) @@ -19,11 +20,11 @@ class polarisation_estimator2(Module): polarimeter(), torch.nn.Linear(4, 4), torch.nn.ReLU(), - torch.nn.Dropout(p=0.01), + # torch.nn.Dropout(p=0.01), torch.nn.Linear(4, 4), torch.nn.ReLU(), - torch.nn.Dropout(p=0.01), - torch.nn.Linear(4, 4), + # torch.nn.Dropout(p=0.01), + torch.nn.Linear(4, 1), ) def forward(self, x): @@ -124,6 +125,7 @@ class regenerator(Module): dtype=torch.float64, dropout_prob=0.01, scale_layers=False, + rotate=False, ): super(regenerator, self).__init__() self._n_hidden_layers = len(dims) - 2 @@ -131,6 +133,8 @@ class regenerator(Module): layer_func_kwargs = layer_func_kwargs or {} act_func_kwargs = act_func_kwargs or {} + self.rotation = rotate + self.build_model(dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob, scale_layers) def build_model(self, dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob, scale_layers): @@ -146,8 +150,9 @@ class regenerator(Module): module = act_function(size=dims[i + 1], **act_func_kwargs) self.get_submodule(f"layer_{i}").add_module("activation", module) - module = DropoutComplex(p=dropout_prob) - self.get_submodule(f"layer_{i}").add_module("dropout", module) + if dropout_prob is not None and dropout_prob > 0: + module = DropoutComplex(p=dropout_prob) + self.get_submodule(f"layer_{i}").add_module("dropout", module) self.add_module(f"layer_{self._n_hidden_layers}", Sequential()) @@ -160,6 +165,10 @@ class regenerator(Module): module = act_function(size=dims[-1], **act_func_kwargs) self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("activation", module) + if self.rotation: + module = rotate() + self.add_module("rotate", module) + # module = Scale(size=dims[-1]) # self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("out_scale", module) @@ -190,15 +199,27 @@ class regenerator(Module): powers.append(x.abs().square().sum()) return powers - def forward(self, x, trace_powers=False): + def forward(self, x, angle=None, pre_rot=False, trace_powers=False): powers = self._trace_powers(trace_powers, x) - x = self.layer_0(x) - powers = self._trace_powers(trace_powers, x, powers) - for i in range(1, self._n_hidden_layers): + # x = self.layer_0(x) + # powers = self._trace_powers(trace_powers, x, powers) + for i in range(0, self._n_hidden_layers): x = getattr(self, f"layer_{i}")(x) powers = self._trace_powers(trace_powers, x, powers) x = getattr(self, f"layer_{self._n_hidden_layers}")(x) - powers = self._trace_powers(trace_powers, x, powers) - if trace_powers: - return x, powers - return x \ No newline at end of file + if self.rotation: + try: + x_rot = self.rotate(x, angle) + except AttributeError: + pass + powers = self._trace_powers(trace_powers, x_rot, powers) + else: + x_rot = x + + if pre_rot and trace_powers: + return x_rot, x, powers + if pre_rot and not trace_powers: + return x_rot, x + if not pre_rot and trace_powers: + return x_rot, powers + return x_rot \ No newline at end of file diff --git a/src/single-core-regen/hypertraining/settings.py b/src/single-core-regen/hypertraining/settings.py index 797cb89..1e144ff 100644 --- a/src/single-core-regen/hypertraining/settings.py +++ b/src/single-core-regen/hypertraining/settings.py @@ -22,6 +22,8 @@ class DataSettings: train_split: float = 0.8 polarisations: tuple | list = (0,) randomise_polarisations: bool = False + osnr: float | int = None + seed: int = None """ change to: diff --git a/src/single-core-regen/hypertraining/training.py b/src/single-core-regen/hypertraining/training.py index fb9a7a0..22fb705 100644 --- a/src/single-core-regen/hypertraining/training.py +++ b/src/single-core-regen/hypertraining/training.py @@ -2,7 +2,6 @@ import copy from datetime import datetime from pathlib import Path import random -from typing import Literal import matplotlib from matplotlib.colors import LinearSegmentedColormap import torch.nn.utils.parametrize @@ -60,33 +59,35 @@ def traverse_dict_update(target, source): except TypeError: target.__dict__[k] = v + def get_parameter_names_and_values(model): def is_parametrized(module): if hasattr(module, "parametrizations"): return True return False - def _get_param_info(module, prefix='', parametrization=False): + def _get_param_info(module, prefix="", parametrization=False): param_list = [] - for name, param in module.named_parameters(recurse = parametrization): + for name, param in module.named_parameters(recurse=parametrization): if parametrization and name.startswith("parametrizations"): - name_parts = name.split('.') + name_parts = name.split(".") name = name_parts[1] param = getattr(module, name) - full_name = prefix + ('.' if prefix else '') + name + full_name = prefix + ("." if prefix else "") + name param_value = param.data param_list.append((full_name, param_value)) - + for child_name, child_module in module.named_children(): - child_prefix = prefix + ('.' if prefix else '') + child_name + child_prefix = prefix + ("." if prefix else "") + child_name if child_name == "parametrizations": continue param_list.extend(_get_param_info(child_module, child_prefix, is_parametrized(child_module))) - + return param_list return _get_param_info(model) + class PolarizationTrainer: def __init__( self, @@ -101,7 +102,7 @@ class PolarizationTrainer: settings_override=None, reset_epoch=False, ): - self.mod = torch.pi/2 + self.mod = torch.pi / 2 self.resume = checkpoint_path is not None torch.serialization.add_safe_globals([ *util.complexNN.__all__, @@ -219,7 +220,7 @@ class PolarizationTrainer: # dims = self.model_kwargs.pop("dims") model_kwargs = copy.deepcopy(self.model_kwargs) - self.model = models.polarisation_estimator(*model_kwargs.pop('dims'),**model_kwargs) + self.model = models.polarisation_estimator(*model_kwargs.pop("dims"), **model_kwargs) # self.model = models.polarisation_estimator2() if self.writer is not None: @@ -336,17 +337,20 @@ class PolarizationTrainer: write_div = 0 loss_div = 0 for batch_idx, batch in enumerate(train_loader): - x = batch["x"] - y = batch["sop"] + x = batch["angle_data2"] + y = batch["center_angle"] self.model.zero_grad(set_to_none=True) x, y = ( x.to(self.pytorch_settings.device), y.to(self.pytorch_settings.device), ) - y_pred = self.model(x) + y_pred = self.model(x).abs().real + # y_pred = torch.fmod(y_pred, self.mod) + y = y.abs().real + # y = torch.fmod(y, self.mod) + # loss = torch.nn.functional.mse_loss(torch.cos(y_pred), torch.cos(y)) # loss = torch.nn.functional.smooth_l1_loss(torch.cos(torch.fmod(y_pred, torch.pi/2)).squeeze(), torch.cos(torch.fmod(y, torch.pi/2)).squeeze(), beta=0.5) - loss = torch.nn.functional.mse_loss(y_pred, y) - # loss = util.complexNN.naive_angle_loss(y_pred, y, mod=torch.pi/2) + loss = util.complexNN.naive_angle_loss(y_pred, y, mod=self.mod) loss_value = loss.item() loss.backward() optimizer.step() @@ -356,7 +360,7 @@ class PolarizationTrainer: loss_div += 1 if enable_progress: - progress.update(task, advance=1, description=f"{loss_value:.3e}") + progress.update(task, advance=1, description=f"{loss_value/np.pi*180:.3e} °") if batch_idx % self.pytorch_settings.write_every == 0: self.writer.add_scalar( @@ -395,24 +399,28 @@ class PolarizationTrainer: loss_div = 0 with torch.no_grad(): for _, batch in enumerate(valid_loader): - x = batch["x"] - y = batch["sop"] + # x = batch["angle_data2"] + x = batch["angle_data2"] + y = batch["center_angle"] x, y = ( x.to(self.pytorch_settings.device), y.to(self.pytorch_settings.device), ) - y_pred = self.model(x) + y_pred = self.model(x).abs().real + # y_pred = torch.fmod(y_pred, self.mod) + y = y.abs().real + # y = torch.fmod(y, self.mod) + # loss = torch.nn.functional.mse_loss(torch.cos(y_pred), torch.cos(y)) # loss = torch.nn.functional.smooth_l1_loss(torch.cos(torch.fmod(y_pred, torch.pi/2)).squeeze(), torch.cos(torch.fmod(y, torch.pi/2)).squeeze(), beta=0.5) - loss = torch.nn.functional.mse_loss(y_pred, y) - # loss = util.complexNN.naive_angle_loss(y_pred, y, mod=torch.pi/2) + loss = util.complexNN.naive_angle_loss(y_pred, y, mod=self.mod) loss_value = loss.item() running_loss += loss_value loss_div += 1 if enable_progress: - progress.update(task, advance=1, description=f"{loss_value:.3e}") + progress.update(task, advance=1, description=f"{loss_value/np.pi*180:.3e} °") - running_loss = running_loss/loss_div + running_loss = running_loss / loss_div self.writer.add_scalar( "eval loss", @@ -506,19 +514,19 @@ class PolarizationTrainer: for i, config_path in enumerate(self.data_settings.config_path): paths = Path.cwd().glob(config_path) for j, path in enumerate(paths): - text = str(path) + '\n' - with open(path, 'r') as f: + text = str(path) + "\n" + with open(path, "r") as f: text += f.read() - text += '\n' - self.writer.add_text(f"config_{i*len(self.data_settings.config_path)+j}", text) + text += "\n" + self.writer.add_text(f"config_{i * len(self.data_settings.config_path) + j}", text) elif isinstance(self.data_settings.config_path, str): paths = Path.cwd().glob(self.data_settings.config_path) for j, path in enumerate(paths): - text = str(path) + '\n' - with open(path, 'r') as f: + text = str(path) + "\n" + with open(path, "r") as f: text += f.read() - text += '\n' + text += "\n" self.writer.add_text(f"config_{j}", text) self.writer.flush() @@ -571,7 +579,8 @@ class PolarizationTrainer: if loss < self.best["loss"]: self.best = checkpoint save_path = ( - Path(self.pytorch_settings.model_dir) / f"best_pol_{self.writer.get_logdir().split('/')[-1]}.tar" + Path(self.pytorch_settings.model_dir) + / f"best_pol_{self.writer.get_logdir().split('/')[-1]}.tar" ) save_path.parent.mkdir(parents=True, exist_ok=True) self.save_checkpoint(self.best, save_path) @@ -580,6 +589,7 @@ class PolarizationTrainer: self.writer.close() return self.best + class RegenerationTrainer: def __init__( self, @@ -636,6 +646,10 @@ class RegenerationTrainer: self.model_settings: ModelSettings = model_settings self.optimizer_settings: OptimizerSettings = optimizer_settings + if self.global_settings.seed is not None: + random.seed(self.global_settings.seed) + np.random.seed(self.global_settings.seed) + self.console = console or Console() self.writer = None @@ -706,10 +720,12 @@ class RegenerationTrainer: # dims = self.model_kwargs.pop("dims") model_kwargs = copy.deepcopy(self.model_kwargs) - self.model = models.regenerator(*model_kwargs.pop('dims'),**model_kwargs) + self.model = models.regenerator(*model_kwargs.pop("dims"), **model_kwargs) if self.writer is not None: - self.writer.add_graph(self.model, torch.zeros(1, input_dim, dtype=dtype)) + self.writer.add_graph( + self.model, (torch.rand(1, input_dim, dtype=dtype), torch.rand(1, 1, dtype=dtype.to_real())) + ) self.model = self.model.to(self.pytorch_settings.device) if self.resume: @@ -728,12 +744,12 @@ class RegenerationTrainer: num_symbols = None config_path = self.data_settings.config_path - polarisations = self.data_settings.polarisations randomise_polarisations = self.data_settings.randomise_polarisations + osnr = self.data_settings.osnr if override is not None: num_symbols = override.get("num_symbols", None) config_path = override.get("config_path", config_path) - polarisations = override.get("polarisations", polarisations) + # polarisations = override.get("polarisations", polarisations) randomise_polarisations = override.get("randomise_polarisation", randomise_polarisations) # get dataset dataset = FiberRegenerationDataset( @@ -746,8 +762,8 @@ class RegenerationTrainer: dtype=dtype, real=not dtype.is_complex, num_symbols=num_symbols, - polarisations=polarisations, randomise_polarisations=randomise_polarisations, + osnr = osnr, ) dataset_size = len(dataset) @@ -819,12 +835,14 @@ class RegenerationTrainer: for batch_idx, batch in enumerate(train_loader): x = batch["x"] y = batch["y"] + angles = batch["mean_angle"] self.model.zero_grad(set_to_none=True) - x, y = ( + x, y, angles = ( x.to(self.pytorch_settings.device), y.to(self.pytorch_settings.device), + angles.to(self.pytorch_settings.device), ) - y_pred = self.model(x) + y_pred = self.model(x, -angles) loss = util.complexNN.complex_mse_loss(y_pred, y, power=True) loss_value = loss.item() loss.backward() @@ -872,11 +890,13 @@ class RegenerationTrainer: for _, batch in enumerate(valid_loader): x = batch["x"] y = batch["y"] - x, y = ( + angles = batch["mean_angle"] + x, y, angles = ( x.to(self.pytorch_settings.device), y.to(self.pytorch_settings.device), + angles.to(self.pytorch_settings.device), ) - y_pred = self.model(x) + y_pred = self.model(x, -angles) error = util.complexNN.complex_mse_loss(y_pred, y, power=True) error_value = error.item() running_error += error_value @@ -884,7 +904,7 @@ class RegenerationTrainer: if enable_progress: progress.update(task, advance=1, description=f"{error_value:.3e}") - running_error = running_error/len(valid_loader) + running_error = running_error / len(valid_loader) self.writer.add_scalar( "eval loss", @@ -928,45 +948,65 @@ class RegenerationTrainer: def run_model(self, model, loader, trace_powers=False): model.eval() fiber_out = [] + fiber_out_rot = [] fiber_in = [] regen = [] timestamps = [] + angles = [] with torch.no_grad(): model = model.to(self.pytorch_settings.device) for batch in loader: x = batch["x"] y = batch["y"] + plot_target = batch["plot_target"] + angle = batch["mean_angle"] + center_angle = batch["center_angle"] timestamp = batch["timestamp"] plot_data = batch["plot_data"] - x, y = ( + plot_data_rot = batch["plot_data_rot"] + x, y, angle = ( x.to(self.pytorch_settings.device), y.to(self.pytorch_settings.device), + angle.to(self.pytorch_settings.device), ) if trace_powers: - y_pred, powers = model(x, trace_powers).cpu() + y_pred, powers = model(x, angle, True).cpu() else: - y_pred = model(x, trace_powers).cpu() + y_pred = model(x, angle).cpu() # x = x.cpu() # y = y.cpu() y_pred = y_pred.view(y_pred.shape[0], -1, 2) + y_pred = y_pred[:, y_pred.shape[1]//2, :] y = y.view(y.shape[0], -1, 2) - plot_data = plot_data.view(plot_data.shape[0], -1, 2) + # plot_data = plot_data.view(plot_data.shape[0], -1, 2) + # c = torch.cos(-angle).cpu() + # s = torch.sin(-angle).cpu() + # rot = torch.stack([torch.stack([c, -s], dim=1), torch.stack([s, c], dim=1)], dim=2).squeeze(-1) + # plot_data = torch.bmm(plot_data, rot.to(dtype=plot_data.dtype)) + # plot_data = plot_data + # sines = torch.sin(-angle.cpu()) + # cosines = torch.cos(-angle.cpu()) + # plot_data = torch.stack((plot_data[..., 0] * cosines - plot_data[..., 1] * sines, plot_data[..., 0] * sines + plot_data[..., 1] * cosines), dim=-1) # x = x.view(x.shape[0], -1, 2) # timestamp = timestamp.view(-1, 1) fiber_out.append(plot_data.squeeze()) - fiber_in.append(y.squeeze()) + fiber_out_rot.append(plot_data_rot.squeeze()) + fiber_in.append(plot_target.squeeze()) regen.append(y_pred.squeeze()) timestamps.append(timestamp.squeeze()) + angles.append(center_angle.squeeze()) fiber_out = torch.vstack(fiber_out).cpu() + fiber_out_rot = torch.vstack(fiber_out_rot).cpu() fiber_in = torch.vstack(fiber_in).cpu() regen = torch.vstack(regen).cpu() + angles = torch.vstack(angles).cpu() timestamps = torch.concat(timestamps).cpu() if trace_powers: - return fiber_in, fiber_out, regen, timestamps, powers - return fiber_in, fiber_out, regen, timestamps + return fiber_in, fiber_out, fiber_out_rot, angles, regen, timestamps, powers + return fiber_in, fiber_out, fiber_out_rot, angles, regen, timestamps def write_parameters(self, epoch, attributes: list[str] | tuple[str] = None): parameter_list = get_parameter_names_and_values(self.model) @@ -1027,18 +1067,18 @@ class RegenerationTrainer: for i, config_path in enumerate(self.data_settings.config_path): paths = Path.cwd().glob(config_path) for j, path in enumerate(paths): - text = str(path) + '\n' - with open(path, 'r') as f: + text = str(path) + "\n" + with open(path, "r") as f: text += f.read() - text += '\n' - self.writer.add_text(f"config_{i*len(self.data_settings.config_path)+j}", text) + text += "\n" + self.writer.add_text(f"config_{i * len(self.data_settings.config_path) + j}", text) elif isinstance(self.data_settings.config_path, str): paths = Path.cwd().glob(self.data_settings.config_path) for j, path in enumerate(paths): - text = str(path) + '\n' - with open(path, 'r') as f: + text = str(path) + "\n" + with open(path, "r") as f: text += f.read() - text += '\n' + text += "\n" self.writer.add_text(f"config_{j}", text) self.writer.flush() @@ -1116,6 +1156,7 @@ class RegenerationTrainer: powers = [power / powers[0] for power in powers] fig, ax = plt.subplots() fig.set_figwidth(18) + fig.set_figheight(4) fig.suptitle( f"Energy conservation{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}" ) @@ -1184,6 +1225,7 @@ class RegenerationTrainer: fig, axs = plt.subplots(1, 2 * len(signals), sharex=True, sharey=True) fig.set_figwidth(18) + fig.set_figheight(4) fig.suptitle(f"Eye diagram{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}") # xaxis = timestamps / sps # xaxis = np.arange(2 * sps) / sps @@ -1253,7 +1295,7 @@ class RegenerationTrainer: xaxis = timestamps / sps else: xaxis = timestamps - ax.plot(xaxis, np.abs(signal[:, i]) ** 2, label=label) + ax.plot(xaxis, np.abs(signal[:, i]) ** 2, label=label, alpha=0.7) ax.set_xlabel("Sample" if sps is None else "Symbol") ax.set_ylabel("normalized power") ax.minorticks_on() @@ -1269,7 +1311,7 @@ class RegenerationTrainer: def plot_model_response( self, - model:torch.nn.Module=None, + model: torch.nn.Module = None, title_append="", subtitle="", # mode: Literal["eye", "head", "powers"] = "head", @@ -1281,7 +1323,9 @@ class RegenerationTrainer: model = model.to(self.pytorch_settings.device) model.eval() with torch.no_grad(): - _, powers = model(input_data, trace_powers=True) + _, powers = model( + input_data, torch.zeros(input_data.shape[0], 1).to(self.pytorch_settings.device), trace_powers=True + ) powers = [power.item() for power in powers] layer_names = [name for (name, _) in model.named_children()] @@ -1296,29 +1340,42 @@ class RegenerationTrainer: self.data_settings.shuffle = False self.data_settings.train_split = 1.0 self.pytorch_settings.batchsize = max(self.pytorch_settings.head_symbols, self.pytorch_settings.eye_symbols) - config_path = random.choice(self.data_settings.config_path) if isinstance(self.data_settings.config_path, (list, tuple)) else self.data_settings.config_path - fiber_length = int(float(str(config_path).split('-')[4])/1000) + config_path = ( + random.choice(self.data_settings.config_path) + if isinstance(self.data_settings.config_path, (list, tuple)) + else self.data_settings.config_path + ) + fiber_length = int(float(str(config_path).split("-")[4]) / 1000) if not hasattr(self, "_plot_loader"): self._plot_loader, _ = self.get_sliced_data( override={ "num_symbols": self.pytorch_settings.batchsize, "config_path": config_path, "shuffle": False, - "polarisations": (np.random.rand(1)*np.pi*2,), - "randomise_polarisation": False, + "polarisations": (np.random.rand(1) * np.pi * 2,), + "randomise_polarisation": self.data_settings.randomise_polarisations, } ) self._sps = self._plot_loader.dataset.samples_per_symbol self.data_settings = data_settings_backup self.pytorch_settings = pytorch_settings_backup - fiber_in, fiber_out, regen, timestamps = self.run_model(model, self._plot_loader) + fiber_in, fiber_out, fiber_out_rot, angles, regen, timestamps = self.run_model(model, self._plot_loader) fiber_in = fiber_in.view(-1, 2) fiber_out = fiber_out.view(-1, 2) + fiber_out_rot = fiber_out_rot.view(-1, 2) + angles = angles.view(-1, 1) + angles = angles.real + angles = torch.fmod(angles, 2 * torch.pi) + angles = torch.div(angles, 2*torch.pi) + angles = torch.repeat_interleave(angles, 2, dim=1) + regen = regen.view(-1, 2) fiber_in = fiber_in.numpy() fiber_out = fiber_out.numpy() + fiber_out_rot = fiber_out_rot.numpy() + angles = angles.numpy() regen = regen.numpy() timestamps = timestamps.numpy() @@ -1327,28 +1384,29 @@ class RegenerationTrainer: import gc head_fig = self._plot_model_response_head( - fiber_in[:self.pytorch_settings.head_symbols*self._sps], - fiber_out[:self.pytorch_settings.head_symbols*self._sps], - regen[:self.pytorch_settings.head_symbols*self._sps], - timestamps=timestamps[:self.pytorch_settings.head_symbols*self._sps], + fiber_out_rot[: self.pytorch_settings.head_symbols * self._sps], + fiber_in[: self.pytorch_settings.head_symbols * self._sps], + regen[: self.pytorch_settings.head_symbols * self._sps], + angles[: self.pytorch_settings.head_symbols * self._sps], + timestamps=timestamps[: self.pytorch_settings.head_symbols * self._sps], + labels=("fiber out", "fiber in", "regen", "normed angle"), + sps=self._sps, + title_append=title_append + f" ({fiber_length} km)", + subtitle=subtitle, + show=show, + ) + # raise NotImplementedError("Eye diagram not implemented") + eye_fig = self._plot_model_response_eye( + fiber_in[: self.pytorch_settings.eye_symbols * self._sps], + fiber_out_rot[: self.pytorch_settings.eye_symbols * self._sps], + regen[: self.pytorch_settings.eye_symbols * self._sps], + timestamps=timestamps[: self.pytorch_settings.eye_symbols * self._sps], labels=("fiber in", "fiber out", "regen"), sps=self._sps, title_append=title_append + f" ({fiber_length} km)", subtitle=subtitle, show=show, ) - # raise NotImplementedError("Eye diagram not implemented") - eye_fig = self._plot_model_response_eye( - fiber_in[:self.pytorch_settings.eye_symbols*self._sps], - fiber_out[:self.pytorch_settings.eye_symbols*self._sps], - regen[:self.pytorch_settings.eye_symbols*self._sps], - timestamps=timestamps[:self.pytorch_settings.eye_symbols*self._sps], - labels=("fiber in", "fiber out", "regen"), - sps=self._sps, - title_append=title_append + f" ({fiber_length} km)", - subtitle=subtitle, - show=show, - ) gc.collect() return head_fig, eye_fig, power_fig @@ -1361,7 +1419,7 @@ class RegenerationTrainer: self.model_settings.overrides.get(f"n_hidden_nodes_{i}", -1) for i in range(model_n_hidden_layers) ] model_dims.insert(0, input_dim) - model_dims.append(2) + model_dims.append(self.model_settings.output_dim) model_dims = [str(dim) for dim in model_dims] model_activation_func = self.model_settings.model_activation_func model_dtype = self.data_settings.dtype diff --git a/src/single-core-regen/regen.py b/src/single-core-regen/regen.py index 73096e8..ccd7077 100644 --- a/src/single-core-regen/regen.py +++ b/src/single-core-regen/regen.py @@ -1,6 +1,8 @@ from datetime import datetime import optuna +import torch +import util from hypertraining.hypertraining import HyperTraining from hypertraining.settings import ( GlobalSettings, @@ -16,24 +18,29 @@ global_settings = GlobalSettings( ) data_settings = DataSettings( - config_path="data/*-128-16384-100000-0-0-17-0-PAM4-0.ini", + # config_path="data/*-128-16384-100000-0-0-17-0-PAM4-0.ini", + config_path="data/20241204-131003-128-16384-100000-0-0-17-0-PAM4-0.ini", dtype="complex64", # symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber - symbols=13, # study: single_core_regen_20241123_011232 + # symbols=13, # study: single_core_regen_20241123_011232 + # symbols = (3, 13), + symbols=4, # output_size = (11, 32), # ballpark 26 taps -> 2 taps per input symbol -> 1 tap every 0.01m (model has 52 inputs) - output_size=26, # study: single_core_regen_20241123_011232 (model_input_dim/2) + # output_size=26, # study: single_core_regen_20241123_011232 (model_input_dim/2) + output_size=(8, 30), shuffle=True, in_out_delay=0, xy_delay=0, - drop_first=128 * 100, + drop_first=256, train_split=0.8, + randomise_polarisations=False, ) pytorch_settings = PytorchSettings( - epochs=10000, + epochs=10, batchsize=2**10, device="cuda", - dataloader_workers=12, + dataloader_workers=4, dataloader_prefetch=4, summary_dir=".runs", write_every=2**5, @@ -43,28 +50,70 @@ pytorch_settings = PytorchSettings( model_settings = ModelSettings( output_dim=2, - # n_hidden_layers = (3, 8), - n_hidden_layers=4, - overrides={ - "n_hidden_nodes_0": 8, - "n_hidden_nodes_1": 6, - "n_hidden_nodes_2": 4, - "n_hidden_nodes_3": 8, - }, - model_activation_func="Mag", - # satabsT0=(1e-6, 1), + n_hidden_layers = (2, 5), + n_hidden_nodes=(2, 16), + model_activation_func="EOActivation", + dropout_prob=0, + model_layer_function="ONNRect", + model_layer_kwargs={"square": True}, + # scale=(False, True), + scale=False, + model_layer_parametrizations=[ + { + "tensor_name": "weight", + "parametrization": util.complexNN.energy_conserving, + }, + { + "tensor_name": "alpha", + "parametrization": util.complexNN.clamp, + }, + { + "tensor_name": "gain", + "parametrization": util.complexNN.clamp, + "kwargs": { + "min": 0, + "max": float("inf"), + }, + }, + { + "tensor_name": "phase_bias", + "parametrization": util.complexNN.clamp, + "kwargs": { + "min": 0, + "max": 2 * torch.pi, + }, + }, + { + "tensor_name": "scales", + "parametrization": util.complexNN.clamp, + }, + { + "tensor_name": "angle", + "parametrization": util.complexNN.clamp, + "kwargs": { + "min": -torch.pi, + "max": torch.pi, + }, + }, + { + "tensor_name": "loss", + "parametrization": util.complexNN.clamp, + }, + ], ) optimizer_settings = OptimizerSettings( - optimizer="Adam", - # learning_rate = (1e-5, 1e-1), - learning_rate=5e-3 - # learning_rate=5e-4, + optimizer="AdamW", + optimizer_kwargs={ + "lr": 5e-3, + "amsgrad": True, + # "weight_decay": 1e-7, + }, ) optuna_settings = OptunaSettings( - n_trials=1, - n_workers=1, + n_trials=1024, + n_workers=8, timeout=3600, directions=("minimize",), metrics_names=("mse",), diff --git a/src/single-core-regen/regen_no_hyper.py b/src/single-core-regen/regen_no_hyper.py index 5c64be2..84ca548 100644 --- a/src/single-core-regen/regen_no_hyper.py +++ b/src/single-core-regen/regen_no_hyper.py @@ -26,24 +26,26 @@ global_settings = GlobalSettings( ) data_settings = DataSettings( - config_path="data/*-128-16384-10000-0-0-0-0-PAM4-0-0.4.ini", + # config_path="data/*-128-16384-1-0-0-0-0-PAM4-0-0.ini", + config_path="data/*-128-16384-10000-0-0-17-0-PAM4-0.ini", # config_path=[f"data/20241202-*-128-16384-{length}-0-0-17-0-PAM4-0.ini" for length in range(48000, 53000, 1000)], dtype="complex64", # symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber - symbols=13, # study: single_core_regen_20241123_011232 + symbols=4, # study: single_core_regen_20241123_011232 # output_size = (11, 32), # 26 taps -> 2 taps per input symbol -> 1 tap every 1cm (model has 52 inputs (x/y)) - output_size=26, # study: single_core_regen_20241123_011232 (model_input_dim/2) + output_size=20, # study: single_core_regen_20241123_011232 (model_input_dim/2) shuffle=True, drop_first=64, train_split=0.8, randomise_polarisations=True, + osnr=10, ) pytorch_settings = PytorchSettings( epochs=10000, batchsize=2**14, device="cuda", - dataloader_workers=16, + dataloader_workers=24, dataloader_prefetch=8, summary_dir=".runs", write_every=2**5, @@ -53,17 +55,17 @@ pytorch_settings = PytorchSettings( model_settings = ModelSettings( output_dim=2, - n_hidden_layers=5, + n_hidden_layers=3, overrides={ # "hidden_layer_dims": (8, 8, 4, 4), - "n_hidden_nodes_0": 8, + "n_hidden_nodes_0": 16, "n_hidden_nodes_1": 8, - "n_hidden_nodes_2": 4, - "n_hidden_nodes_3": 4, - "n_hidden_nodes_4": 2, + "n_hidden_nodes_2": 8, + # "n_hidden_nodes_3": 4, + # "n_hidden_nodes_4": 2, }, model_activation_func="EOActivation", - dropout_prob=0.01, + dropout_prob=0, model_layer_function="ONNRect", model_layer_kwargs={"square": True}, scale=False, @@ -126,7 +128,7 @@ model_settings = ModelSettings( optimizer_settings = OptimizerSettings( optimizer="AdamW", optimizer_kwargs={ - "lr": 0.01, + "lr": 0.005, "amsgrad": True, # "weight_decay": 1e-7, }, @@ -242,7 +244,15 @@ if __name__ == "__main__": pytorch_settings=pytorch_settings, model_settings=model_settings, optimizer_settings=optimizer_settings, - # checkpoint_path=".models/best_20241205_235929.tar", + checkpoint_path=".models/best_20241216_221359.tar", + reset_epoch=True, + # settings_override={ + # "optimizer_settings": { + # "optimizer_kwargs": { + # "lr": 0.01, + # }, + # } + # } # 20241202_143149 ) trainer.train() diff --git a/src/single-core-regen/train_pol_estimator.py b/src/single-core-regen/train_pol_estimator.py index 153f807..1da060e 100644 --- a/src/single-core-regen/train_pol_estimator.py +++ b/src/single-core-regen/train_pol_estimator.py @@ -26,7 +26,7 @@ global_settings = GlobalSettings( ) data_settings = DataSettings( - config_path="data/*-128-16384-10000-0-0-0-0-PAM4-0-0.4.ini", + config_path="data/20241211-105524-128-16384-1-0-0-0-0-PAM4-0-0.ini", # config_path=[f"data/20241202-*-128-16384-{length}-0-0-17-0-PAM4-0.ini" for length in range(48000, 53000, 1000)], dtype="complex64", # symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber @@ -53,14 +53,14 @@ pytorch_settings = PytorchSettings( ) model_settings = ModelSettings( - output_dim=3, + output_dim=1, n_hidden_layers=3, overrides={ - "n_hidden_nodes_0": 2, - "n_hidden_nodes_1": 2, - "n_hidden_nodes_2": 2, + "n_hidden_nodes_0": 4, + "n_hidden_nodes_1": 4, + "n_hidden_nodes_2": 4, }, - dropout_prob=0.01, + dropout_prob=0, model_layer_function="ONNRect", model_activation_func="EOActivation", model_layer_kwargs={"square": True}, @@ -110,20 +110,24 @@ model_settings = ModelSettings( ) optimizer_settings = OptimizerSettings( - optimizer="AdamW", + optimizer="RMSprop", + # optimizer="AdamW", optimizer_kwargs={ - "lr": 0.005, - "amsgrad": True, + "lr": 0.01, + "alpha": 0.9, + "momentum": 0.1, + "eps": 1e-8, + "centered": True, + # "amsgrad": True, # "weight_decay": 1e-7, }, - # learning_rate=0.05, scheduler="ReduceLROnPlateau", scheduler_kwargs={ - "patience": 2**6, + "patience": 2**5, "factor": 0.75, # "threshold": 1e-3, "min_lr": 1e-6, - "cooldown": 10, + # "cooldown": 10, }, ) diff --git a/src/single-core-regen/util/complexNN.py b/src/single-core-regen/util/complexNN.py index cdcf4d6..9e2cfe0 100644 --- a/src/single-core-regen/util/complexNN.py +++ b/src/single-core-regen/util/complexNN.py @@ -319,6 +319,29 @@ class normalize_by_first(nn.Module): def forward(self, data): return data / data[:, 0].unsqueeze(1) + +class rotate(nn.Module): + def __init__(self): + super(rotate, self).__init__() + + def forward(self, data, angle): + # data -> (batch, n*2) + # angle -> (batch, n) + data_ = data + if angle.ndim == 1: + angle_ = angle.unsqueeze(1) + else: + angle_ = angle + angle_ = angle_.expand(-1, data_.shape[1]//2) + c = torch.cos(angle_) + s = torch.sin(angle_) + rot = torch.stack([torch.stack([c, -s], dim=2), + torch.stack([s, c], dim=2)], dim=3) + d = torch.bmm(data_.reshape(-1, 1, 2), rot.view(-1, 2, 2).to(dtype=data_.dtype)).reshape(*data.shape) + # d = torch.bmm(data.unsqueeze(-1).mT, rot.to(dtype=data.dtype).mT).mT.squeeze(-1) + + return d + class photodiode(nn.Module): def __init__(self, size, bias=True): @@ -487,7 +510,7 @@ class MZISingle(nn.Module): return x * torch.exp(1j * self.phi) * torch.sin(self.omega + self.func(x)) def naive_angle_loss(x: torch.Tensor, target: torch.Tensor, mod=2*torch.pi): - return torch.fmod((x - target), mod).square().mean() + return torch.fmod((x.abs().real - target.abs().real), mod).abs().mean() def cosine_loss(x: torch.Tensor, target: torch.Tensor): return (2*(1 - torch.cos(x - target))).mean() diff --git a/src/single-core-regen/util/datasets.py b/src/single-core-regen/util/datasets.py index d56bafe..4967639 100644 --- a/src/single-core-regen/util/datasets.py +++ b/src/single-core-regen/util/datasets.py @@ -5,6 +5,7 @@ from torch.utils.data import Dataset # from torch.utils.data import Sampler import numpy as np import configparser +import multiprocessing as mp # class SubsetSampler(Sampler[int]): # """ @@ -113,8 +114,9 @@ class FiberRegenerationDataset(Dataset): dtype: torch.dtype = None, real: bool = False, device=None, - polarisations: tuple | list = (0,), + osnr: float = None, randomise_polarisations: bool = False, + repeat_randoms: int = 1, **kwargs, ): """ @@ -190,18 +192,20 @@ class FiberRegenerationDataset(Dataset): files.append(config["data"]["file"].strip('"')) self.config["data"]["file"] = str(files) - for i, angle in enumerate(torch.tensor(np.array(polarisations))): - data_raw_copy = data_raw.clone() - if angle == 0: - continue - sine = torch.sin(angle) - cosine = torch.cos(angle) - data_raw_copy[:, 2] = data_raw[:, 2] * cosine - data_raw[:, 3] * sine - data_raw_copy[:, 3] = data_raw[:, 2] * sine + data_raw[:, 3] * cosine - if i == 0: - data_raw = data_raw_copy - else: - data_raw = torch.cat([data_raw, data_raw_copy], dim=0) + # if polarisations is not None: + # self.angles = torch.tensor(polarisations).repeat(len(data_raw), 1) + # for i, angle in enumerate(torch.tensor(np.array(polarisations))): + # data_raw_copy = data_raw.clone() + # if angle == 0: + # continue + # sine = torch.sin(angle) + # cosine = torch.cos(angle) + # data_raw_copy[:, 2] = data_raw[:, 2] * cosine - data_raw[:, 3] * sine + # data_raw_copy[:, 3] = data_raw[:, 2] * sine + data_raw[:, 3] * cosine + # if i == 0: + # data_raw = data_raw_copy + # else: + # data_raw = torch.cat([data_raw, data_raw_copy], dim=0) self.device = data_raw.device @@ -278,23 +282,61 @@ class FiberRegenerationDataset(Dataset): timestamps = data_raw[4, :] data_raw = data_raw[:4, :] data_raw = data_raw.view(2, 2, -1) - timestamps_doubled = torch.cat([timestamps.unsqueeze(dim=0), timestamps.unsqueeze(dim=0)], dim=0).unsqueeze( - dim=1 - ) - data_raw = torch.cat([data_raw, timestamps_doubled], dim=1) + fiber_in = data_raw[0, :, :] + fiber_out = data_raw[1, :, :] + # timestamps_doubled = torch.cat([timestamps.unsqueeze(dim=0), timestamps.unsqueeze(dim=0)], dim=0).unsqueeze( + # dim=1 + # ) + fiber_in = torch.cat([fiber_in, timestamps.unsqueeze(0)], dim=0) + fiber_out = torch.cat([fiber_out, timestamps.unsqueeze(0)], dim=0) + + if repeat_randoms > 1: + fiber_in = fiber_in.repeat(1, 1, repeat_randoms) + fiber_out = fiber_out.repeat(1, 1, repeat_randoms) + # review: potential problems with repeated timestamps when plotting + else: + repeat_randoms = 1 + + if self.randomise_polarisations: + angles = torch.fmod(torch.randperm(data_raw.shape[-1]*repeat_randoms), 2) * torch.pi + # start_angle = torch.rand(1) * 2 * torch.pi + # angles = start_angle + torch.cumsum(torch.randn(data_raw.shape[-1])/333, dim=0) * torch.pi * 2 # random walk + # self.angles = torch.rand(self.data.shape[0]) * 2 * torch.pi + else: + angles = torch.zeros(data_raw.shape[-1]) + + sin = torch.sin(angles) + cos = torch.cos(angles) + rot = torch.stack([torch.stack([cos, -sin], dim=1), torch.stack([sin, cos], dim=1)], dim=2) + data_rot = torch.bmm(fiber_out[:2, :].T.unsqueeze(1), rot.to(dtype=data_raw.dtype)).squeeze(1).T + fiber_out = torch.cat((fiber_out, data_rot), dim=0) + fiber_out = torch.cat([fiber_out, angles.unsqueeze(0)], dim=0) + + if osnr is not None: + popt = torch.mean(fiber_out[:2, :, :].abs().flatten(), dim=-1) + noise = torch.randn_like(fiber_out[:2, :, :]) + pn = torch.mean(noise.abs().flatten(), dim=-1) + noise = noise * (popt / pn) * 10 ** (-osnr / 20) + fiber_out[:2, :, :] = torch.add(fiber_out[:2, :, :], noise) + + + + # data_raw = torch.cat([data_raw, timestamps_doubled], dim=1) # data layout # [ [E_in_x, E_in_y, timestamps], # [E_out_x, E_out_y, timestamps] ] - self.data = data_raw.unfold(dimension=-1, size=self.samples_per_slice, step=1) - self.data = self.data.movedim(-2, 0) + self.fiber_in = fiber_in.unfold(dimension=-1, size=self.samples_per_slice, step=1) + self.fiber_in = self.fiber_in.movedim(-2, 0) - if randomise_polarisations: - self.angles = torch.rand(self.data.shape[0]) * np.pi * 2 - # self.data[:, 1, :2, :] = self.rotate(self.data[:, 1, :2, :], self.angles) - else: - self.angles = torch.zeros(self.data.shape[0]) + self.fiber_out = fiber_out.unfold(dimension=-1, size=self.samples_per_slice, step=1) + self.fiber_out = self.fiber_out.movedim(-2, 0) + + # self.data = data_raw.unfold(dimension=-1, size=self.samples_per_slice, step=1) + # self.data = self.data.movedim(-2, 0) + # self.angles = torch.zeros(self.data.shape[0]) + ... # ... # -> [no_slices, 2, 3, samples_per_slice] @@ -305,51 +347,56 @@ class FiberRegenerationDataset(Dataset): # ... # ] -> [no_slices, 2, 3, samples_per_slice] - ... - def __len__(self): - return self.data.shape[0] + return self.fiber_in.shape[0] def __getitem__(self, idx): if isinstance(idx, slice): return [self.__getitem__(i) for i in range(*idx.indices(len(self)))] else: - data_slice = self.data[idx].squeeze() - - data_slice = data_slice[:, :, : data_slice.shape[2] // self.output_dim * self.output_dim] + # fiber in: [E_in_x, E_in_y, timestamps] + # fiber out: [E_out_x, E_out_y, timestamps, E_out_x_rot, E_out_y_rot, angle] - data_slice = data_slice.view(data_slice.shape[0], data_slice.shape[1], self.output_dim, -1) + fiber_in = self.fiber_in[idx].squeeze() + fiber_out = self.fiber_out[idx].squeeze() + + fiber_in = fiber_in[..., : fiber_in.shape[-1] // self.output_dim * self.output_dim] + fiber_out = fiber_out[..., : fiber_out.shape[-1] // self.output_dim * self.output_dim] + + fiber_in = fiber_in.view(fiber_in.shape[0], self.output_dim, -1) + fiber_out = fiber_out.view(fiber_out.shape[0], self.output_dim, -1) + + # data_slice = data_slice[:, :, : data_slice.shape[2] // self.output_dim * self.output_dim] + + # data_slice = data_slice.view(data_slice.shape[0], data_slice.shape[1], self.output_dim, -1) + + # angle = self.angles[idx] + + center_angle = fiber_out[5, self.output_dim // 2, 0] + angles = fiber_out[5, :, 0] + plot_data = fiber_out[:2, self.output_dim // 2, 0].detach().clone() + plot_data_rot = fiber_out[3:5, self.output_dim // 2, 0].detach().clone() + data = fiber_out[3:5, :, 0] # if self.randomise_polarisations: - # angle = torch.rand(1) * torch.pi * 2 - # sine = torch.sin(angle) - # cosine = torch.cos(angle) - # data_slice_ = data_slice[1] - # data_slice[1, 0] = data_slice_[0] * cosine - data_slice_[1] * sine - # data_slice[1,1] = data_slice_[0] * sine + data_slice_[1] * cosine - # else: - # angle = torch.zeros(1) + # data = data.mT + # c = torch.cos(angle).unsqueeze(-1) + # s = torch.sin(angle).unsqueeze(-1) + # rot = torch.stack([torch.stack([c, -s], dim=1), torch.stack([s, c], dim=1)], dim=2).squeeze(-1) + # data = torch.bmm(data.mT.unsqueeze(0), rot.to(dtype=data.dtype)).squeeze(-1) + ... - # data = data_slice[1, :2, :, 0] - - angle = self.angles[idx] - - data_index = 1 - - data_slice[1, :2, :, :] = self.rotate(data_slice[data_index, :2, :, :], angle) - - data = data_slice[1, :2, :, 0] - # data = self.rotate(data, angle) + # angle = torch.zeros_like(angle) # for both polarisations (x/y), calculate the mean of the signal around the current symbol (-> corresponding to a lowpass filter) - angle_data = data_slice[1, :2, :, :].reshape(2, -1).mean(dim=1) - angle_data2 = self.complex_max(data_slice[1, :2, :, :].reshape(2, -1)) - plot_data = data_slice[1, :2, self.output_dim // 2, 0] - sop = self.polarimeter(plot_data) + angle_data = fiber_out[:2, :, :].reshape(2, -1).mean(dim=1).repeat(1, self.output_dim) + angle_data2 = self.complex_max(fiber_out[:2, :, :].reshape(2, -1)).repeat(1, self.output_dim) + # sop = self.polarimeter(plot_data) # angle_data = data_slice[1, :2, :self.output_dim//2, :].squeeze().reshape(2, -1).mean(dim=-1) # angle = data_slice[1, 3, self.output_dim // 2, 0].real - target = data_slice[0, :2, self.output_dim // 2, 0] - target_timestamp = data_slice[0, 2, self.output_dim // 2, 0].real + target = fiber_in[:2, self.output_dim // 2, 0] + plot_target = fiber_in[:2, self.output_dim // 2, 0].detach().clone() + target_timestamp = fiber_in[2, self.output_dim // 2, 0].real ... # data_timestamps = data[-1,:].real @@ -360,22 +407,39 @@ class FiberRegenerationDataset(Dataset): # transpose to interleave the x and y data in the output tensor data = data.transpose(0, 1).flatten().squeeze() - angle_data = angle_data.flatten().squeeze() - angle_data2 = angle_data.flatten().squeeze() - angle = angle.flatten().squeeze() + angle_data = angle_data.transpose(0, 1).flatten().squeeze() + angle_data2 = angle_data2.transpose(0,1).flatten().squeeze() + center_angle = center_angle.flatten().squeeze() + angles = angles.flatten().squeeze() # data_timestamps = data_timestamps.flatten().squeeze() + # target = target.transpose(0,1).flatten().squeeze() target = target.flatten().squeeze() target_timestamp = target_timestamp.flatten().squeeze() + plot_target = plot_target.flatten().squeeze() + plot_data = plot_data.flatten().squeeze() + plot_data_rot = plot_data_rot.flatten().squeeze() + + return { + "x": data, + "y": target, + "center_angle": center_angle, + "angles": angles, + "mean_angle": angles.mean(), + # "sop": sop, + "angle_data": angle_data, + "angle_data2": angle_data2, + "timestamp": target_timestamp, + "plot_target": plot_target, + "plot_data": plot_data, + "plot_data_rot": plot_data_rot, + } - return {"x": data, "y": target, "angle": angle, "sop": sop, "angle_data": angle_data, "angle_data2": angle_data2, "timestamp": target_timestamp, "plot_data": plot_data} - def complex_max(self, data, dim=-1): # returns element(s) with the maximum absolute value along a given dimension # ind = torch.argmax(data.abs(), dim=dim, keepdim=True) # max_values = torch.gather(data, dim, ind).squeeze(dim=dim) # return max_values return torch.gather(data, dim, torch.argmax(data.abs(), dim=dim, keepdim=True)).squeeze(dim=dim) - def rotate(self, data, angle): # rotates a 2d tensor by a given angle @@ -388,7 +452,25 @@ class FiberRegenerationDataset(Dataset): cosine = torch.cos(angle) return torch.stack([data[0] * cosine - data[1] * sine, data[0] * sine + data[1] * cosine], dim=0) - + + def rotate_all(self): + def do_rotation(j, num_processes): + for i in range(len(self) // num_processes): + index = i * num_processes + j + self.data[index, 1, :2, :] = self.rotate(self.data[index, 1, :2, :], self.angles[index]) + + self.processes = [] + + for j in range(mp.cpu_count()): + self.processes.append(mp.Process(target=do_rotation, args=(j, mp.cpu_count()))) + self.processes[-1].start() + + for p in self.processes: + p.join() + + for i in range(len(self) // mp.cpu_count() * mp.cpu_count(), len(self)): + self.data[i, 1, :2, :] = self.rotate(self.data[i, 1, :2, :], self.angles[i]) + def polarimeter(self, data): # data: [2, ...] -> x, y # returns [4] -> S0, S1, S2, S3 @@ -396,12 +478,12 @@ class FiberRegenerationDataset(Dataset): y = data[1].mean() I_X = x.abs().square() I_Y = y.abs().square() - I_45 = (x+y).abs().square() - I_RHC = (x + 1j*y).abs().square() + I_45 = (x + y).abs().square() + I_RHC = (x + 1j * y).abs().square() S0 = I_X + I_Y - S1 = (2*I_X - S0) / S0 - S2 = (2*I_45 - S0) / S0 - S3 = (2*I_RHC - S0) / S0 + S1 = (2 * I_X - S0) / S0 + S2 = (2 * I_45 - S0) / S0 + S3 = (2 * I_RHC - S0) / S0 - return torch.stack([S1, S2, S3], dim=0) \ No newline at end of file + return torch.stack([S0, S1, S2, S3], dim=0)