update submodule configuration and enhance model settings; add eye diagram functionality

This commit is contained in:
Joseph Hopfmüller
2024-12-02 18:50:43 +01:00
parent aa2e7a4cb4
commit 297e9e8d7f
7 changed files with 626 additions and 249 deletions

1
.gitmodules vendored
View File

@@ -1,3 +1,4 @@
[submodule "pypho"] [submodule "pypho"]
path = pypho path = pypho
url = git@gitlab.lrz.de:000000003B9B3E61/pypho.git url = git@gitlab.lrz.de:000000003B9B3E61/pypho.git
branch = main

View File

@@ -258,12 +258,12 @@ class HyperTraining:
f"model_hidden_dim_{i}", f"model_hidden_dim_{i}",
self.model_settings.n_hidden_nodes, self.model_settings.n_hidden_nodes,
) )
layers.append(util.complexNN.SemiUnitaryLayer(last_dim, hidden_dim, dtype=dtype)) layers.append(util.complexNN.ONNRect(last_dim, hidden_dim, dtype=dtype))
last_dim = hidden_dim last_dim = hidden_dim
layers.append(getattr(util.complexNN, afunc)()) layers.append(getattr(util.complexNN, afunc)())
n_nodes += last_dim n_nodes += last_dim
layers.append(util.complexNN.SemiUnitaryLayer(last_dim, self.model_settings.output_dim, dtype=dtype)) layers.append(util.complexNN.ONNRect(last_dim, self.model_settings.output_dim, dtype=dtype))
model = nn.Sequential(*layers) model = nn.Sequential(*layers)

View File

@@ -11,7 +11,7 @@ class GlobalSettings:
# data settings # data settings
@dataclass @dataclass
class DataSettings: class DataSettings:
config_path: str # = "data/*-128-16384-100000-0-0-17-0-PAM4-0.ini" config_path: tuple | list | str # = "data/*-128-16384-100000-0-0-17-0-PAM4-0.ini"
dtype: tuple = ("complex64", "float64") dtype: tuple = ("complex64", "float64")
symbols: tuple | float | int = 8 symbols: tuple | float | int = 8
output_size: tuple | float | int = 64 output_size: tuple | float | int = 64
@@ -39,7 +39,7 @@ class PytorchSettings:
summary_dir: str = ".runs" summary_dir: str = ".runs"
write_every: int = 10 write_every: int = 10
head_symbols: int = 40 head_symbols: int = 40
eye_symbols: int = 400 eye_symbols: int = 1000
# model settings # model settings
@@ -52,13 +52,16 @@ class ModelSettings:
overrides: dict = field(default_factory=dict) overrides: dict = field(default_factory=dict)
dropout_prob: float | None = None dropout_prob: float | None = None
model_layer_function: str | None = None model_layer_function: str | None = None
scale: bool = False
model_layer_kwargs: dict | None = None
model_layer_parametrizations: list= field(default_factory=list) model_layer_parametrizations: list= field(default_factory=list)
@dataclass @dataclass
class OptimizerSettings: class OptimizerSettings:
optimizer: tuple | str = ("Adam", "RMSprop", "SGD") optimizer: tuple | str = ("Adam", "RMSprop", "SGD")
learning_rate: tuple | float = (1e-5, 1e-1) optimizer_kwargs: dict | None = None
# learning_rate: tuple | float = (1e-5, 1e-1)
scheduler: str | None = None scheduler: str | None = None
scheduler_kwargs: dict | None = None scheduler_kwargs: dict | None = None

View File

