From 297e9e8d7fa4eddb8215c8775e48de0bb163faec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joseph=20Hopfm=C3=BCller?= Date: Mon, 2 Dec 2024 18:50:43 +0100 Subject: [PATCH] update submodule configuration and enhance model settings; add eye diagram functionality --- .gitmodules | 1 + .../hypertraining/hypertraining.py | 4 +- .../hypertraining/settings.py | 9 +- .../hypertraining/training.py | 423 +++++++++++++----- src/single-core-regen/regen_no_hyper.py | 170 +++++-- src/single-core-regen/util/__init__.py | 4 +- src/single-core-regen/util/complexNN.py | 264 +++++++---- 7 files changed, 626 insertions(+), 249 deletions(-) diff --git a/.gitmodules b/.gitmodules index 89f8125..d759e75 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,4 @@ [submodule "pypho"] path = pypho url = git@gitlab.lrz.de:000000003B9B3E61/pypho.git + branch = main diff --git a/src/single-core-regen/hypertraining/hypertraining.py b/src/single-core-regen/hypertraining/hypertraining.py index d14c308..bed9bcc 100644 --- a/src/single-core-regen/hypertraining/hypertraining.py +++ b/src/single-core-regen/hypertraining/hypertraining.py @@ -258,12 +258,12 @@ class HyperTraining: f"model_hidden_dim_{i}", self.model_settings.n_hidden_nodes, ) - layers.append(util.complexNN.SemiUnitaryLayer(last_dim, hidden_dim, dtype=dtype)) + layers.append(util.complexNN.ONNRect(last_dim, hidden_dim, dtype=dtype)) last_dim = hidden_dim layers.append(getattr(util.complexNN, afunc)()) n_nodes += last_dim - layers.append(util.complexNN.SemiUnitaryLayer(last_dim, self.model_settings.output_dim, dtype=dtype)) + layers.append(util.complexNN.ONNRect(last_dim, self.model_settings.output_dim, dtype=dtype)) model = nn.Sequential(*layers) diff --git a/src/single-core-regen/hypertraining/settings.py b/src/single-core-regen/hypertraining/settings.py index 5b3eebd..1ceb51b 100644 --- a/src/single-core-regen/hypertraining/settings.py +++ b/src/single-core-regen/hypertraining/settings.py @@ -11,7 +11,7 @@ class GlobalSettings: # data settings @dataclass class DataSettings: - config_path: str # = "data/*-128-16384-100000-0-0-17-0-PAM4-0.ini" + config_path: tuple | list | str # = "data/*-128-16384-100000-0-0-17-0-PAM4-0.ini" dtype: tuple = ("complex64", "float64") symbols: tuple | float | int = 8 output_size: tuple | float | int = 64 @@ -39,7 +39,7 @@ class PytorchSettings: summary_dir: str = ".runs" write_every: int = 10 head_symbols: int = 40 - eye_symbols: int = 400 + eye_symbols: int = 1000 # model settings @@ -52,13 +52,16 @@ class ModelSettings: overrides: dict = field(default_factory=dict) dropout_prob: float | None = None model_layer_function: str | None = None + scale: bool = False + model_layer_kwargs: dict | None = None model_layer_parametrizations: list= field(default_factory=list) @dataclass class OptimizerSettings: optimizer: tuple | str = ("Adam", "RMSprop", "SGD") - learning_rate: tuple | float = (1e-5, 1e-1) + optimizer_kwargs: dict | None = None + # learning_rate: tuple | float = (1e-5, 1e-1) scheduler: str | None = None scheduler_kwargs: dict | None = None diff --git a/src/single-core-regen/hypertraining/training.py b/src/single-core-regen/hypertraining/training.py index 0a7f668..fc27098 100644 --- a/src/single-core-regen/hypertraining/training.py +++ b/src/single-core-regen/hypertraining/training.py @@ -1,8 +1,10 @@ import copy from datetime import datetime from pathlib import Path +import random from typing import Literal import matplotlib +from matplotlib.colors import LinearSegmentedColormap import torch.nn.utils.parametrize try: @@ -50,6 +52,7 @@ class regenerator(nn.Module): self, *dims, layer_function=util.complexNN.ONN, + layer_kwargs: dict | None = None, layer_parametrizations: list[dict] = None, # [ # { @@ -64,6 +67,7 @@ class regenerator(nn.Module): activation_function=util.complexNN.Pow, dtype=torch.float64, dropout_prob=0.01, + scale=False, **kwargs, ): super(regenerator, self).__init__() @@ -74,39 +78,57 @@ class regenerator(nn.Module): raise ValueError("dims must be provided") self._n_hidden_layers = len(dims) - 2 self._layers = nn.Sequential() + if layer_kwargs is None: + layer_kwargs = {} + # self.powers = [] for i in range(self._n_hidden_layers + 1): - self._layers.append(layer_function(dims[i], dims[i + 1], dtype=dtype)) + if scale: + self._layers.append(util.complexNN.Scale(dims[i])) + self._layers.append(layer_function(dims[i], dims[i + 1], dtype=dtype, **layer_kwargs)) if i < self._n_hidden_layers: if dropout_prob is not None: self._layers.append(util.complexNN.DropoutComplex(p=dropout_prob)) - self._layers.append(activation_function()) - - # add parametrizations - if layer_parametrizations is not None: + self._layers.append(activation_function(bias=True, size=dims[i + 1])) + + self._layers.append(util.complexNN.Scale(dims[-1])) + + # add parametrizations + if layer_parametrizations is not None: + for layer in self._layers: for layer_parametrization in layer_parametrizations: tensor_name = layer_parametrization.get("tensor_name", None) parametrization = layer_parametrization.get("parametrization", None) param_kwargs = layer_parametrization.get("kwargs", {}) - if ( - tensor_name is not None - and tensor_name in self._layers[-1]._parameters - and parametrization is not None - ): - parametrization(self._layers[-1], tensor_name, **param_kwargs) + if tensor_name is not None and tensor_name in layer._parameters and parametrization is not None: + parametrization(layer, tensor_name, **param_kwargs) - def forward(self, input_x): + # def __call__(self, input_x, **kwargs): + # return self.forward(input_x, **kwargs) + + def forward(self, input_x, trace_powers=False): x = input_x + + if trace_powers: + powers = [x.abs().square().sum()] + # check if tracing if torch.jit.is_tracing(): for layer in self._layers: x = layer(x) + if trace_powers: + powers.append(x.abs().square().sum()) else: # with torch.nn.utils.parametrize.cached(): for layer in self._layers: x = layer(x) + if trace_powers: + powers.append(x.abs().square().sum()) + if trace_powers: + return x, powers return x + def traverse_dict_update(target, source): for k, v in source.items(): if isinstance(v, dict): @@ -119,6 +141,7 @@ def traverse_dict_update(target, source): except TypeError: target.__dict__[k] = v + class Trainer: def __init__( self, @@ -142,7 +165,7 @@ class Trainer: OptimizerSettings, PytorchSettings, regenerator, - torch.nn.utils.parametrizations.orthogonal + torch.nn.utils.parametrizations.orthogonal, ]) if self.resume: self.checkpoint_dict = torch.load(checkpoint_path, weights_only=True) @@ -167,7 +190,7 @@ class Trainer: raise ValueError("model_settings must be provided") if optimizer_settings is None: raise ValueError("optimizer_settings must be provided") - + self.global_settings: GlobalSettings = global_settings self.data_settings: DataSettings = data_settings self.pytorch_settings: PytorchSettings = pytorch_settings @@ -206,6 +229,11 @@ class Trainer: } def define_model(self, model_kwargs=None): + if self.resume: + model_kwargs = self.checkpoint_dict["model_kwargs"] + else: + model_kwargs = model_kwargs + if model_kwargs is None: n_hidden_layers = self.model_settings.n_hidden_layers @@ -228,6 +256,7 @@ class Trainer: "activation_function": afunc, "dtype": dtype, "dropout_prob": self.model_settings.dropout_prob, + "scale": self.model_settings.scale, } else: self.model_kwargs = model_kwargs @@ -237,9 +266,12 @@ class Trainer: # dims = self.model_kwargs.pop("dims") self.model = regenerator(**self.model_kwargs) - self.writer.add_graph(self.model, torch.zeros(1, input_dim, dtype=dtype)) + if self.writer is not None: + self.writer.add_graph(self.model, torch.zeros(1, input_dim, dtype=dtype)) self.model = self.model.to(self.pytorch_settings.device) + if self.resume: + self.model.load_state_dict(self.checkpoint_dict["model_state_dict"], strict=False) def get_sliced_data(self, override=None): symbols = self.data_settings.symbols @@ -253,11 +285,13 @@ class Trainer: dtype = getattr(torch, self.data_settings.dtype) num_symbols = None + config_path = self.data_settings.config_path if override is not None: num_symbols = override.get("num_symbols", None) + config_path = override.get("config_path", config_path) # get dataset dataset = FiberRegenerationDataset( - file_path=self.data_settings.config_path, + file_path=config_path, symbols=symbols, output_dim=data_size, target_delay=in_out_delay, @@ -330,10 +364,11 @@ class Trainer: task = progress.add_task("-.---e--", total=len(train_loader)) progress.start() - running_loss2 = 0.0 + # running_loss2 = 0.0 running_loss = 0.0 self.model.train() - for batch_idx, (x, y) in enumerate(train_loader): + loader_len = len(train_loader) + for batch_idx, (x, y, _) in enumerate(train_loader): self.model.zero_grad(set_to_none=True) x, y = ( x.to(self.pytorch_settings.device), @@ -344,24 +379,23 @@ class Trainer: loss_value = loss.item() loss.backward() optimizer.step() - running_loss2 += loss_value + # running_loss2 += loss_value running_loss += loss_value if enable_progress: - progress.update(task, advance=1, description=f"{loss_value:.3e}") + progress.update(task, advance=1, description=f"{running_loss/(batch_idx+1):.3e}") if batch_idx % self.pytorch_settings.write_every == 0: self.writer.add_scalar( "training loss", - running_loss2 / (self.pytorch_settings.write_every if batch_idx > 0 else 1), - epoch * len(train_loader) + batch_idx, + running_loss / (batch_idx + 1), + epoch * loader_len + batch_idx, ) - running_loss2 = 0.0 if enable_progress: progress.stop() - return running_loss / len(train_loader) + return running_loss / (batch_idx + 1) def eval_model(self, valid_loader, epoch, enable_progress=True): if enable_progress: @@ -384,7 +418,7 @@ class Trainer: self.model.eval() running_error = 0 with torch.no_grad(): - for batch_idx, (x, y) in enumerate(valid_loader): + for batch_idx, (x, y, _) in enumerate(valid_loader): x, y = ( x.to(self.pytorch_settings.device), y.to(self.pytorch_settings.device), @@ -395,76 +429,107 @@ class Trainer: running_error += error_value if enable_progress: - progress.update(task, advance=1, description=f"{error_value:.3e}") + progress.update(task, advance=1, description=f"{error_value/(batch_idx+1):.3e}") + running_error /= (batch_idx+1) - running_error /= len(valid_loader) self.writer.add_scalar( "eval loss", running_error, epoch, ) - title_append, subtitle = self.build_title(epoch + 1) - self.writer.add_figure( - "fiber response", - self.plot_model_response( - model=self.model, - title_append=title_append, - subtitle=subtitle, - show=False, - ), - epoch + 1, - ) - self.writer.add_figure( - "eye diagram", - self.plot_model_response( - model=self.model, - title_append=title_append, - subtitle=subtitle, - show=False, - mode="eye", - ), - epoch + 1, - ) - self.writer_histograms(epoch + 1) + if (epoch + 1) % 10 == 0 or epoch < 10: + # plotting is slow, so only do it every 10 epochs + title_append, subtitle = self.build_title(epoch + 1) + self.writer.add_figure( + "fiber response", + self.plot_model_response( + model=self.model, + title_append=title_append, + subtitle=subtitle, + show=False, + ), + epoch + 1, + ) + self.writer.add_figure( + "eye diagram", + self.plot_model_response( + model=self.model, + title_append=title_append, + subtitle=subtitle, + show=False, + mode="eye", + ), + epoch + 1, + ) + + self.writer.add_figure( + "powers", + self.plot_model_response( + model=self.model, + title_append=title_append, + subtitle=subtitle, + mode="powers", + show=False, + ), + epoch + 1, + ) + + self.write_parameters(epoch + 1) + self.writer.flush() if enable_progress: progress.stop() return running_error - def run_model(self, model, loader): + def run_model(self, model, loader, trace_powers=False): model.eval() - xs = [] - ys = [] - y_preds = [] + fiber_out = [] + fiber_in = [] + regen = [] + timestamps = [] + with torch.no_grad(): model = model.to(self.pytorch_settings.device) - for x, y in loader: + for x, y, timestamp in loader: x, y = ( x.to(self.pytorch_settings.device), y.to(self.pytorch_settings.device), ) - y_pred = model(x).cpu() + if trace_powers: + y_pred, powers = model(x, trace_powers).cpu() + else: + y_pred = model(x, trace_powers).cpu() # x = x.cpu() # y = y.cpu() y_pred = y_pred.view(y_pred.shape[0], -1, 2) y = y.view(y.shape[0], -1, 2) x = x.view(x.shape[0], -1, 2) - xs.append(x[:, 0, :].squeeze()) - ys.append(y.squeeze()) - y_preds.append(y_pred.squeeze()) + # timestamp = timestamp.view(-1, 1) + fiber_out.append(x[:, x.shape[1] // 2, :].squeeze()) + fiber_in.append(y.squeeze()) + regen.append(y_pred.squeeze()) + timestamps.append(timestamp.squeeze()) - xs = torch.vstack(xs).cpu() - ys = torch.vstack(ys).cpu() - y_preds = torch.vstack(y_preds).cpu() - return ys, xs, y_preds + fiber_out = torch.vstack(fiber_out).cpu() + fiber_in = torch.vstack(fiber_in).cpu() + regen = torch.vstack(regen).cpu() + timestamps = torch.concat(timestamps).cpu() + if trace_powers: + return fiber_in, fiber_out, regen, timestamps, powers + return fiber_in, fiber_out, regen, timestamps - def writer_histograms(self, epoch, attributes=["weight", "weight_U", "weight_V", "bias", "sigma", "scale"]): + def write_parameters(self, epoch, attributes: list[str] | tuple[str] = None): for i, layer in enumerate(self.model._layers): tag = f"layer {i}" - for attribute in attributes: - if hasattr(layer, attribute): + if hasattr(layer, "parametrizations"): + attribute_pool = set(layer.parametrizations._modules) | set(layer._parameters) + else: + attribute_pool = set(layer._parameters) + for attribute in attribute_pool: + plot = (attributes is None) or (attribute in attributes) + if plot: vals: np.ndarray = getattr(layer, attribute).detach().cpu().numpy().flatten() if vals.ndim <= 1 and len(vals) == 1: if np.iscomplexobj(vals): @@ -483,14 +548,11 @@ class Trainer: if self.writer is None: self.setup_tb_writer() - if self.resume: - model_kwargs = self.checkpoint_dict["model_kwargs"] - else: - model_kwargs = None + self.define_model() - self.define_model(model_kwargs=model_kwargs) - - print(f"number of parameters (trainable): {sum(p.numel() for p in self.model.parameters())} ({sum(p.numel() for p in self.model.parameters() if p.requires_grad)})") + print( + f"number of parameters (trainable): {sum(p.numel() for p in self.model.parameters())} ({sum(p.numel() for p in self.model.parameters() if p.requires_grad)})" + ) title_append, subtitle = self.build_title(0) @@ -515,36 +577,55 @@ class Trainer: ), 0, ) - self.writer_histograms(0) + + self.writer.add_figure( + "powers", + self.plot_model_response( + model=self.model, + title_append=title_append, + subtitle=subtitle, + mode="powers", + show=False, + ), + 0, + ) + + self.write_parameters(0) + + self.writer.add_text("datasets", '\n'.join(self.data_settings.config_path)) + + self.writer.flush() train_loader, valid_loader = self.get_sliced_data() optimizer_name = self.optimizer_settings.optimizer - lr = self.optimizer_settings.learning_rate + # lr = self.optimizer_settings.learning_rate - self.optimizer: optim.Optimizer = getattr(optim, optimizer_name)(self.model.parameters(), lr=lr) + self.optimizer: optim.Optimizer = getattr(optim, optimizer_name)( + self.model.parameters(), **self.optimizer_settings.optimizer_kwargs + ) if self.optimizer_settings.scheduler is not None: self.scheduler = getattr(optim.lr_scheduler, self.optimizer_settings.scheduler)( self.optimizer, **self.optimizer_settings.scheduler_kwargs ) - if self.resume: - try: - self.scheduler.load_state_dict(self.checkpoint_dict["scheduler_state_dict"]) - except ValueError: - pass - self.writer.add_scalar("learning rate", self.scheduler.get_last_lr()[0], -1) - + # if self.resume: + # try: + # self.scheduler.load_state_dict(self.checkpoint_dict["scheduler_state_dict"]) + # except ValueError: + # pass + self.writer.add_scalar("learning rate", self.optimizer.param_groups[0]["lr"], -1) if not self.resume: self.best = self.build_checkpoint_dict() else: self.best = self.checkpoint_dict - self.model.load_state_dict(self.best["model_state_dict"], strict=False) - try: - self.optimizer.load_state_dict(self.best["optimizer_state_dict"]) - except ValueError: - pass + self.best["loss"] = float("inf") + # self.model.load_state_dict(self.best["model_state_dict"], strict=False) + # try: + # self.optimizer.load_state_dict(self.best["optimizer_state_dict"]) + # except ValueError: + # pass for epoch in range(self.best["epoch"] + 1, self.pytorch_settings.epochs): enable_progress = True @@ -562,12 +643,8 @@ class Trainer: enable_progress=enable_progress, ) if self.optimizer_settings.scheduler is not None: - lr_old = self.scheduler.get_last_lr() self.scheduler.step(loss) - lr_new = self.scheduler.get_last_lr() - if lr_old[0] != lr_new[0]: - self.writer.add_scalar("learning rate", lr_new[0], epoch) - + self.writer.add_scalar("learning rate", self.optimizer.param_groups[0]["lr"], epoch) if self.pytorch_settings.save_models and self.model is not None: save_path = ( Path(self.pytorch_settings.model_dir) / f"{self.writer.get_logdir().split('/')[-1]}_{epoch}.tar" @@ -588,7 +665,28 @@ class Trainer: self.writer.close() return self.best - def _plot_model_response_eye(self, *signals, labels=None, sps=None, title_append="", subtitle="", show=True): + def _plot_model_response_powers(self, powers, layer_names, title_append="", subtitle="", show=True): + powers = [power / powers[0] for power in powers] + fig, ax = plt.subplots() + fig.set_figwidth(18) + fig.suptitle( + f"Energy conservation{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}" + ) + ax.semilogy(powers, marker="o") + ax.set_xticks(range(len(layer_names)), layer_names, rotation=90) + ax.set_xlabel("Layer") + ax.set_ylabel("Normailzed Power") + ax.grid(which="major", axis="x") + ax.grid(which="major", axis="y") + ax.grid(which="minor", axis="y", linestyle=":") + fig.tight_layout() + if show: + plt.show() + return fig + + def _plot_model_response_eye( + self, *signals, timestamps, labels=None, sps=None, title_append="", subtitle="", show=True + ): if sps is None: raise ValueError("sps must be provided") if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))): @@ -603,22 +701,73 @@ class Trainer: if not any(labels): labels = [f"signal {i + 1}" for i in range(len(signals))] + x_bins = np.linspace(0, 2, 2 * sps, endpoint=False) + y_bins = np.zeros((2 * len(signals), 1000)) + eye_data = np.zeros((2 * len(signals), 1000, 2 * sps)) + # signals = [signal.cpu().numpy() for signal in signals] + for i in range(len(signals) * 2): + eye_signal = signals[i // 2][:, i % 2] # x, y, x, y, ... + eye_signal = np.real(np.square(np.abs(eye_signal))) + data_min = np.min(eye_signal) + data_max = np.max(eye_signal) + y_bins[i] = np.linspace(data_min, data_max, 1000, endpoint=False) + for j in range(len(timestamps)): + t = timestamps[j] / sps + val = eye_signal[j] + x = np.digitize(t % 2, x_bins) - 1 + y = np.digitize(val, y_bins[i]) - 1 + eye_data[i][y][x] += 1 + + cmap = LinearSegmentedColormap.from_list( + "eyemap", + [ + (0, "white"), + (0.001, "dodgerblue"), + (0.1, "blue"), + (0.2, "cyan"), + (0.5, "lime"), + (0.8, "gold"), + (1, "red"), + ], + ) + + # ordering = np.argsort(timestamps) + # signals = [signal[ordering] for signal in signals] + # timestamps = timestamps[ordering] + fig, axs = plt.subplots(1, 2 * len(signals), sharex=True, sharey=True) fig.set_figwidth(18) fig.suptitle(f"Eye diagram{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}") - xaxis = np.linspace(0, 2, 2 * sps, endpoint=False) - for j, (label, signal) in enumerate(zip(labels, signals)): + # xaxis = timestamps / sps + # xaxis = np.arange(2 * sps) / sps + for j, label in enumerate(labels): + x = eye_data[2 * j] + y = eye_data[2 * j + 1] + # x, y = signal.T # signal = signal.cpu().numpy() - for i in range(len(signal) // sps - 1): - x, y = signal[i * sps : (i + 2) * sps].T - axs[0 + 2 * j].plot(xaxis, np.abs(x) ** 2, color=f"C{j}", alpha=1 / (len(signal) // sps) * 10) - axs[1 + 2 * j].plot(xaxis, np.abs(y) ** 2, color=f"C{j}", alpha=1 / (len(signal) // sps) * 10) - axs[0 + 2 * j].set_title(label + " x") - axs[1 + 2 * j].set_title(label + " y") - axs[0 + 2 * j].set_xlabel("Symbol") - axs[1 + 2 * j].set_xlabel("Symbol") - axs[0 + 2 * j].set_box_aspect(1) - axs[1 + 2 * j].set_box_aspect(1) + # for i in range(len(signal) // sps - 1): + # x, y = signal[i * sps : (i + 2) * sps].T + # axs[0 + 2 * j].scatter((timestamps/sps) % 2, np.abs(x) ** 2, color=f"C{j}", alpha=1 / (len(signal) // sps) * 10, s=1) + # axs[1 + 2 * j].scatter((timestamps/sps) % 2, np.abs(y) ** 2, color=f"C{j}", alpha=1 / (len(signal) // sps) * 10, s=1) + axs[0 + 2 * j].imshow( + x, aspect="auto", cmap=cmap, origin="lower", extent=[0, 2, y_bins[2 * j][0], y_bins[2 * j][-1]] + ) + axs[1 + 2 * j].imshow( + y, aspect="auto", cmap=cmap, origin="lower", extent=[0, 2, y_bins[2 * j + 1][0], y_bins[2 * j + 1][-1]] + ) + axs[0 + 2 * j].set_xlim((x_bins[0], x_bins[-1])) + axs[1 + 2 * j].set_xlim((x_bins[0], x_bins[-1])) + ymin = np.min(y_bins[:, 0]) + ymax = np.max(y_bins[:, -1]) + ydiff = ymax - ymin + axs[0 + 2 * j].set_ylim((ymin - 0.05 * ydiff, ymax + 0.05 * ydiff)) + axs[1 + 2 * j].set_ylim((ymin - 0.05 * ydiff, ymax + 0.05 * ydiff)) + axs[0 + 2 * j].set_title(label + " x") + axs[1 + 2 * j].set_title(label + " y") + axs[0 + 2 * j].set_xlabel("Symbol") + axs[1 + 2 * j].set_xlabel("Symbol") + axs[0 + 2 * j].set_box_aspect(1) + axs[1 + 2 * j].set_box_aspect(1) axs[0].set_ylabel("normalized power") fig.tight_layout() # axs[1+2*len(labels)-1].set_ylabel("normalized power") @@ -627,7 +776,9 @@ class Trainer: plt.show() return fig - def _plot_model_response_head(self, *signals, labels=None, sps=None, title_append="", subtitle="", show=True): + def _plot_model_response_head( + self, *signals, timestamps, labels=None, sps=None, title_append="", subtitle="", show=True + ): if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))): labels = [labels] else: @@ -640,19 +791,29 @@ class Trainer: if not any(labels): labels = [f"signal {i + 1}" for i in range(len(signals))] + ordering = np.argsort(timestamps) + signals = [signal[ordering] for signal in signals] + timestamps = timestamps[ordering] + fig, axs = plt.subplots(1, 2, sharex=True, sharey=True) fig.set_figwidth(18) fig.set_figheight(4) fig.suptitle(f"Fiber response{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}") for i, ax in enumerate(axs): + ax: plt.Axes for signal, label in zip(signals, labels): if sps is not None: - xaxis = np.linspace(0, len(signal) / sps, len(signal), endpoint=False) + xaxis = timestamps / sps else: - xaxis = np.arange(len(signal)) + xaxis = timestamps ax.plot(xaxis, np.abs(signal[:, i]) ** 2, label=label) ax.set_xlabel("Sample" if sps is None else "Symbol") ax.set_ylabel("normalized power") + ax.minorticks_on() + ax.tick_params(axis="y", which="minor", left=False, right=False) + ax.grid(which="major", axis="x") + ax.grid(which="minor", axis="x", linestyle=":") + ax.grid(which="major", axis="y") ax.legend(loc="upper right") fig.tight_layout() if show: @@ -664,22 +825,51 @@ class Trainer: model=None, title_append="", subtitle="", - mode: Literal["eye", "head"] = "head", + mode: Literal["eye", "head", "powers"] = "head", show=False, ): + if mode == "powers": + input_data = torch.ones( + 1, 2 * self.data_settings.output_size, dtype=getattr(torch, self.data_settings.dtype) + ).to(self.pytorch_settings.device) + model = model.to(self.pytorch_settings.device) + model.eval() + with torch.no_grad(): + _, powers = model(input_data, trace_powers=True) + + powers = [power.item() for power in powers] + layer_names = ["input", *[str(x).split("(")[0] for x in model._layers._modules.values()]] + + # remove dropout layers + mask = [1 if "Dropout" not in layer_name else 0 for layer_name in layer_names] + layer_names = [layer_name for layer_name, m in zip(layer_names, mask) if m] + powers = [power for power, m in zip(powers, mask) if m] + + fig = self._plot_model_response_powers( + powers, layer_names, title_append=title_append, subtitle=subtitle, show=show + ) + return fig + data_settings_backup = copy.deepcopy(self.data_settings) pytorch_settings_backup = copy.deepcopy(self.pytorch_settings) - self.data_settings.drop_first = 100 * 128 + self.data_settings.drop_first = 99.5 + random.randint(0, 1000) self.data_settings.shuffle = False self.data_settings.train_split = 1.0 self.pytorch_settings.batchsize = ( self.pytorch_settings.eye_symbols if mode == "eye" else self.pytorch_settings.head_symbols ) - plot_loader, _ = self.get_sliced_data(override={"num_symbols": self.pytorch_settings.batchsize}) + config_path = random.choice(self.data_settings.config_path) + fiber_length = int(float(str(config_path).split('-')[-7])/1000) + plot_loader, _ = self.get_sliced_data( + override={ + "num_symbols": self.pytorch_settings.batchsize, + "config_path": config_path, + } + ) self.data_settings = data_settings_backup self.pytorch_settings = pytorch_settings_backup - fiber_in, fiber_out, regen = self.run_model(model, plot_loader) + fiber_in, fiber_out, regen, timestamps = self.run_model(model, plot_loader) fiber_in = fiber_in.view(-1, 2) fiber_out = fiber_out.view(-1, 2) regen = regen.view(-1, 2) @@ -687,6 +877,7 @@ class Trainer: fiber_in = fiber_in.numpy() fiber_out = fiber_out.numpy() regen = regen.numpy() + timestamps = timestamps.numpy() # https://github.com/matplotlib/matplotlib/issues/27713#issue-2104110987 # https://github.com/matplotlib/matplotlib/issues/27713#issuecomment-1915497463 @@ -697,9 +888,10 @@ class Trainer: fiber_in, fiber_out, regen, + timestamps=timestamps, labels=("fiber in", "fiber out", "regen"), sps=plot_loader.dataset.samples_per_symbol, - title_append=title_append, + title_append=title_append + f" ({fiber_length} km)", subtitle=subtitle, show=show, ) @@ -709,9 +901,10 @@ class Trainer: fiber_in, fiber_out, regen, + timestamps=timestamps, labels=("fiber in", "fiber out", "regen"), sps=plot_loader.dataset.samples_per_symbol, - title_append=title_append, + title_append=title_append + f" ({fiber_length} km)", subtitle=subtitle, show=show, ) diff --git a/src/single-core-regen/regen_no_hyper.py b/src/single-core-regen/regen_no_hyper.py index 6d84c3e..70aee0b 100644 --- a/src/single-core-regen/regen_no_hyper.py +++ b/src/single-core-regen/regen_no_hyper.py @@ -1,3 +1,6 @@ +import matplotlib +import numpy as np +import torch from hypertraining.settings import ( GlobalSettings, DataSettings, @@ -7,16 +10,20 @@ from hypertraining.settings import ( ) from hypertraining.training import Trainer -import torch + +# import torch import json import util +from rich import print as rprint + global_settings = GlobalSettings( - seed=42, + seed=0xC0FFEE, ) data_settings = DataSettings( - config_path="data/*-128-16384-100000-0-0-17-0-PAM4-0.ini", + # config_path="data/*-128-16384-50000-0-0-17-0-PAM4-0.ini", + config_path=[f"data/*-128-16384-{length}-0-0-17-0-PAM4-0.ini" for length in (40000, 50000, 60000)], dtype="complex64", # symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber symbols=13, # study: single_core_regen_20241123_011232 @@ -25,7 +32,7 @@ data_settings = DataSettings( shuffle=True, in_out_delay=0, xy_delay=0, - drop_first=128*64, + drop_first=128 * 64, train_split=0.8, ) @@ -45,55 +52,83 @@ model_settings = ModelSettings( output_dim=2, n_hidden_layers=4, overrides={ - "n_hidden_nodes_0": 8, - "n_hidden_nodes_1": 8, + "n_hidden_nodes_0": 4, + "n_hidden_nodes_1": 4, "n_hidden_nodes_2": 4, - "n_hidden_nodes_3": 6, + "n_hidden_nodes_3": 4, }, - model_activation_func="PowScale", - # dropout_prob=0.01, - model_layer_function="ONN", + model_activation_func="EOActivation", + dropout_prob=0.01, + model_layer_function="ONNRect", + model_layer_kwargs={"square": True}, + scale=True, model_layer_parametrizations=[ { "tensor_name": "weight", - "parametrization": torch.nn.utils.parametrizations.orthogonal, + "parametrization": util.complexNN.energy_conserving, + }, + { + "tensor_name": "alpha", + "parametrization": util.complexNN.clamp, + }, + { + "tensor_name": "gain", + "parametrization": util.complexNN.clamp, + "kwargs": { + "min": 0, + "max": float("inf"), + }, + }, + { + "tensor_name": "phase_bias", + "parametrization": util.complexNN.clamp, + "kwargs": { + "min": 0, + "max": 2 * torch.pi, + }, }, { "tensor_name": "scales", "parametrization": util.complexNN.clamp, }, - { - "tensor_name": "scale", - "parametrization": util.complexNN.clamp, - }, - { - "tensor_name": "bias", - "parametrization": util.complexNN.clamp, - }, + # { + # "tensor_name": "scale", + # "parametrization": util.complexNN.clamp, + # }, + # { + # "tensor_name": "bias", + # "parametrization": util.complexNN.clamp, + # }, # { # "tensor_name": "V", # "parametrization": torch.nn.utils.parametrizations.orthogonal, # }, - # { - # "tensor_name": "S", - # "parametrization": util.complexNN.clamp, - # }, + { + "tensor_name": "loss", + "parametrization": util.complexNN.clamp, + }, ], ) optimizer_settings = OptimizerSettings( - optimizer="Adam", - learning_rate=0.05, + optimizer="AdamW", + optimizer_kwargs={ + "lr": 0.05, + "amsgrad": True, + # "weight_decay": 1e-7, + }, + # learning_rate=0.05, scheduler="ReduceLROnPlateau", scheduler_kwargs={ - "patience": 2**6, - "factor": 0.9, + "patience": 2**6, + "factor": 0.75, # "threshold": 1e-3, "min_lr": 1e-6, "cooldown": 10, - }, + }, ) + def save_dict_to_file(dictionary, filename): """ Save the best dictionary to a JSON file. @@ -103,28 +138,79 @@ def save_dict_to_file(dictionary, filename): :param filename: Path to the JSON file where the dictionary will be saved. :type filename: str """ - with open(filename, 'w') as f: + with open(filename, "w") as f: json.dump(dictionary, f, indent=4) +def sweep_lengths(*lengths, model=None): + assert model is not None, "Model must be provided." + model = model + + fiber_ins = {} + fiber_outs = {} + regens = {} + timestampss = {} + + for length in lengths: + trainer = Trainer( + checkpoint_path=model, + settings_override={ + "data_settings": { + "config_path": f"data/*-128-16384-{length}-0-0-17-0-PAM4-0.ini", + "train_split": 1, + "shuffle": True, + } + }, + ) + trainer.define_model() + loader, _ = trainer.get_sliced_data() + fiber_in, fiber_out, regen, timestamps = trainer.run_model(trainer.model, loader=loader) + + fiber_ins[length] = fiber_in + fiber_outs[length] = fiber_out + regens[length] = regen + timestampss[length] = timestamps + + data = torch.zeros(2 * len(lengths), 2, fiber_out.shape[0]) + channel_names = ["" for _ in range(2 * len(lengths))] + + for li, length in enumerate(lengths): + data[2 * li, 0, :] = timestampss[length] / 128 + data[2 * li, 1, :] = regens[length][:, 0].abs().square() + data[2 * li + 1, 0, :] = timestampss[length] / 128 + data[2 * li + 1, 1, :] = regens[length][:, 1].abs().square() + + channel_names[2 * li] = f"regen x {length}" + channel_names[2 * li + 1] = f"regen y {length}" + + # get current backend + backend = matplotlib.get_backend() + + matplotlib.use("TkCairo") + eye = util.eye_diagram.eye_diagram(data.to(dtype=torch.float32).detach().cpu().numpy(), channel_names=channel_names) + + print_attrs = ("channel", "success", "min_area") + with np.printoptions(precision=3, suppress=True, formatter={'float': '{:0.3e}'.format}): + for result in eye.eye_stats: + print_dict = {attr: result[attr] for attr in print_attrs} + rprint(print_dict) + rprint() + + eye.plot() + matplotlib.use(backend) + + if __name__ == "__main__": + + # sweep_lengths(30000, 40000, 50000, 60000, 70000, model=".models/best_20241202_143149.tar") + trainer = Trainer( global_settings=global_settings, data_settings=data_settings, pytorch_settings=pytorch_settings, model_settings=model_settings, optimizer_settings=optimizer_settings, - checkpoint_path='.models/20241128_084935_8885.tar', - settings_override={ - "model_settings": { - # "model_activation_func": "PowScale", - "dropout_prob": 0, - } - }, - reset_epoch=True, + # checkpoint_path=".models/best_20241202_143149.tar", + # 20241202_143149 ) - - best = trainer.train() - save_dict_to_file(best, ".models/best_results.json") - - ... + trainer.train() \ No newline at end of file diff --git a/src/single-core-regen/util/__init__.py b/src/single-core-regen/util/__init__.py index 842276e..d4c4f7d 100644 --- a/src/single-core-regen/util/__init__.py +++ b/src/single-core-regen/util/__init__.py @@ -16,4 +16,6 @@ from . import complexNN # noqa: F401 # from .complexNN import complex_mse_loss # noqa: F401 # from .complexNN import complex_sse_loss # noqa: F401 -from . import misc # noqa: F401 \ No newline at end of file +from . import misc # noqa: F401 + +from . import eye_diagram # noqa: F401 \ No newline at end of file diff --git a/src/single-core-regen/util/complexNN.py b/src/single-core-regen/util/complexNN.py index 5383e6c..c47fcd4 100644 --- a/src/single-core-regen/util/complexNN.py +++ b/src/single-core-regen/util/complexNN.py @@ -4,23 +4,36 @@ import torch.nn.functional as F # from torchlambertw.special import lambertw -def complex_mse_loss(input, target, power=False, reduction="mean"): +def complex_mse_loss(input, target, power=False, normalize=False, reduction="mean"): """ Compute the mean squared error between two complex tensors. If power is set to True, the loss is computed as |input|^2 - |target|^2 """ reduce = getattr(torch, reduction) + power_penalty = 0 if power: input = (input * input.conj()).real.to(dtype=input.dtype.to_real()) target = (target * target.conj()).real.to(dtype=target.dtype.to_real()) + if normalize: + power_penalty = ((input.max() - input.min()) - (target.max() - target.min())) ** 2 + power_penalty += (input.min() - target.min()) ** 2 + input = input - input.min() + input = input / input.max() + target = target - target.min() + target = target / target.max() + else: + if normalize: + power_penalty = (input.abs().max() - target.abs().max()) ** 2 + input = input / input.abs().max() + target = target / target.abs().max() if input.is_complex() and target.is_complex(): - return reduce(torch.square(input.real - target.real) + torch.square(input.imag - target.imag)) + return reduce(torch.square(input.real - target.real) + torch.square(input.imag - target.imag)) + power_penalty elif input.is_complex() or target.is_complex(): raise ValueError("Input and target must have the same type (real or complex)") else: - return F.mse_loss(input, target, reduction=reduction) + return F.mse_loss(input, target, reduction=reduction) + power_penalty def complex_sse_loss(input, target): @@ -53,23 +66,19 @@ class UnitaryLayer(nn.Module): return f"UnitaryLayer({self.in_features}, {self.out_features})" - class _Unitary(nn.Module): - def forward(self, X:torch.Tensor): + def forward(self, X: torch.Tensor): if X.ndim < 2: - raise ValueError( - "Only tensors with 2 or more dimensions are supported. " - f"Got a tensor of shape {X.shape}" - ) + raise ValueError(f"Only tensors with 2 or more dimensions are supported. Got a tensor of shape {X.shape}") n, k = X.size(-2), X.size(-1) - transpose = n nn.Module: weight = getattr(module, name, None) if not isinstance(weight, torch.Tensor): @@ -87,27 +97,29 @@ def unitary(module: nn.Module, name: str = "weight") -> nn.Module: if weight.ndim < 2: raise ValueError(f"Expected a matrix or batch of matrices. Got a tensor of {weight.ndim} dimensions.") - + if weight.shape[-2] != weight.shape[-1]: raise ValueError(f"Expected a square matrix or batch of square matrices. Got a tensor of shape {weight.shape}") - + unit = _Unitary() nn.utils.parametrize.register_parametrization(module, name, unit) return module + class _SpecialUnitary(nn.Module): def __init__(self): super().__init__() - def forward(self, X:torch.Tensor): + def forward(self, X: torch.Tensor): n, k = X.size(-2), X.size(-1) if n != k: raise ValueError(f"Expected a square matrix. Got a tensor of shape {X.shape}") q, _ = torch.linalg.qr(X) - q = q / torch.linalg.det(q).pow(1/n) - + q = q / torch.linalg.det(q).pow(1 / n) + return q + def special_unitary(module: nn.Module, name: str = "weight") -> nn.Module: weight = getattr(module, name, None) if not isinstance(weight, torch.Tensor): @@ -115,73 +127,61 @@ def special_unitary(module: nn.Module, name: str = "weight") -> nn.Module: if weight.ndim < 2: raise ValueError(f"Expected a matrix or batch of matrices. Got a tensor of {weight.ndim} dimensions.") - + if weight.shape[-2] != weight.shape[-1]: raise ValueError(f"Expected a square matrix or batch of square matrices. Got a tensor of shape {weight.shape}") - + unit = _SpecialUnitary() nn.utils.parametrize.register_parametrization(module, name, unit) return module + class _Clamp(nn.Module): def __init__(self, min, max): super(_Clamp, self).__init__() self.min = min self.max = max + def forward(self, x): if x.is_complex(): # clamp magnitude, ignore phase return torch.clamp(x.abs(), self.min, self.max) * x / x.abs() return torch.clamp(x, self.min, self.max) - + def clamp(module: nn.Module, name: str = "scale", min=0, max=1) -> nn.Module: scale = getattr(module, name, None) if not isinstance(scale, torch.Tensor): raise ValueError(f"Module '{module}' has no parameter or buffer '{name}'") - + cl = _Clamp(min, max) nn.utils.parametrize.register_parametrization(module, name, cl) return module -class ONNMiller(nn.Module): - def __init__(self, input_dim, output_dim, dtype=None) -> None: - super(ONNMiller, self).__init__() - self.input_dim = input_dim - self.output_dim = output_dim - self.dtype = dtype +class _EnergyConserving(nn.Module): + def __init__(self): + super(_EnergyConserving, self).__init__() - self.dim = max(input_dim, output_dim) + def forward(self, X: torch.Tensor): + if X.ndim == 2: + X = X.unsqueeze(0) + spectral_norm = torch.linalg.svdvals(X)[:, 0] + return (X / spectral_norm).squeeze() - # zero pad input to internal size if smaller - if self.input_dim < self.dim: - self.pad = lambda x: F.pad(x, ((self.dim - self.input_dim) // 2, (self.dim - self.input_dim + 1) // 2)) - else: - self.pad = lambda x: x - self.pad.__doc__ = f"Zero pad input from {self.input_dim} to {self.dim}" - # crop output to desired size - if self.output_dim < self.dim: - self.crop = lambda x: x[:, (self.dim - self.output_dim) // 2 : (x.shape[1] - (self.dim - self.output_dim + 1) // 2)] - else: - self.crop = lambda x: x - self.crop.__doc__ = f"Crop output from {self.dim} to {self.output_dim}" +def energy_conserving(module: nn.Module, name: str = "weight") -> nn.Module: + param = getattr(module, name, None) + if not isinstance(param, torch.Tensor): + raise ValueError(f"Module '{module}' has no parameter or buffer '{name}'") - self.U = nn.Parameter(torch.randn(self.dim, self.dim, dtype=self.dtype)) # -> parametrization: Unitary - self.S = nn.Parameter(torch.randn(self.dim, dtype=self.dtype)) # -> parametrization: Clamp (magnitude 0..1) - self.V = nn.Parameter(torch.randn(self.dim, self.dim, dtype=self.dtype)) # -> parametrization: Unitary - self.register_buffer("MZI_scale", torch.tensor(2, dtype=self.dtype.to_real()).sqrt()) - # V is actually V.H, but + if not (2 <= param.ndim <= 3): + raise ValueError(f"Expected a matrix or batch of matrices. Got a tensor of {param.ndim} dimensions.") + + unit = _EnergyConserving() + nn.utils.parametrize.register_parametrization(module, name, unit) + return module - def forward(self, x_in): - x = x_in - x = self.pad(x) - x = x @ self.U - x = x * (self.S.squeeze() / self.MZI_scale) - x = x @ self.V - x = self.crop(x) - return x class ONN(nn.Module): def __init__(self, input_dim, output_dim, dtype=None) -> None: @@ -202,56 +202,72 @@ class ONN(nn.Module): # crop output to desired size if self.output_dim < self.dim: - self.crop = lambda x: x[:, (self.dim - self.output_dim) // 2 : (x.shape[1] - (self.dim - self.output_dim + 1) // 2)] + self.crop = lambda x: x[ + :, (self.dim - self.output_dim) // 2 : (x.shape[1] - (self.dim - self.output_dim + 1) // 2) + ] self.crop.__doc__ = f"Crop output from {self.dim} to {self.output_dim}" else: self.crop = lambda x: x self.crop.__doc__ = f"Output size equals internal size {self.dim}" - self.weight = nn.Parameter(torch.randn(self.dim, self.dim, dtype=self.dtype)) + # self.scale = nn.Parameter(torch.randn(1, dtype=self.dtype.to_real())+0.5) def reset_parameters(self): q, _ = torch.linalg.qr(self.weight) self.weight.data = q + # def get_M(self): - # return self.U @ self.sigma @ self.V + # return self.U @ self.sigma @ self.V def forward(self, x): return self.crop(self.pad(x) @ self.weight) -class SemiUnitaryLayer(nn.Module): - def __init__(self, input_dim, output_dim, dtype=None): - super(SemiUnitaryLayer, self).__init__() +class ONNRect(nn.Module): + def __init__(self, input_dim, output_dim, square=False, dtype=None): + super(ONNRect, self).__init__() self.input_dim = input_dim self.output_dim = output_dim - # Create a larger square matrix for QR decomposition - self.weight = nn.Parameter(torch.randn(max(input_dim, output_dim), max(input_dim, output_dim), dtype=dtype)) - self.scale = nn.Parameter(torch.tensor(1.0, dtype=dtype.to_real())) - self.reset_parameters() + if square: + dim = max(input_dim, output_dim) + self.weight = nn.Parameter(torch.randn(dim, dim, dtype=dtype)) + + # zero pad input to internal size if smaller + if self.input_dim < dim: + self.pad = lambda x: F.pad(x, ((dim - self.input_dim) // 2, (dim - self.input_dim + 1) // 2)) + self.pad.__doc__ = f"Zero pad input from {self.input_dim} to {dim}" + else: + self.pad = lambda x: x + self.pad.__doc__ = f"Input size equals internal size {dim}" + + # crop output to desired size + if self.output_dim < dim: + self.crop = lambda x: x[ + :, (dim - self.output_dim) // 2 : (x.shape[1] - (dim - self.output_dim + 1) // 2) + ] + self.crop.__doc__ = f"Crop output from {dim} to {self.output_dim}" + else: + self.crop = lambda x: x + self.crop.__doc__ = f"Output size equals internal size {dim}" - def reset_parameters(self): - # Ensure the weights are unitary by QR decomposition - q, _ = torch.linalg.qr(self.weight) - # A = QR with A being a complex square matrix -> Q is unitary, R is upper triangular - # truncate the matrix to the desired size - if self.input_dim > self.output_dim: - self.weight.data = q[: self.input_dim, : self.output_dim] else: - self.weight.data = q[: self.output_dim, : self.input_dim].t() - ... + self.weight = nn.Parameter(torch.randn(output_dim, input_dim, dtype=dtype)) + self.pad = lambda x: x + self.pad.__doc__ = "No padding" + self.crop = lambda x: x + self.crop.__doc__ = "No cropping" + def forward(self, x): - with torch.no_grad(): - scale = torch.clamp(self.scale, 0.0, 1.0) - out = torch.matmul(x, scale * self.weight) + x = self.pad(x) + out = self.crop((self.weight @ x.mT).mT) return out - def __repr__(self): - return f"SemiUnitaryLayer({self.input_dim}, {self.output_dim})" + # def __repr__(self): + # return f"ONNRect({self.input_dim}, {self.output_dim})" # class SaturableAbsorberLambertW(nn.Module): @@ -336,6 +352,19 @@ class DropoutComplex(nn.Module): return self.dropout(x) +class Scale(nn.Module): + def __init__(self, size): + super(Scale, self).__init__() + self.size = size + self.scale = nn.Parameter(torch.ones(size, dtype=torch.float32)) + + def forward(self, x): + return x * self.scale + + def __repr__(self): + return f"Scale({self.size})" + + class Identity(nn.Module): """ implements the "activation" function @@ -348,6 +377,7 @@ class Identity(nn.Module): def forward(self, x): return x + class PowRot(nn.Module): def __init__(self, bias=False): super(PowRot, self).__init__() @@ -359,15 +389,75 @@ class PowRot(nn.Module): def forward(self, x: torch.Tensor): if x.is_complex(): - return x * torch.exp(-self.scale*1j*x.abs().square()+self.bias.to(dtype=x.dtype)) + return x * torch.exp(-self.scale * 1j * x.abs().square() + self.bias.to(dtype=x.dtype)) else: - return x + return x + + +class MZISingle(nn.Module): + def __init__(self, bias, size, func=None): + super(MZISingle, self).__init__() + self.omega = nn.Parameter(torch.randn(size)) + self.phi = nn.Parameter(torch.randn(size)) + self.func = func or (lambda x: x.abs().square()) # default to |z|^2 + + def forward(self, x: torch.Tensor): + return x * torch.exp(1j * self.phi) * torch.sin(self.omega + self.func(x)) + + +class EOActivation(nn.Module): + def __init__(self, bias, size=None): + # 10.1109/SiPhotonics60897.2024.10543376 + super(EOActivation, self).__init__() + if size is None: + raise ValueError("Size must be specified") + self.size = size + self.alpha = nn.Parameter(torch.ones(size)) + self.V_bias = nn.Parameter(torch.ones(size)) + self.gain = nn.Parameter(torch.ones(size)) + # if bias: + # self.phase_bias = nn.Parameter(torch.zeros(size)) + # else: + # self.register_buffer("phase_bias", torch.zeros(size)) + self.register_buffer("phase_bias", torch.clamp(torch.ones(size) + torch.randn(size)*0.1, 0, 1)*torch.pi) + self.register_buffer("responsivity", torch.ones(size)*0.9) + self.register_buffer("V_pi", torch.ones(size)*3) + + self.reset_weights() + + def reset_weights(self): + if "alpha" in self._parameters: + self.alpha.data = torch.ones(self.size)*0.5 + if "V_pi" in self._parameters: + self.V_pi.data = torch.ones(self.size)*3 + if "V_bias" in self._parameters: + self.V_bias.data = torch.zeros(self.size) + if "gain" in self._parameters: + self.gain.data = torch.ones(self.size) + if "responsivity" in self._parameters: + self.responsivity.data = torch.ones(self.size)*0.9 + if "bias" in self._parameters: + self.phase_bias.data = torch.zeros(self.size) + + def forward(self, x: torch.Tensor): + phi_b = torch.pi * self.V_bias / (self.V_pi + 1e-8) + g_phi = torch.pi * (self.alpha * self.gain * self.responsivity) / (self.V_pi + 1e-8) + intermediate = g_phi * x.abs().square() + phi_b + return ( + 1j + * torch.sqrt(1 - self.alpha) + * torch.exp(-0.5j * (intermediate + self.phase_bias)) + * torch.cos(0.5 * intermediate) + * x + ) + class Pow(nn.Module): """ implements the activation function M(z) = ||z||^2 + b """ + def __init__(self, bias=False): super(Pow, self).__init__() if bias: @@ -375,7 +465,6 @@ class Pow(nn.Module): else: self.register_buffer("bias", torch.tensor(0.0)) - def forward(self, x: torch.Tensor): return x.abs().square().add(self.bias).to(dtype=x.dtype) @@ -395,7 +484,7 @@ class Mag(nn.Module): def forward(self, x: torch.Tensor): return x.abs().add(self.bias).to(dtype=x.dtype) - + class MagScale(nn.Module): def __init__(self, bias=False): @@ -404,10 +493,11 @@ class MagScale(nn.Module): self.bias = nn.Parameter(torch.tensor(0.0)) else: self.register_buffer("bias", torch.tensor(0.0)) - + def forward(self, x: torch.Tensor): return x.abs().add(self.bias).to(dtype=x.dtype).sin().mul(x) - + + class PowScale(nn.Module): def __init__(self, bias=False): super(PowScale, self).__init__() @@ -415,7 +505,7 @@ class PowScale(nn.Module): self.bias = nn.Parameter(torch.tensor(0.0)) else: self.register_buffer("bias", torch.tensor(0.0)) - + def forward(self, x: torch.Tensor): return x.mul(x.abs().square().add(self.bias).to(dtype=x.dtype).sin()) @@ -486,10 +576,10 @@ __all__ = [ complex_mse_loss, UnitaryLayer, unitary, + energy_conserving, clamp, ONN, - ONNMiller, - SemiUnitaryLayer, + ONNRect, DropoutComplex, Identity, Pow, @@ -498,7 +588,9 @@ __all__ = [ ModReLU, CReLU, ZReLU, + MZISingle, + EOActivation, # SaturableAbsorberLambertW, # SaturableAbsorber, # SpreadLayer, -] \ No newline at end of file +]