@@ -1,8 +1,10 @@
import copy import copy
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
import random
from typing import Literal from typing import Literal
import matplotlib import matplotlib
from matplotlib.colors import LinearSegmentedColormap
import torch.nn.utils.parametrize import torch.nn.utils.parametrize
try: try:
@@ -50,6 +52,7 @@ class regenerator(nn.Module):
self, self,
*dims, *dims,
layer_function=util.complexNN.ONN, layer_function=util.complexNN.ONN,
layer_kwargs: dict | None = None,
layer_parametrizations: list[dict] = None, layer_parametrizations: list[dict] = None,
# [ # [
# { # {
@@ -64,6 +67,7 @@ class regenerator(nn.Module):
activation_function=util.complexNN.Pow, activation_function=util.complexNN.Pow,
dtype=torch.float64, dtype=torch.float64,
dropout_prob=0.01, dropout_prob=0.01,
scale=False,
**kwargs, **kwargs,
): ):
super(regenerator, self).__init__() super(regenerator, self).__init__()
@@ -74,39 +78,57 @@ class regenerator(nn.Module):
raise ValueError("dims must be provided") raise ValueError("dims must be provided")
self._n_hidden_layers = len(dims) - 2 self._n_hidden_layers = len(dims) - 2
self._layers = nn.Sequential() self._layers = nn.Sequential()
if layer_kwargs is None:
layer_kwargs = {}
# self.powers = []
for i in range(self._n_hidden_layers + 1): for i in range(self._n_hidden_layers + 1):
self._layers.append(layer_function(dims[i], dims[i + 1], dtype=dtype)) if scale:
self._layers.append(util.complexNN.Scale(dims[i]))
self._layers.append(layer_function(dims[i], dims[i + 1], dtype=dtype, **layer_kwargs))
if i < self._n_hidden_layers: if i < self._n_hidden_layers:
if dropout_prob is not None: if dropout_prob is not None:
self._layers.append(util.complexNN.DropoutComplex(p=dropout_prob)) self._layers.append(util.complexNN.DropoutComplex(p=dropout_prob))
self._layers.append(activation_function()) self._layers.append(activation_function(bias=True, size=dims[i + 1]))
# add parametrizations self._layers.append(util.complexNN.Scale(dims[-1]))
if layer_parametrizations is not None:
# add parametrizations
if layer_parametrizations is not None:
for layer in self._layers:
for layer_parametrization in layer_parametrizations: for layer_parametrization in layer_parametrizations:
tensor_name = layer_parametrization.get("tensor_name", None) tensor_name = layer_parametrization.get("tensor_name", None)
parametrization = layer_parametrization.get("parametrization", None) parametrization = layer_parametrization.get("parametrization", None)
param_kwargs = layer_parametrization.get("kwargs", {}) param_kwargs = layer_parametrization.get("kwargs", {})
if ( if tensor_name is not None and tensor_name in layer._parameters and parametrization is not None:
tensor_name is not None parametrization(layer, tensor_name, **param_kwargs)
and tensor_name in self._layers[-1]._parameters
and parametrization is not None
):
parametrization(self._layers[-1], tensor_name, **param_kwargs)
def forward(self, input_x): # def __call__(self, input_x, **kwargs):
# return self.forward(input_x, **kwargs)
def forward(self, input_x, trace_powers=False):
x = input_x x = input_x
if trace_powers:
powers = [x.abs().square().sum()]
# check if tracing # check if tracing
if torch.jit.is_tracing(): if torch.jit.is_tracing():
for layer in self._layers: for layer in self._layers:
x = layer(x) x = layer(x)
if trace_powers:
powers.append(x.abs().square().sum())
else: else:
# with torch.nn.utils.parametrize.cached(): # with torch.nn.utils.parametrize.cached():
for layer in self._layers: for layer in self._layers:
x = layer(x) x = layer(x)
if trace_powers:
powers.append(x.abs().square().sum())
if trace_powers:
return x, powers
return x return x
def traverse_dict_update(target, source): def traverse_dict_update(target, source):
for k, v in source.items(): for k, v in source.items():
if isinstance(v, dict): if isinstance(v, dict):
@@ -119,6 +141,7 @@ def traverse_dict_update(target, source):
except TypeError: except TypeError:
target.__dict__[k] = v target.__dict__[k] = v
class Trainer: class Trainer:
def __init__( def __init__(
self, self,
@@ -142,7 +165,7 @@ class Trainer:
OptimizerSettings, OptimizerSettings,
PytorchSettings, PytorchSettings,
regenerator, regenerator,
torch.nn.utils.parametrizations.orthogonal torch.nn.utils.parametrizations.orthogonal,
]) ])
if self.resume: if self.resume:
self.checkpoint_dict = torch.load(checkpoint_path, weights_only=True) self.checkpoint_dict = torch.load(checkpoint_path, weights_only=True)
@@ -167,7 +190,7 @@ class Trainer:
raise ValueError("model_settings must be provided") raise ValueError("model_settings must be provided")
if optimizer_settings is None: if optimizer_settings is None:
raise ValueError("optimizer_settings must be provided") raise ValueError("optimizer_settings must be provided")
self.global_settings: GlobalSettings = global_settings self.global_settings: GlobalSettings = global_settings
self.data_settings: DataSettings = data_settings self.data_settings: DataSettings = data_settings
self.pytorch_settings: PytorchSettings = pytorch_settings self.pytorch_settings: PytorchSettings = pytorch_settings
@@ -206,6 +229,11 @@ class Trainer:
} }
def define_model(self, model_kwargs=None): def define_model(self, model_kwargs=None):
if self.resume:
model_kwargs = self.checkpoint_dict["model_kwargs"]
else:
model_kwargs = model_kwargs
if model_kwargs is None: if model_kwargs is None:
n_hidden_layers = self.model_settings.n_hidden_layers n_hidden_layers = self.model_settings.n_hidden_layers
@@ -228,6 +256,7 @@ class Trainer:
"activation_function": afunc, "activation_function": afunc,
"dtype": dtype, "dtype": dtype,
"dropout_prob": self.model_settings.dropout_prob, "dropout_prob": self.model_settings.dropout_prob,
"scale": self.model_settings.scale,
} }
else: else:
self.model_kwargs = model_kwargs self.model_kwargs = model_kwargs
@@ -237,9 +266,12 @@ class Trainer:
# dims = self.model_kwargs.pop("dims") # dims = self.model_kwargs.pop("dims")
self.model = regenerator(**self.model_kwargs) self.model = regenerator(**self.model_kwargs)
self.writer.add_graph(self.model, torch.zeros(1, input_dim, dtype=dtype)) if self.writer is not None:
self.writer.add_graph(self.model, torch.zeros(1, input_dim, dtype=dtype))
self.model = self.model.to(self.pytorch_settings.device) self.model = self.model.to(self.pytorch_settings.device)
if self.resume:
self.model.load_state_dict(self.checkpoint_dict["model_state_dict"], strict=False)
def get_sliced_data(self, override=None): def get_sliced_data(self, override=None):
symbols = self.data_settings.symbols symbols = self.data_settings.symbols
@@ -253,11 +285,13 @@ class Trainer:
dtype = getattr(torch, self.data_settings.dtype) dtype = getattr(torch, self.data_settings.dtype)
num_symbols = None num_symbols = None
config_path = self.data_settings.config_path
if override is not None: if override is not None:
num_symbols = override.get("num_symbols", None) num_symbols = override.get("num_symbols", None)
config_path = override.get("config_path", config_path)
# get dataset # get dataset
dataset = FiberRegenerationDataset( dataset = FiberRegenerationDataset(
file_path=self.data_settings.config_path, file_path=config_path,
symbols=symbols, symbols=symbols,
output_dim=data_size, output_dim=data_size,
target_delay=in_out_delay, target_delay=in_out_delay,
@@ -330,10 +364,11 @@ class Trainer:
task = progress.add_task("-.---e--", total=len(train_loader)) task = progress.add_task("-.---e--", total=len(train_loader))
progress.start() progress.start()
running_loss2 = 0.0 # running_loss2 = 0.0
running_loss = 0.0 running_loss = 0.0
self.model.train() self.model.train()
for batch_idx, (x, y) in enumerate(train_loader): loader_len = len(train_loader)
for batch_idx, (x, y, _) in enumerate(train_loader):
self.model.zero_grad(set_to_none=True) self.model.zero_grad(set_to_none=True)
x, y = ( x, y = (
x.to(self.pytorch_settings.device), x.to(self.pytorch_settings.device),
@@ -344,24 +379,23 @@ class Trainer:
loss_value = loss.item() loss_value = loss.item()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
running_loss2 += loss_value # running_loss2 += loss_value
running_loss += loss_value running_loss += loss_value
if enable_progress: if enable_progress:
progress.update(task, advance=1, description=f"{loss_value:.3e}") progress.update(task, advance=1, description=f"{running_loss/(batch_idx+1):.3e}")
if batch_idx % self.pytorch_settings.write_every == 0: if batch_idx % self.pytorch_settings.write_every == 0:
self.writer.add_scalar( self.writer.add_scalar(
"training loss", "training loss",
running_loss2 / (self.pytorch_settings.write_every if batch_idx > 0 else 1), running_loss / (batch_idx + 1),
epoch * len(train_loader) + batch_idx, epoch * loader_len + batch_idx,
) )
running_loss2 = 0.0
if enable_progress: if enable_progress:
progress.stop() progress.stop()
return running_loss / len(train_loader) return running_loss / (batch_idx + 1)
def eval_model(self, valid_loader, epoch, enable_progress=True): def eval_model(self, valid_loader, epoch, enable_progress=True):
if enable_progress: if enable_progress:
@@ -384,7 +418,7 @@ class Trainer:
self.model.eval() self.model.eval()
running_error = 0 running_error = 0
with torch.no_grad(): with torch.no_grad():
for batch_idx, (x, y) in enumerate(valid_loader): for batch_idx, (x, y, _) in enumerate(valid_loader):
x, y = ( x, y = (
x.to(self.pytorch_settings.device), x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device), y.to(self.pytorch_settings.device),
@@ -395,76 +429,107 @@ class Trainer:
running_error += error_value running_error += error_value
if enable_progress: if enable_progress:
progress.update(task, advance=1, description=f"{error_value:.3e}") progress.update(task, advance=1, description=f"{error_value/(batch_idx+1):.3e}")
running_error /= (batch_idx+1)
running_error /= len(valid_loader)
self.writer.add_scalar( self.writer.add_scalar(
"eval loss", "eval loss",
running_error, running_error,
epoch, epoch,
) )
title_append, subtitle = self.build_title(epoch + 1) if (epoch + 1) % 10 == 0 or epoch < 10:
self.writer.add_figure( # plotting is slow, so only do it every 10 epochs
"fiber response", title_append, subtitle = self.build_title(epoch + 1)
self.plot_model_response( self.writer.add_figure(
model=self.model, "fiber response",
title_append=title_append, self.plot_model_response(
subtitle=subtitle, model=self.model,
show=False, title_append=title_append,
), subtitle=subtitle,
epoch + 1, show=False,
) ),
self.writer.add_figure( epoch + 1,
"eye diagram", )
self.plot_model_response( self.writer.add_figure(
model=self.model, "eye diagram",
title_append=title_append, self.plot_model_response(
subtitle=subtitle, model=self.model,
show=False, title_append=title_append,
mode="eye", subtitle=subtitle,
), show=False,
epoch + 1, mode="eye",
) ),
self.writer_histograms(epoch + 1) epoch + 1,
)
self.writer.add_figure(
"powers",
self.plot_model_response(
model=self.model,
title_append=title_append,
subtitle=subtitle,
mode="powers",
show=False,
),
epoch + 1,
)
self.write_parameters(epoch + 1)
self.writer.flush()
if enable_progress: if enable_progress:
progress.stop() progress.stop()
return running_error return running_error
def run_model(self, model, loader): def run_model(self, model, loader, trace_powers=False):
model.eval() model.eval()
xs = [] fiber_out = []
ys = [] fiber_in = []
y_preds = [] regen = []
timestamps = []
with torch.no_grad(): with torch.no_grad():
model = model.to(self.pytorch_settings.device) model = model.to(self.pytorch_settings.device)
for x, y in loader: for x, y, timestamp in loader:
x, y = ( x, y = (
x.to(self.pytorch_settings.device), x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device), y.to(self.pytorch_settings.device),
) )
y_pred = model(x).cpu() if trace_powers:
y_pred, powers = model(x, trace_powers).cpu()
else:
y_pred = model(x, trace_powers).cpu()
# x = x.cpu() # x = x.cpu()
# y = y.cpu() # y = y.cpu()
y_pred = y_pred.view(y_pred.shape[0], -1, 2) y_pred = y_pred.view(y_pred.shape[0], -1, 2)
y = y.view(y.shape[0], -1, 2) y = y.view(y.shape[0], -1, 2)
x = x.view(x.shape[0], -1, 2) x = x.view(x.shape[0], -1, 2)
xs.append(x[:, 0, :].squeeze()) # timestamp = timestamp.view(-1, 1)
ys.append(y.squeeze()) fiber_out.append(x[:, x.shape[1] // 2, :].squeeze())
y_preds.append(y_pred.squeeze()) fiber_in.append(y.squeeze())
regen.append(y_pred.squeeze())
timestamps.append(timestamp.squeeze())
xs = torch.vstack(xs).cpu() fiber_out = torch.vstack(fiber_out).cpu()
ys = torch.vstack(ys).cpu() fiber_in = torch.vstack(fiber_in).cpu()
y_preds = torch.vstack(y_preds).cpu() regen = torch.vstack(regen).cpu()
return ys, xs, y_preds timestamps = torch.concat(timestamps).cpu()
if trace_powers:
return fiber_in, fiber_out, regen, timestamps, powers
return fiber_in, fiber_out, regen, timestamps
def writer_histograms(self, epoch, attributes=["weight", "weight_U", "weight_V", "bias", "sigma", "scale"]): def write_parameters(self, epoch, attributes: list[str] | tuple[str] = None):
for i, layer in enumerate(self.model._layers): for i, layer in enumerate(self.model._layers):
tag = f"layer {i}" tag = f"layer {i}"
for attribute in attributes: if hasattr(layer, "parametrizations"):
if hasattr(layer, attribute): attribute_pool = set(layer.parametrizations._modules) | set(layer._parameters)
else:
attribute_pool = set(layer._parameters)
for attribute in attribute_pool:
plot = (attributes is None) or (attribute in attributes)
if plot:
vals: np.ndarray = getattr(layer, attribute).detach().cpu().numpy().flatten() vals: np.ndarray = getattr(layer, attribute).detach().cpu().numpy().flatten()
if vals.ndim <= 1 and len(vals) == 1: if vals.ndim <= 1 and len(vals) == 1:
if np.iscomplexobj(vals): if np.iscomplexobj(vals):
@@ -483,14 +548,11 @@ class Trainer:
if self.writer is None: if self.writer is None:
self.setup_tb_writer() self.setup_tb_writer()
if self.resume: self.define_model()
model_kwargs = self.checkpoint_dict["model_kwargs"]
else:
model_kwargs = None
self.define_model(model_kwargs=model_kwargs) print(
f"number of parameters (trainable): {sum(p.numel() for p in self.model.parameters())} ({sum(p.numel() for p in self.model.parameters() if p.requires_grad)})"
print(f"number of parameters (trainable): {sum(p.numel() for p in self.model.parameters())} ({sum(p.numel() for p in self.model.parameters() if p.requires_grad)})") )
title_append, subtitle = self.build_title(0) title_append, subtitle = self.build_title(0)
@@ -515,36 +577,55 @@ class Trainer:
), ),
0, 0,
) )
self.writer_histograms(0)
self.writer.add_figure(
"powers",
self.plot_model_response(
model=self.model,
title_append=title_append,
subtitle=subtitle,
mode="powers",
show=False,
),
0,
)
self.write_parameters(0)
self.writer.add_text("datasets", '\n'.join(self.data_settings.config_path))
self.writer.flush()
train_loader, valid_loader = self.get_sliced_data() train_loader, valid_loader = self.get_sliced_data()
optimizer_name = self.optimizer_settings.optimizer optimizer_name = self.optimizer_settings.optimizer
lr = self.optimizer_settings.learning_rate # lr = self.optimizer_settings.learning_rate
self.optimizer: optim.Optimizer = getattr(optim, optimizer_name)(self.model.parameters(), lr=lr) self.optimizer: optim.Optimizer = getattr(optim, optimizer_name)(
self.model.parameters(), **self.optimizer_settings.optimizer_kwargs
)
if self.optimizer_settings.scheduler is not None: if self.optimizer_settings.scheduler is not None:
self.scheduler = getattr(optim.lr_scheduler, self.optimizer_settings.scheduler)( self.scheduler = getattr(optim.lr_scheduler, self.optimizer_settings.scheduler)(
self.optimizer, **self.optimizer_settings.scheduler_kwargs self.optimizer, **self.optimizer_settings.scheduler_kwargs
) )
if self.resume: # if self.resume:
try: # try:
self.scheduler.load_state_dict(self.checkpoint_dict["scheduler_state_dict"]) # self.scheduler.load_state_dict(self.checkpoint_dict["scheduler_state_dict"])
except ValueError: # except ValueError:
pass # pass
self.writer.add_scalar("learning rate", self.scheduler.get_last_lr()[0], -1) self.writer.add_scalar("learning rate", self.optimizer.param_groups[0]["lr"], -1)
if not self.resume: if not self.resume:
self.best = self.build_checkpoint_dict() self.best = self.build_checkpoint_dict()
else: else:
self.best = self.checkpoint_dict self.best = self.checkpoint_dict
self.model.load_state_dict(self.best["model_state_dict"], strict=False) self.best["loss"] = float("inf")
try: # self.model.load_state_dict(self.best["model_state_dict"], strict=False)
self.optimizer.load_state_dict(self.best["optimizer_state_dict"]) # try:
except ValueError: # self.optimizer.load_state_dict(self.best["optimizer_state_dict"])
pass # except ValueError:
# pass
for epoch in range(self.best["epoch"] + 1, self.pytorch_settings.epochs): for epoch in range(self.best["epoch"] + 1, self.pytorch_settings.epochs):
enable_progress = True enable_progress = True
@@ -562,12 +643,8 @@ class Trainer:
enable_progress=enable_progress, enable_progress=enable_progress,
) )
if self.optimizer_settings.scheduler is not None: if self.optimizer_settings.scheduler is not None:
lr_old = self.scheduler.get_last_lr()
self.scheduler.step(loss) self.scheduler.step(loss)
lr_new = self.scheduler.get_last_lr() self.writer.add_scalar("learning rate", self.optimizer.param_groups[0]["lr"], epoch)
if lr_old[0] != lr_new[0]:
self.writer.add_scalar("learning rate", lr_new[0], epoch)
if self.pytorch_settings.save_models and self.model is not None: if self.pytorch_settings.save_models and self.model is not None:
save_path = ( save_path = (
Path(self.pytorch_settings.model_dir) / f"{self.writer.get_logdir().split('/')[-1]}_{epoch}.tar" Path(self.pytorch_settings.model_dir) / f"{self.writer.get_logdir().split('/')[-1]}_{epoch}.tar"
@@ -588,7 +665,28 @@ class Trainer:
self.writer.close() self.writer.close()
return self.best return self.best
def _plot_model_response_eye(self, *signals, labels=None, sps=None, title_append="", subtitle="", show=True): def _plot_model_response_powers(self, powers, layer_names, title_append="", subtitle="", show=True):
powers = [power / powers[0] for power in powers]
fig, ax = plt.subplots()
fig.set_figwidth(18)
fig.suptitle(
f"Energy conservation{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}"
)
ax.semilogy(powers, marker="o")
ax.set_xticks(range(len(layer_names)), layer_names, rotation=90)
ax.set_xlabel("Layer")
ax.set_ylabel("Normailzed Power")
ax.grid(which="major", axis="x")
ax.grid(which="major", axis="y")
ax.grid(which="minor", axis="y", linestyle=":")
fig.tight_layout()
if show:
plt.show()
return fig
def _plot_model_response_eye(
self, *signals, timestamps, labels=None, sps=None, title_append="", subtitle="", show=True
):
if sps is None: if sps is None:
raise ValueError("sps must be provided") raise ValueError("sps must be provided")
if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))): if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))):
@@ -603,22 +701,73 @@ class Trainer:
if not any(labels): if not any(labels):
labels = [f"signal {i + 1}" for i in range(len(signals))] labels = [f"signal {i + 1}" for i in range(len(signals))]
x_bins = np.linspace(0, 2, 2 * sps, endpoint=False)
y_bins = np.zeros((2 * len(signals), 1000))
eye_data = np.zeros((2 * len(signals), 1000, 2 * sps))
# signals = [signal.cpu().numpy() for signal in signals]
for i in range(len(signals) * 2):
eye_signal = signals[i // 2][:, i % 2] # x, y, x, y, ...
eye_signal = np.real(np.square(np.abs(eye_signal)))
data_min = np.min(eye_signal)
data_max = np.max(eye_signal)
y_bins[i] = np.linspace(data_min, data_max, 1000, endpoint=False)
for j in range(len(timestamps)):
t = timestamps[j] / sps
val = eye_signal[j]
x = np.digitize(t % 2, x_bins) - 1
y = np.digitize(val, y_bins[i]) - 1
eye_data[i][y][x] += 1
cmap = LinearSegmentedColormap.from_list(
"eyemap",
[
(0, "white"),
(0.001, "dodgerblue"),
(0.1, "blue"),
(0.2, "cyan"),
(0.5, "lime"),
(0.8, "gold"),
(1, "red"),
],
)
# ordering = np.argsort(timestamps)
# signals = [signal[ordering] for signal in signals]
# timestamps = timestamps[ordering]
fig, axs = plt.subplots(1, 2 * len(signals), sharex=True, sharey=True) fig, axs = plt.subplots(1, 2 * len(signals), sharex=True, sharey=True)
fig.set_figwidth(18) fig.set_figwidth(18)
fig.suptitle(f"Eye diagram{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}") fig.suptitle(f"Eye diagram{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}")
xaxis = np.linspace(0, 2, 2 * sps, endpoint=False) # xaxis = timestamps / sps
for j, (label, signal) in enumerate(zip(labels, signals)): # xaxis = np.arange(2 * sps) / sps
for j, label in enumerate(labels):
x = eye_data[2 * j]
y = eye_data[2 * j + 1]
# x, y = signal.T
# signal = signal.cpu().numpy() # signal = signal.cpu().numpy()
for i in range(len(signal) // sps - 1): # for i in range(len(signal) // sps - 1):
x, y = signal[i * sps : (i + 2) * sps].T # x, y = signal[i * sps : (i + 2) * sps].T
axs[0 + 2 * j].plot(xaxis, np.abs(x) ** 2, color=f"C{j}", alpha=1 / (len(signal) // sps) * 10) # axs[0 + 2 * j].scatter((timestamps/sps) % 2, np.abs(x) ** 2, color=f"C{j}", alpha=1 / (len(signal) // sps) * 10, s=1)
axs[1 + 2 * j].plot(xaxis, np.abs(y) ** 2, color=f"C{j}", alpha=1 / (len(signal) // sps) * 10) # axs[1 + 2 * j].scatter((timestamps/sps) % 2, np.abs(y) ** 2, color=f"C{j}", alpha=1 / (len(signal) // sps) * 10, s=1)
axs[0 + 2 * j].set_title(label + " x") axs[0 + 2 * j].imshow(
axs[1 + 2 * j].set_title(label + " y") x, aspect="auto", cmap=cmap, origin="lower", extent=[0, 2, y_bins[2 * j][0], y_bins[2 * j][-1]]
axs[0 + 2 * j].set_xlabel("Symbol") )
axs[1 + 2 * j].set_xlabel("Symbol") axs[1 + 2 * j].imshow(
axs[0 + 2 * j].set_box_aspect(1) y, aspect="auto", cmap=cmap, origin="lower", extent=[0, 2, y_bins[2 * j + 1][0], y_bins[2 * j + 1][-1]]
axs[1 + 2 * j].set_box_aspect(1) )
axs[0 + 2 * j].set_xlim((x_bins[0], x_bins[-1]))
axs[1 + 2 * j].set_xlim((x_bins[0], x_bins[-1]))
ymin = np.min(y_bins[:, 0])
ymax = np.max(y_bins[:, -1])
ydiff = ymax - ymin
axs[0 + 2 * j].set_ylim((ymin - 0.05 * ydiff, ymax + 0.05 * ydiff))
axs[1 + 2 * j].set_ylim((ymin - 0.05 * ydiff, ymax + 0.05 * ydiff))
axs[0 + 2 * j].set_title(label + " x")
axs[1 + 2 * j].set_title(label + " y")
axs[0 + 2 * j].set_xlabel("Symbol")
axs[1 + 2 * j].set_xlabel("Symbol")
axs[0 + 2 * j].set_box_aspect(1)
axs[1 + 2 * j].set_box_aspect(1)
axs[0].set_ylabel("normalized power") axs[0].set_ylabel("normalized power")
fig.tight_layout() fig.tight_layout()
# axs[1+2*len(labels)-1].set_ylabel("normalized power") # axs[1+2*len(labels)-1].set_ylabel("normalized power")
@@ -627,7 +776,9 @@ class Trainer:
plt.show() plt.show()
return fig return fig
def _plot_model_response_head(self, *signals, labels=None, sps=None, title_append="", subtitle="", show=True): def _plot_model_response_head(
self, *signals, timestamps, labels=None, sps=None, title_append="", subtitle="", show=True
):
if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))): if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))):
labels = [labels] labels = [labels]
else: else:
@@ -640,19 +791,29 @@ class Trainer:
if not any(labels): if not any(labels):
labels = [f"signal {i + 1}" for i in range(len(signals))] labels = [f"signal {i + 1}" for i in range(len(signals))]
ordering = np.argsort(timestamps)
signals = [signal[ordering] for signal in signals]
timestamps = timestamps[ordering]
fig, axs = plt.subplots(1, 2, sharex=True, sharey=True) fig, axs = plt.subplots(1, 2, sharex=True, sharey=True)
fig.set_figwidth(18) fig.set_figwidth(18)
fig.set_figheight(4) fig.set_figheight(4)
fig.suptitle(f"Fiber response{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}") fig.suptitle(f"Fiber response{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}")
for i, ax in enumerate(axs): for i, ax in enumerate(axs):
ax: plt.Axes
for signal, label in zip(signals, labels): for signal, label in zip(signals, labels):
if sps is not None: if sps is not None:
xaxis = np.linspace(0, len(signal) / sps, len(signal), endpoint=False) xaxis = timestamps / sps
else: else:
xaxis = np.arange(len(signal)) xaxis = timestamps
ax.plot(xaxis, np.abs(signal[:, i]) ** 2, label=label) ax.plot(xaxis, np.abs(signal[:, i]) ** 2, label=label)
ax.set_xlabel("Sample" if sps is None else "Symbol") ax.set_xlabel("Sample" if sps is None else "Symbol")
ax.set_ylabel("normalized power") ax.set_ylabel("normalized power")
ax.minorticks_on()
ax.tick_params(axis="y", which="minor", left=False, right=False)
ax.grid(which="major", axis="x")
ax.grid(which="minor", axis="x", linestyle=":")
ax.grid(which="major", axis="y")
ax.legend(loc="upper right") ax.legend(loc="upper right")
fig.tight_layout() fig.tight_layout()
if show: if show:
@@ -664,22 +825,51 @@ class Trainer:
model=None, model=None,
title_append="", title_append="",
subtitle="", subtitle="",
mode: Literal["eye", "head"] = "head", mode: Literal["eye", "head", "powers"] = "head",
show=False, show=False,
): ):
if mode == "powers":
input_data = torch.ones(
1, 2 * self.data_settings.output_size, dtype=getattr(torch, self.data_settings.dtype)
).to(self.pytorch_settings.device)
model = model.to(self.pytorch_settings.device)
model.eval()
with torch.no_grad():
_, powers = model(input_data, trace_powers=True)
powers = [power.item() for power in powers]
layer_names = ["input", *[str(x).split("(")[0] for x in model._layers._modules.values()]]
# remove dropout layers
mask = [1 if "Dropout" not in layer_name else 0 for layer_name in layer_names]
layer_names = [layer_name for layer_name, m in zip(layer_names, mask) if m]
powers = [power for power, m in zip(powers, mask) if m]
fig = self._plot_model_response_powers(
powers, layer_names, title_append=title_append, subtitle=subtitle, show=show
)
return fig
data_settings_backup = copy.deepcopy(self.data_settings) data_settings_backup = copy.deepcopy(self.data_settings)
pytorch_settings_backup = copy.deepcopy(self.pytorch_settings) pytorch_settings_backup = copy.deepcopy(self.pytorch_settings)
self.data_settings.drop_first = 100 * 128 self.data_settings.drop_first = 99.5 + random.randint(0, 1000)
self.data_settings.shuffle = False self.data_settings.shuffle = False
self.data_settings.train_split = 1.0 self.data_settings.train_split = 1.0
self.pytorch_settings.batchsize = ( self.pytorch_settings.batchsize = (
self.pytorch_settings.eye_symbols if mode == "eye" else self.pytorch_settings.head_symbols self.pytorch_settings.eye_symbols if mode == "eye" else self.pytorch_settings.head_symbols
) )
plot_loader, _ = self.get_sliced_data(override={"num_symbols": self.pytorch_settings.batchsize}) config_path = random.choice(self.data_settings.config_path)
fiber_length = int(float(str(config_path).split('-')[-7])/1000)
plot_loader, _ = self.get_sliced_data(
override={
"num_symbols": self.pytorch_settings.batchsize,
"config_path": config_path,
}
)
self.data_settings = data_settings_backup self.data_settings = data_settings_backup
self.pytorch_settings = pytorch_settings_backup self.pytorch_settings = pytorch_settings_backup
fiber_in, fiber_out, regen = self.run_model(model, plot_loader) fiber_in, fiber_out, regen, timestamps = self.run_model(model, plot_loader)
fiber_in = fiber_in.view(-1, 2) fiber_in = fiber_in.view(-1, 2)
fiber_out = fiber_out.view(-1, 2) fiber_out = fiber_out.view(-1, 2)
regen = regen.view(-1, 2) regen = regen.view(-1, 2)
@@ -687,6 +877,7 @@ class Trainer:
fiber_in = fiber_in.numpy() fiber_in = fiber_in.numpy()
fiber_out = fiber_out.numpy() fiber_out = fiber_out.numpy()
regen = regen.numpy() regen = regen.numpy()
timestamps = timestamps.numpy()
# https://github.com/matplotlib/matplotlib/issues/27713#issue-2104110987 # https://github.com/matplotlib/matplotlib/issues/27713#issue-2104110987
# https://github.com/matplotlib/matplotlib/issues/27713#issuecomment-1915497463 # https://github.com/matplotlib/matplotlib/issues/27713#issuecomment-1915497463
@@ -697,9 +888,10 @@ class Trainer:
fiber_in, fiber_in,
fiber_out, fiber_out,
regen, regen,
timestamps=timestamps,
labels=("fiber in", "fiber out", "regen"), labels=("fiber in", "fiber out", "regen"),
sps=plot_loader.dataset.samples_per_symbol, sps=plot_loader.dataset.samples_per_symbol,
title_append=title_append, title_append=title_append + f" ({fiber_length} km)",
subtitle=subtitle, subtitle=subtitle,
show=show, show=show,
) )
@@ -709,9 +901,10 @@ class Trainer:
fiber_in, fiber_in,
fiber_out, fiber_out,
regen, regen,
timestamps=timestamps,
labels=("fiber in", "fiber out", "regen"), labels=("fiber in", "fiber out", "regen"),
sps=plot_loader.dataset.samples_per_symbol, sps=plot_loader.dataset.samples_per_symbol,
title_append=title_append, title_append=title_append + f" ({fiber_length} km)",
subtitle=subtitle, subtitle=subtitle,
show=show, show=show,
) )

View File

@@ -1,3 +1,6 @@
import matplotlib
import numpy as np
import torch
from hypertraining.settings import ( from hypertraining.settings import (
GlobalSettings, GlobalSettings,
DataSettings, DataSettings,
@@ -7,16 +10,20 @@ from hypertraining.settings import (
) )
from hypertraining.training import Trainer from hypertraining.training import Trainer
import torch
# import torch
import json import json
import util import util
from rich import print as rprint
global_settings = GlobalSettings( global_settings = GlobalSettings(
seed=42, seed=0xC0FFEE,
) )
data_settings = DataSettings( data_settings = DataSettings(
config_path="data/*-128-16384-100000-0-0-17-0-PAM4-0.ini", # config_path="data/*-128-16384-50000-0-0-17-0-PAM4-0.ini",
config_path=[f"data/*-128-16384-{length}-0-0-17-0-PAM4-0.ini" for length in (40000, 50000, 60000)],
dtype="complex64", dtype="complex64",
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber # symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
symbols=13, # study: single_core_regen_20241123_011232 symbols=13, # study: single_core_regen_20241123_011232
@@ -25,7 +32,7 @@ data_settings = DataSettings(
shuffle=True, shuffle=True,
in_out_delay=0, in_out_delay=0,
xy_delay=0, xy_delay=0,
drop_first=128*64, drop_first=128 * 64,
train_split=0.8, train_split=0.8,
) )
@@ -45,55 +52,83 @@ model_settings = ModelSettings(
output_dim=2, output_dim=2,
n_hidden_layers=4, n_hidden_layers=4,
overrides={ overrides={
"n_hidden_nodes_0": 8, "n_hidden_nodes_0": 4,
"n_hidden_nodes_1": 8, "n_hidden_nodes_1": 4,
"n_hidden_nodes_2": 4, "n_hidden_nodes_2": 4,
"n_hidden_nodes_3": 6, "n_hidden_nodes_3": 4,
}, },
model_activation_func="PowScale", model_activation_func="EOActivation",
# dropout_prob=0.01, dropout_prob=0.01,
model_layer_function="ONN", model_layer_function="ONNRect",
model_layer_kwargs={"square": True},
scale=True,
model_layer_parametrizations=[ model_layer_parametrizations=[
{ {
"tensor_name": "weight", "tensor_name": "weight",
"parametrization": torch.nn.utils.parametrizations.orthogonal, "parametrization": util.complexNN.energy_conserving,
},
{
"tensor_name": "alpha",
"parametrization": util.complexNN.clamp,
},
{
"tensor_name": "gain",
"parametrization": util.complexNN.clamp,
"kwargs": {
"min": 0,
"max": float("inf"),
},
},
{
"tensor_name": "phase_bias",
"parametrization": util.complexNN.clamp,
"kwargs": {
"min": 0,
"max": 2 * torch.pi,
},
}, },
{ {
"tensor_name": "scales", "tensor_name": "scales",
"parametrization": util.complexNN.clamp, "parametrization": util.complexNN.clamp,
}, },
{ # {
"tensor_name": "scale", # "tensor_name": "scale",
"parametrization": util.complexNN.clamp, # "parametrization": util.complexNN.clamp,
}, # },
{ # {
"tensor_name": "bias", # "tensor_name": "bias",
"parametrization": util.complexNN.clamp, # "parametrization": util.complexNN.clamp,
}, # },
# { # {
# "tensor_name": "V", # "tensor_name": "V",
# "parametrization": torch.nn.utils.parametrizations.orthogonal, # "parametrization": torch.nn.utils.parametrizations.orthogonal,
# }, # },
# { {
# "tensor_name": "S", "tensor_name": "loss",
# "parametrization": util.complexNN.clamp, "parametrization": util.complexNN.clamp,
# }, },
], ],
) )
optimizer_settings = OptimizerSettings( optimizer_settings = OptimizerSettings(
optimizer="Adam", optimizer="AdamW",
learning_rate=0.05, optimizer_kwargs={
"lr": 0.05,
"amsgrad": True,
# "weight_decay": 1e-7,
},
# learning_rate=0.05,
scheduler="ReduceLROnPlateau", scheduler="ReduceLROnPlateau",
scheduler_kwargs={ scheduler_kwargs={
"patience": 2**6, "patience": 2**6,
"factor": 0.9, "factor": 0.75,
# "threshold": 1e-3, # "threshold": 1e-3,
"min_lr": 1e-6, "min_lr": 1e-6,
"cooldown": 10, "cooldown": 10,
}, },
) )
def save_dict_to_file(dictionary, filename): def save_dict_to_file(dictionary, filename):
""" """
Save the best dictionary to a JSON file. Save the best dictionary to a JSON file.
@@ -103,28 +138,79 @@ def save_dict_to_file(dictionary, filename):
:param filename: Path to the JSON file where the dictionary will be saved. :param filename: Path to the JSON file where the dictionary will be saved.
:type filename: str :type filename: str
""" """
with open(filename, 'w') as f: with open(filename, "w") as f:
json.dump(dictionary, f, indent=4) json.dump(dictionary, f, indent=4)
def sweep_lengths(*lengths, model=None):
assert model is not None, "Model must be provided."
model = model
fiber_ins = {}
fiber_outs = {}
regens = {}
timestampss = {}
for length in lengths:
trainer = Trainer(
checkpoint_path=model,
settings_override={
"data_settings": {
"config_path": f"data/*-128-16384-{length}-0-0-17-0-PAM4-0.ini",
"train_split": 1,
"shuffle": True,
}
},
)
trainer.define_model()
loader, _ = trainer.get_sliced_data()
fiber_in, fiber_out, regen, timestamps = trainer.run_model(trainer.model, loader=loader)
fiber_ins[length] = fiber_in
fiber_outs[length] = fiber_out
regens[length] = regen
timestampss[length] = timestamps
data = torch.zeros(2 * len(lengths), 2, fiber_out.shape[0])
channel_names = ["" for _ in range(2 * len(lengths))]
for li, length in enumerate(lengths):
data[2 * li, 0, :] = timestampss[length] / 128
data[2 * li, 1, :] = regens[length][:, 0].abs().square()
data[2 * li + 1, 0, :] = timestampss[length] / 128
data[2 * li + 1, 1, :] = regens[length][:, 1].abs().square()
channel_names[2 * li] = f"regen x {length}"
channel_names[2 * li + 1] = f"regen y {length}"
# get current backend
backend = matplotlib.get_backend()
matplotlib.use("TkCairo")
eye = util.eye_diagram.eye_diagram(data.to(dtype=torch.float32).detach().cpu().numpy(), channel_names=channel_names)
print_attrs = ("channel", "success", "min_area")
with np.printoptions(precision=3, suppress=True, formatter={'float': '{:0.3e}'.format}):
for result in eye.eye_stats:
print_dict = {attr: result[attr] for attr in print_attrs}
rprint(print_dict)
rprint()
eye.plot()
matplotlib.use(backend)
if __name__ == "__main__": if __name__ == "__main__":
# sweep_lengths(30000, 40000, 50000, 60000, 70000, model=".models/best_20241202_143149.tar")
trainer = Trainer( trainer = Trainer(
global_settings=global_settings, global_settings=global_settings,
data_settings=data_settings, data_settings=data_settings,
pytorch_settings=pytorch_settings, pytorch_settings=pytorch_settings,
model_settings=model_settings, model_settings=model_settings,
optimizer_settings=optimizer_settings, optimizer_settings=optimizer_settings,
checkpoint_path='.models/20241128_084935_8885.tar', # checkpoint_path=".models/best_20241202_143149.tar",
settings_override={ # 20241202_143149
"model_settings": {
# "model_activation_func": "PowScale",
"dropout_prob": 0,
}
},
reset_epoch=True,
) )
trainer.train()
best = trainer.train()
save_dict_to_file(best, ".models/best_results.json")
...

View File

@@ -16,4 +16,6 @@ from . import complexNN # noqa: F401
# from .complexNN import complex_mse_loss # noqa: F401 # from .complexNN import complex_mse_loss # noqa: F401
# from .complexNN import complex_sse_loss # noqa: F401 # from .complexNN import complex_sse_loss # noqa: F401
from . import misc # noqa: F401 from . import misc # noqa: F401
from . import eye_diagram # noqa: F401

View File

@@ -4,23 +4,36 @@ import torch.nn.functional as F
# from torchlambertw.special import lambertw # from torchlambertw.special import lambertw
def complex_mse_loss(input, target, power=False, reduction="mean"): def complex_mse_loss(input, target, power=False, normalize=False, reduction="mean"):
""" """
Compute the mean squared error between two complex tensors. Compute the mean squared error between two complex tensors.
If power is set to True, the loss is computed as |input|^2 - |target|^2 If power is set to True, the loss is computed as |input|^2 - |target|^2
""" """
reduce = getattr(torch, reduction) reduce = getattr(torch, reduction)
power_penalty = 0
if power: if power:
input = (input * input.conj()).real.to(dtype=input.dtype.to_real()) input = (input * input.conj()).real.to(dtype=input.dtype.to_real())
target = (target * target.conj()).real.to(dtype=target.dtype.to_real()) target = (target * target.conj()).real.to(dtype=target.dtype.to_real())
if normalize:
power_penalty = ((input.max() - input.min()) - (target.max() - target.min())) ** 2
power_penalty += (input.min() - target.min()) ** 2
input = input - input.min()
input = input / input.max()
target = target - target.min()
target = target / target.max()
else:
if normalize:
power_penalty = (input.abs().max() - target.abs().max()) ** 2
input = input / input.abs().max()
target = target / target.abs().max()
if input.is_complex() and target.is_complex(): if input.is_complex() and target.is_complex():
return reduce(torch.square(input.real - target.real) + torch.square(input.imag - target.imag)) return reduce(torch.square(input.real - target.real) + torch.square(input.imag - target.imag)) + power_penalty
elif input.is_complex() or target.is_complex(): elif input.is_complex() or target.is_complex():
raise ValueError("Input and target must have the same type (real or complex)") raise ValueError("Input and target must have the same type (real or complex)")
else: else:
return F.mse_loss(input, target, reduction=reduction) return F.mse_loss(input, target, reduction=reduction) + power_penalty
def complex_sse_loss(input, target): def complex_sse_loss(input, target):
@@ -53,23 +66,19 @@ class UnitaryLayer(nn.Module):
return f"UnitaryLayer({self.in_features}, {self.out_features})" return f"UnitaryLayer({self.in_features}, {self.out_features})"
class _Unitary(nn.Module): class _Unitary(nn.Module):
def forward(self, X:torch.Tensor): def forward(self, X: torch.Tensor):
if X.ndim < 2: if X.ndim < 2:
raise ValueError( raise ValueError(f"Only tensors with 2 or more dimensions are supported. Got a tensor of shape {X.shape}")
"Only tensors with 2 or more dimensions are supported. "
f"Got a tensor of shape {X.shape}"
)
n, k = X.size(-2), X.size(-1) n, k = X.size(-2), X.size(-1)
transpose = n<k transpose = n < k
if transpose: if transpose:
X = X.transpose(-2, -1) X = X.transpose(-2, -1)
q, r = torch.linalg.qr(X) q, r = torch.linalg.qr(X)
# q: torch.Tensor = q # q: torch.Tensor = q
# r: torch.Tensor = r # r: torch.Tensor = r
d = r.diagonal(dim1=-2, dim2=-1).sgn() d = r.diagonal(dim1=-2, dim2=-1).sgn()
q*=d.unsqueeze(-2) q *= d.unsqueeze(-2)
if transpose: if transpose:
q = q.transpose(-2, -1) q = q.transpose(-2, -1)
if n == k: if n == k:
@@ -80,6 +89,7 @@ class _Unitary(nn.Module):
# X.copy_(q) # X.copy_(q)
return q return q
def unitary(module: nn.Module, name: str = "weight") -> nn.Module: def unitary(module: nn.Module, name: str = "weight") -> nn.Module:
weight = getattr(module, name, None) weight = getattr(module, name, None)
if not isinstance(weight, torch.Tensor): if not isinstance(weight, torch.Tensor):
@@ -87,27 +97,29 @@ def unitary(module: nn.Module, name: str = "weight") -> nn.Module:
if weight.ndim < 2: if weight.ndim < 2:
raise ValueError(f"Expected a matrix or batch of matrices. Got a tensor of {weight.ndim} dimensions.") raise ValueError(f"Expected a matrix or batch of matrices. Got a tensor of {weight.ndim} dimensions.")
if weight.shape[-2] != weight.shape[-1]: if weight.shape[-2] != weight.shape[-1]:
raise ValueError(f"Expected a square matrix or batch of square matrices. Got a tensor of shape {weight.shape}") raise ValueError(f"Expected a square matrix or batch of square matrices. Got a tensor of shape {weight.shape}")
unit = _Unitary() unit = _Unitary()
nn.utils.parametrize.register_parametrization(module, name, unit) nn.utils.parametrize.register_parametrization(module, name, unit)
return module return module
class _SpecialUnitary(nn.Module): class _SpecialUnitary(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
def forward(self, X:torch.Tensor): def forward(self, X: torch.Tensor):
n, k = X.size(-2), X.size(-1) n, k = X.size(-2), X.size(-1)
if n != k: if n != k:
raise ValueError(f"Expected a square matrix. Got a tensor of shape {X.shape}") raise ValueError(f"Expected a square matrix. Got a tensor of shape {X.shape}")
q, _ = torch.linalg.qr(X) q, _ = torch.linalg.qr(X)
q = q / torch.linalg.det(q).pow(1/n) q = q / torch.linalg.det(q).pow(1 / n)
return q return q
def special_unitary(module: nn.Module, name: str = "weight") -> nn.Module: def special_unitary(module: nn.Module, name: str = "weight") -> nn.Module:
weight = getattr(module, name, None) weight = getattr(module, name, None)
if not isinstance(weight, torch.Tensor): if not isinstance(weight, torch.Tensor):
@@ -115,73 +127,61 @@ def special_unitary(module: nn.Module, name: str = "weight") -> nn.Module:
if weight.ndim < 2: if weight.ndim < 2:
raise ValueError(f"Expected a matrix or batch of matrices. Got a tensor of {weight.ndim} dimensions.") raise ValueError(f"Expected a matrix or batch of matrices. Got a tensor of {weight.ndim} dimensions.")
if weight.shape[-2] != weight.shape[-1]: if weight.shape[-2] != weight.shape[-1]:
raise ValueError(f"Expected a square matrix or batch of square matrices. Got a tensor of shape {weight.shape}") raise ValueError(f"Expected a square matrix or batch of square matrices. Got a tensor of shape {weight.shape}")
unit = _SpecialUnitary() unit = _SpecialUnitary()
nn.utils.parametrize.register_parametrization(module, name, unit) nn.utils.parametrize.register_parametrization(module, name, unit)
return module return module
class _Clamp(nn.Module): class _Clamp(nn.Module):
def __init__(self, min, max): def __init__(self, min, max):
super(_Clamp, self).__init__() super(_Clamp, self).__init__()
self.min = min self.min = min
self.max = max self.max = max
def forward(self, x): def forward(self, x):
if x.is_complex(): if x.is_complex():
# clamp magnitude, ignore phase # clamp magnitude, ignore phase
return torch.clamp(x.abs(), self.min, self.max) * x / x.abs() return torch.clamp(x.abs(), self.min, self.max) * x / x.abs()
return torch.clamp(x, self.min, self.max) return torch.clamp(x, self.min, self.max)
def clamp(module: nn.Module, name: str = "scale", min=0, max=1) -> nn.Module: def clamp(module: nn.Module, name: str = "scale", min=0, max=1) -> nn.Module:
scale = getattr(module, name, None) scale = getattr(module, name, None)
if not isinstance(scale, torch.Tensor): if not isinstance(scale, torch.Tensor):
raise ValueError(f"Module '{module}' has no parameter or buffer '{name}'") raise ValueError(f"Module '{module}' has no parameter or buffer '{name}'")
cl = _Clamp(min, max) cl = _Clamp(min, max)
nn.utils.parametrize.register_parametrization(module, name, cl) nn.utils.parametrize.register_parametrization(module, name, cl)
return module return module
class ONNMiller(nn.Module): class _EnergyConserving(nn.Module):
def __init__(self, input_dim, output_dim, dtype=None) -> None: def __init__(self):
super(ONNMiller, self).__init__() super(_EnergyConserving, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.dtype = dtype
self.dim = max(input_dim, output_dim) def forward(self, X: torch.Tensor):
if X.ndim == 2:
X = X.unsqueeze(0)
spectral_norm = torch.linalg.svdvals(X)[:, 0]
return (X / spectral_norm).squeeze()
# zero pad input to internal size if smaller
if self.input_dim < self.dim:
self.pad = lambda x: F.pad(x, ((self.dim - self.input_dim) // 2, (self.dim - self.input_dim + 1) // 2))
else:
self.pad = lambda x: x
self.pad.__doc__ = f"Zero pad input from {self.input_dim} to {self.dim}"
# crop output to desired size def energy_conserving(module: nn.Module, name: str = "weight") -> nn.Module:
if self.output_dim < self.dim: param = getattr(module, name, None)
self.crop = lambda x: x[:, (self.dim - self.output_dim) // 2 : (x.shape[1] - (self.dim - self.output_dim + 1) // 2)] if not isinstance(param, torch.Tensor):
else: raise ValueError(f"Module '{module}' has no parameter or buffer '{name}'")
self.crop = lambda x: x
self.crop.__doc__ = f"Crop output from {self.dim} to {self.output_dim}"
self.U = nn.Parameter(torch.randn(self.dim, self.dim, dtype=self.dtype)) # -> parametrization: Unitary if not (2 <= param.ndim <= 3):
self.S = nn.Parameter(torch.randn(self.dim, dtype=self.dtype)) # -> parametrization: Clamp (magnitude 0..1) raise ValueError(f"Expected a matrix or batch of matrices. Got a tensor of {param.ndim} dimensions.")
self.V = nn.Parameter(torch.randn(self.dim, self.dim, dtype=self.dtype)) # -> parametrization: Unitary
self.register_buffer("MZI_scale", torch.tensor(2, dtype=self.dtype.to_real()).sqrt()) unit = _EnergyConserving()
# V is actually V.H, but nn.utils.parametrize.register_parametrization(module, name, unit)
return module
def forward(self, x_in):
x = x_in
x = self.pad(x)
x = x @ self.U
x = x * (self.S.squeeze() / self.MZI_scale)
x = x @ self.V
x = self.crop(x)
return x
class ONN(nn.Module): class ONN(nn.Module):
def __init__(self, input_dim, output_dim, dtype=None) -> None: def __init__(self, input_dim, output_dim, dtype=None) -> None:
@@ -202,56 +202,72 @@ class ONN(nn.Module):
# crop output to desired size # crop output to desired size
if self.output_dim < self.dim: if self.output_dim < self.dim:
self.crop = lambda x: x[:, (self.dim - self.output_dim) // 2 : (x.shape[1] - (self.dim - self.output_dim + 1) // 2)] self.crop = lambda x: x[
:, (self.dim - self.output_dim) // 2 : (x.shape[1] - (self.dim - self.output_dim + 1) // 2)
]
self.crop.__doc__ = f"Crop output from {self.dim} to {self.output_dim}" self.crop.__doc__ = f"Crop output from {self.dim} to {self.output_dim}"
else: else:
self.crop = lambda x: x self.crop = lambda x: x
self.crop.__doc__ = f"Output size equals internal size {self.dim}" self.crop.__doc__ = f"Output size equals internal size {self.dim}"
self.weight = nn.Parameter(torch.randn(self.dim, self.dim, dtype=self.dtype)) self.weight = nn.Parameter(torch.randn(self.dim, self.dim, dtype=self.dtype))
# self.scale = nn.Parameter(torch.randn(1, dtype=self.dtype.to_real())+0.5)
def reset_parameters(self): def reset_parameters(self):
q, _ = torch.linalg.qr(self.weight) q, _ = torch.linalg.qr(self.weight)
self.weight.data = q self.weight.data = q
# def get_M(self): # def get_M(self):
# return self.U @ self.sigma @ self.V # return self.U @ self.sigma @ self.V
def forward(self, x): def forward(self, x):
return self.crop(self.pad(x) @ self.weight) return self.crop(self.pad(x) @ self.weight)
class SemiUnitaryLayer(nn.Module): class ONNRect(nn.Module):
def __init__(self, input_dim, output_dim, dtype=None): def __init__(self, input_dim, output_dim, square=False, dtype=None):
super(SemiUnitaryLayer, self).__init__() super(ONNRect, self).__init__()
self.input_dim = input_dim self.input_dim = input_dim
self.output_dim = output_dim self.output_dim = output_dim
# Create a larger square matrix for QR decomposition if square:
self.weight = nn.Parameter(torch.randn(max(input_dim, output_dim), max(input_dim, output_dim), dtype=dtype)) dim = max(input_dim, output_dim)
self.scale = nn.Parameter(torch.tensor(1.0, dtype=dtype.to_real())) self.weight = nn.Parameter(torch.randn(dim, dim, dtype=dtype))
self.reset_parameters()
# zero pad input to internal size if smaller
if self.input_dim < dim:
self.pad = lambda x: F.pad(x, ((dim - self.input_dim) // 2, (dim - self.input_dim + 1) // 2))
self.pad.__doc__ = f"Zero pad input from {self.input_dim} to {dim}"
else:
self.pad = lambda x: x
self.pad.__doc__ = f"Input size equals internal size {dim}"
# crop output to desired size
if self.output_dim < dim:
self.crop = lambda x: x[
:, (dim - self.output_dim) // 2 : (x.shape[1] - (dim - self.output_dim + 1) // 2)
]
self.crop.__doc__ = f"Crop output from {dim} to {self.output_dim}"
else:
self.crop = lambda x: x
self.crop.__doc__ = f"Output size equals internal size {dim}"
def reset_parameters(self):
# Ensure the weights are unitary by QR decomposition
q, _ = torch.linalg.qr(self.weight)
# A = QR with A being a complex square matrix -> Q is unitary, R is upper triangular
# truncate the matrix to the desired size
if self.input_dim > self.output_dim:
self.weight.data = q[: self.input_dim, : self.output_dim]
else: else:
self.weight.data = q[: self.output_dim, : self.input_dim].t() self.weight = nn.Parameter(torch.randn(output_dim, input_dim, dtype=dtype))
... self.pad = lambda x: x
self.pad.__doc__ = "No padding"
self.crop = lambda x: x
self.crop.__doc__ = "No cropping"
def forward(self, x): def forward(self, x):
with torch.no_grad(): x = self.pad(x)
scale = torch.clamp(self.scale, 0.0, 1.0) out = self.crop((self.weight @ x.mT).mT)
out = torch.matmul(x, scale * self.weight)
return out return out
def __repr__(self): # def __repr__(self):
return f"SemiUnitaryLayer({self.input_dim}, {self.output_dim})" # return f"ONNRect({self.input_dim}, {self.output_dim})"
# class SaturableAbsorberLambertW(nn.Module): # class SaturableAbsorberLambertW(nn.Module):
@@ -336,6 +352,19 @@ class DropoutComplex(nn.Module):
return self.dropout(x) return self.dropout(x)
class Scale(nn.Module):
def __init__(self, size):
super(Scale, self).__init__()
self.size = size
self.scale = nn.Parameter(torch.ones(size, dtype=torch.float32))
def forward(self, x):
return x * self.scale
def __repr__(self):
return f"Scale({self.size})"
class Identity(nn.Module): class Identity(nn.Module):
""" """
implements the "activation" function implements the "activation" function
@@ -348,6 +377,7 @@ class Identity(nn.Module):
def forward(self, x): def forward(self, x):
return x return x
class PowRot(nn.Module): class PowRot(nn.Module):
def __init__(self, bias=False): def __init__(self, bias=False):
super(PowRot, self).__init__() super(PowRot, self).__init__()
@@ -359,15 +389,75 @@ class PowRot(nn.Module):
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
if x.is_complex(): if x.is_complex():
return x * torch.exp(-self.scale*1j*x.abs().square()+self.bias.to(dtype=x.dtype)) return x * torch.exp(-self.scale * 1j * x.abs().square() + self.bias.to(dtype=x.dtype))
else: else:
return x return x
class MZISingle(nn.Module):
def __init__(self, bias, size, func=None):
super(MZISingle, self).__init__()
self.omega = nn.Parameter(torch.randn(size))
self.phi = nn.Parameter(torch.randn(size))
self.func = func or (lambda x: x.abs().square()) # default to |z|^2
def forward(self, x: torch.Tensor):
return x * torch.exp(1j * self.phi) * torch.sin(self.omega + self.func(x))
class EOActivation(nn.Module):
def __init__(self, bias, size=None):
# 10.1109/SiPhotonics60897.2024.10543376
super(EOActivation, self).__init__()
if size is None:
raise ValueError("Size must be specified")
self.size = size
self.alpha = nn.Parameter(torch.ones(size))
self.V_bias = nn.Parameter(torch.ones(size))
self.gain = nn.Parameter(torch.ones(size))
# if bias:
# self.phase_bias = nn.Parameter(torch.zeros(size))
# else:
# self.register_buffer("phase_bias", torch.zeros(size))
self.register_buffer("phase_bias", torch.clamp(torch.ones(size) + torch.randn(size)*0.1, 0, 1)*torch.pi)
self.register_buffer("responsivity", torch.ones(size)*0.9)
self.register_buffer("V_pi", torch.ones(size)*3)
self.reset_weights()
def reset_weights(self):
if "alpha" in self._parameters:
self.alpha.data = torch.ones(self.size)*0.5
if "V_pi" in self._parameters:
self.V_pi.data = torch.ones(self.size)*3
if "V_bias" in self._parameters:
self.V_bias.data = torch.zeros(self.size)
if "gain" in self._parameters:
self.gain.data = torch.ones(self.size)
if "responsivity" in self._parameters:
self.responsivity.data = torch.ones(self.size)*0.9
if "bias" in self._parameters:
self.phase_bias.data = torch.zeros(self.size)
def forward(self, x: torch.Tensor):
phi_b = torch.pi * self.V_bias / (self.V_pi + 1e-8)
g_phi = torch.pi * (self.alpha * self.gain * self.responsivity) / (self.V_pi + 1e-8)
intermediate = g_phi * x.abs().square() + phi_b
return (
1j
* torch.sqrt(1 - self.alpha)
* torch.exp(-0.5j * (intermediate + self.phase_bias))
* torch.cos(0.5 * intermediate)
* x
)
class Pow(nn.Module): class Pow(nn.Module):
""" """
implements the activation function implements the activation function
M(z) = ||z||^2 + b M(z) = ||z||^2 + b
""" """
def __init__(self, bias=False): def __init__(self, bias=False):
super(Pow, self).__init__() super(Pow, self).__init__()
if bias: if bias:
@@ -375,7 +465,6 @@ class Pow(nn.Module):
else: else:
self.register_buffer("bias", torch.tensor(0.0)) self.register_buffer("bias", torch.tensor(0.0))
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
return x.abs().square().add(self.bias).to(dtype=x.dtype) return x.abs().square().add(self.bias).to(dtype=x.dtype)
@@ -395,7 +484,7 @@ class Mag(nn.Module):
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
return x.abs().add(self.bias).to(dtype=x.dtype) return x.abs().add(self.bias).to(dtype=x.dtype)
class MagScale(nn.Module): class MagScale(nn.Module):
def __init__(self, bias=False): def __init__(self, bias=False):
@@ -404,10 +493,11 @@ class MagScale(nn.Module):
self.bias = nn.Parameter(torch.tensor(0.0)) self.bias = nn.Parameter(torch.tensor(0.0))
else: else:
self.register_buffer("bias", torch.tensor(0.0)) self.register_buffer("bias", torch.tensor(0.0))
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
return x.abs().add(self.bias).to(dtype=x.dtype).sin().mul(x) return x.abs().add(self.bias).to(dtype=x.dtype).sin().mul(x)
class PowScale(nn.Module): class PowScale(nn.Module):
def __init__(self, bias=False): def __init__(self, bias=False):
super(PowScale, self).__init__() super(PowScale, self).__init__()
@@ -415,7 +505,7 @@ class PowScale(nn.Module):
self.bias = nn.Parameter(torch.tensor(0.0)) self.bias = nn.Parameter(torch.tensor(0.0))
else: else:
self.register_buffer("bias", torch.tensor(0.0)) self.register_buffer("bias", torch.tensor(0.0))
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
return x.mul(x.abs().square().add(self.bias).to(dtype=x.dtype).sin()) return x.mul(x.abs().square().add(self.bias).to(dtype=x.dtype).sin())
@@ -486,10 +576,10 @@ __all__ = [
complex_mse_loss, complex_mse_loss,
UnitaryLayer, UnitaryLayer,
unitary, unitary,
energy_conserving,
clamp, clamp,
ONN, ONN,
ONNMiller, ONNRect,
SemiUnitaryLayer,
DropoutComplex, DropoutComplex,
Identity, Identity,
Pow, Pow,
@@ -498,7 +588,9 @@ __all__ = [
ModReLU, ModReLU,
CReLU, CReLU,
ZReLU, ZReLU,
MZISingle,
EOActivation,
# SaturableAbsorberLambertW, # SaturableAbsorberLambertW,
# SaturableAbsorber, # SaturableAbsorber,
# SpreadLayer, # SpreadLayer,
] ]