update submodule configuration and enhance model settings; add eye diagram functionality

This commit is contained in:
Joseph Hopfmüller
2024-12-02 18:50:43 +01:00
parent aa2e7a4cb4
commit 297e9e8d7f
7 changed files with 626 additions and 249 deletions

1
.gitmodules vendored
View File

@@ -1,3 +1,4 @@
[submodule "pypho"]
path = pypho
url = git@gitlab.lrz.de:000000003B9B3E61/pypho.git
branch = main

View File

@@ -258,12 +258,12 @@ class HyperTraining:
f"model_hidden_dim_{i}",
self.model_settings.n_hidden_nodes,
)
layers.append(util.complexNN.SemiUnitaryLayer(last_dim, hidden_dim, dtype=dtype))
layers.append(util.complexNN.ONNRect(last_dim, hidden_dim, dtype=dtype))
last_dim = hidden_dim
layers.append(getattr(util.complexNN, afunc)())
n_nodes += last_dim
layers.append(util.complexNN.SemiUnitaryLayer(last_dim, self.model_settings.output_dim, dtype=dtype))
layers.append(util.complexNN.ONNRect(last_dim, self.model_settings.output_dim, dtype=dtype))
model = nn.Sequential(*layers)

View File

@@ -11,7 +11,7 @@ class GlobalSettings:
# data settings
@dataclass
class DataSettings:
config_path: str # = "data/*-128-16384-100000-0-0-17-0-PAM4-0.ini"
config_path: tuple | list | str # = "data/*-128-16384-100000-0-0-17-0-PAM4-0.ini"
dtype: tuple = ("complex64", "float64")
symbols: tuple | float | int = 8
output_size: tuple | float | int = 64
@@ -39,7 +39,7 @@ class PytorchSettings:
summary_dir: str = ".runs"
write_every: int = 10
head_symbols: int = 40
eye_symbols: int = 400
eye_symbols: int = 1000
# model settings
@@ -52,13 +52,16 @@ class ModelSettings:
overrides: dict = field(default_factory=dict)
dropout_prob: float | None = None
model_layer_function: str | None = None
scale: bool = False
model_layer_kwargs: dict | None = None
model_layer_parametrizations: list= field(default_factory=list)
@dataclass
class OptimizerSettings:
optimizer: tuple | str = ("Adam", "RMSprop", "SGD")
learning_rate: tuple | float = (1e-5, 1e-1)
optimizer_kwargs: dict | None = None
# learning_rate: tuple | float = (1e-5, 1e-1)
scheduler: str | None = None
scheduler_kwargs: dict | None = None

View File

@@ -1,8 +1,10 @@
import copy
from datetime import datetime
from pathlib import Path
import random
from typing import Literal
import matplotlib
from matplotlib.colors import LinearSegmentedColormap
import torch.nn.utils.parametrize
try:
@@ -50,6 +52,7 @@ class regenerator(nn.Module):
self,
*dims,
layer_function=util.complexNN.ONN,
layer_kwargs: dict | None = None,
layer_parametrizations: list[dict] = None,
# [
# {
@@ -64,6 +67,7 @@ class regenerator(nn.Module):
activation_function=util.complexNN.Pow,
dtype=torch.float64,
dropout_prob=0.01,
scale=False,
**kwargs,
):
super(regenerator, self).__init__()
@@ -74,39 +78,57 @@ class regenerator(nn.Module):
raise ValueError("dims must be provided")
self._n_hidden_layers = len(dims) - 2
self._layers = nn.Sequential()
if layer_kwargs is None:
layer_kwargs = {}
# self.powers = []
for i in range(self._n_hidden_layers + 1):
self._layers.append(layer_function(dims[i], dims[i + 1], dtype=dtype))
if scale:
self._layers.append(util.complexNN.Scale(dims[i]))
self._layers.append(layer_function(dims[i], dims[i + 1], dtype=dtype, **layer_kwargs))
if i < self._n_hidden_layers:
if dropout_prob is not None:
self._layers.append(util.complexNN.DropoutComplex(p=dropout_prob))
self._layers.append(activation_function())
self._layers.append(activation_function(bias=True, size=dims[i + 1]))
self._layers.append(util.complexNN.Scale(dims[-1]))
# add parametrizations
if layer_parametrizations is not None:
for layer in self._layers:
for layer_parametrization in layer_parametrizations:
tensor_name = layer_parametrization.get("tensor_name", None)
parametrization = layer_parametrization.get("parametrization", None)
param_kwargs = layer_parametrization.get("kwargs", {})
if (
tensor_name is not None
and tensor_name in self._layers[-1]._parameters
and parametrization is not None
):
parametrization(self._layers[-1], tensor_name, **param_kwargs)
if tensor_name is not None and tensor_name in layer._parameters and parametrization is not None:
parametrization(layer, tensor_name, **param_kwargs)
def forward(self, input_x):
# def __call__(self, input_x, **kwargs):
# return self.forward(input_x, **kwargs)
def forward(self, input_x, trace_powers=False):
x = input_x
if trace_powers:
powers = [x.abs().square().sum()]
# check if tracing
if torch.jit.is_tracing():
for layer in self._layers:
x = layer(x)
if trace_powers:
powers.append(x.abs().square().sum())
else:
# with torch.nn.utils.parametrize.cached():
for layer in self._layers:
x = layer(x)
if trace_powers:
powers.append(x.abs().square().sum())
if trace_powers:
return x, powers
return x
def traverse_dict_update(target, source):
for k, v in source.items():
if isinstance(v, dict):
@@ -119,6 +141,7 @@ def traverse_dict_update(target, source):
except TypeError:
target.__dict__[k] = v
class Trainer:
def __init__(
self,
@@ -142,7 +165,7 @@ class Trainer:
OptimizerSettings,
PytorchSettings,
regenerator,
torch.nn.utils.parametrizations.orthogonal
torch.nn.utils.parametrizations.orthogonal,
])
if self.resume:
self.checkpoint_dict = torch.load(checkpoint_path, weights_only=True)
@@ -206,6 +229,11 @@ class Trainer:
}
def define_model(self, model_kwargs=None):
if self.resume:
model_kwargs = self.checkpoint_dict["model_kwargs"]
else:
model_kwargs = model_kwargs
if model_kwargs is None:
n_hidden_layers = self.model_settings.n_hidden_layers
@@ -228,6 +256,7 @@ class Trainer:
"activation_function": afunc,
"dtype": dtype,
"dropout_prob": self.model_settings.dropout_prob,
"scale": self.model_settings.scale,
}
else:
self.model_kwargs = model_kwargs
@@ -237,9 +266,12 @@ class Trainer:
# dims = self.model_kwargs.pop("dims")
self.model = regenerator(**self.model_kwargs)
if self.writer is not None:
self.writer.add_graph(self.model, torch.zeros(1, input_dim, dtype=dtype))
self.model = self.model.to(self.pytorch_settings.device)
if self.resume:
self.model.load_state_dict(self.checkpoint_dict["model_state_dict"], strict=False)
def get_sliced_data(self, override=None):
symbols = self.data_settings.symbols
@@ -253,11 +285,13 @@ class Trainer:
dtype = getattr(torch, self.data_settings.dtype)
num_symbols = None
config_path = self.data_settings.config_path
if override is not None:
num_symbols = override.get("num_symbols", None)
config_path = override.get("config_path", config_path)
# get dataset
dataset = FiberRegenerationDataset(
file_path=self.data_settings.config_path,
file_path=config_path,
symbols=symbols,
output_dim=data_size,
target_delay=in_out_delay,
@@ -330,10 +364,11 @@ class Trainer:
task = progress.add_task("-.---e--", total=len(train_loader))
progress.start()
running_loss2 = 0.0
# running_loss2 = 0.0
running_loss = 0.0
self.model.train()
for batch_idx, (x, y) in enumerate(train_loader):
loader_len = len(train_loader)
for batch_idx, (x, y, _) in enumerate(train_loader):
self.model.zero_grad(set_to_none=True)
x, y = (
x.to(self.pytorch_settings.device),
@@ -344,24 +379,23 @@ class Trainer:
loss_value = loss.item()
loss.backward()
optimizer.step()
running_loss2 += loss_value
# running_loss2 += loss_value
running_loss += loss_value
if enable_progress:
progress.update(task, advance=1, description=f"{loss_value:.3e}")
progress.update(task, advance=1, description=f"{running_loss/(batch_idx+1):.3e}")
if batch_idx % self.pytorch_settings.write_every == 0:
self.writer.add_scalar(
"training loss",
running_loss2 / (self.pytorch_settings.write_every if batch_idx > 0 else 1),
epoch * len(train_loader) + batch_idx,
running_loss / (batch_idx + 1),
epoch * loader_len + batch_idx,
)
running_loss2 = 0.0
if enable_progress:
progress.stop()
return running_loss / len(train_loader)
return running_loss / (batch_idx + 1)
def eval_model(self, valid_loader, epoch, enable_progress=True):
if enable_progress:
@@ -384,7 +418,7 @@ class Trainer:
self.model.eval()
running_error = 0
with torch.no_grad():
for batch_idx, (x, y) in enumerate(valid_loader):
for batch_idx, (x, y, _) in enumerate(valid_loader):
x, y = (
x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device),
@@ -395,15 +429,17 @@ class Trainer:
running_error += error_value
if enable_progress:
progress.update(task, advance=1, description=f"{error_value:.3e}")
progress.update(task, advance=1, description=f"{error_value/(batch_idx+1):.3e}")
running_error /= (batch_idx+1)
running_error /= len(valid_loader)
self.writer.add_scalar(
"eval loss",
running_error,
epoch,
)
if (epoch + 1) % 10 == 0 or epoch < 10:
# plotting is slow, so only do it every 10 epochs
title_append, subtitle = self.build_title(epoch + 1)
self.writer.add_figure(
"fiber response",
@@ -426,45 +462,74 @@ class Trainer:
),
epoch + 1,
)
self.writer_histograms(epoch + 1)
self.writer.add_figure(
"powers",
self.plot_model_response(
model=self.model,
title_append=title_append,
subtitle=subtitle,
mode="powers",
show=False,
),
epoch + 1,
)
self.write_parameters(epoch + 1)
self.writer.flush()
if enable_progress:
progress.stop()
return running_error
def run_model(self, model, loader):
def run_model(self, model, loader, trace_powers=False):
model.eval()
xs = []
ys = []
y_preds = []
fiber_out = []
fiber_in = []
regen = []
timestamps = []
with torch.no_grad():
model = model.to(self.pytorch_settings.device)
for x, y in loader:
for x, y, timestamp in loader:
x, y = (
x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device),
)
y_pred = model(x).cpu()
if trace_powers:
y_pred, powers = model(x, trace_powers).cpu()
else:
y_pred = model(x, trace_powers).cpu()
# x = x.cpu()
# y = y.cpu()
y_pred = y_pred.view(y_pred.shape[0], -1, 2)
y = y.view(y.shape[0], -1, 2)
x = x.view(x.shape[0], -1, 2)
xs.append(x[:, 0, :].squeeze())
ys.append(y.squeeze())
y_preds.append(y_pred.squeeze())
# timestamp = timestamp.view(-1, 1)
fiber_out.append(x[:, x.shape[1] // 2, :].squeeze())
fiber_in.append(y.squeeze())
regen.append(y_pred.squeeze())
timestamps.append(timestamp.squeeze())
xs = torch.vstack(xs).cpu()
ys = torch.vstack(ys).cpu()
y_preds = torch.vstack(y_preds).cpu()
return ys, xs, y_preds
fiber_out = torch.vstack(fiber_out).cpu()
fiber_in = torch.vstack(fiber_in).cpu()
regen = torch.vstack(regen).cpu()
timestamps = torch.concat(timestamps).cpu()
if trace_powers:
return fiber_in, fiber_out, regen, timestamps, powers
return fiber_in, fiber_out, regen, timestamps
def writer_histograms(self, epoch, attributes=["weight", "weight_U", "weight_V", "bias", "sigma", "scale"]):
def write_parameters(self, epoch, attributes: list[str] | tuple[str] = None):
for i, layer in enumerate(self.model._layers):
tag = f"layer {i}"
for attribute in attributes:
if hasattr(layer, attribute):
if hasattr(layer, "parametrizations"):
attribute_pool = set(layer.parametrizations._modules) | set(layer._parameters)
else:
attribute_pool = set(layer._parameters)
for attribute in attribute_pool:
plot = (attributes is None) or (attribute in attributes)
if plot:
vals: np.ndarray = getattr(layer, attribute).detach().cpu().numpy().flatten()
if vals.ndim <= 1 and len(vals) == 1:
if np.iscomplexobj(vals):
@@ -483,14 +548,11 @@ class Trainer:
if self.writer is None:
self.setup_tb_writer()
if self.resume:
model_kwargs = self.checkpoint_dict["model_kwargs"]
else:
model_kwargs = None
self.define_model()
self.define_model(model_kwargs=model_kwargs)
print(f"number of parameters (trainable): {sum(p.numel() for p in self.model.parameters())} ({sum(p.numel() for p in self.model.parameters() if p.requires_grad)})")
print(
f"number of parameters (trainable): {sum(p.numel() for p in self.model.parameters())} ({sum(p.numel() for p in self.model.parameters() if p.requires_grad)})"
)
title_append, subtitle = self.build_title(0)
@@ -515,36 +577,55 @@ class Trainer:
),
0,
)
self.writer_histograms(0)
self.writer.add_figure(
"powers",
self.plot_model_response(
model=self.model,
title_append=title_append,
subtitle=subtitle,
mode="powers",
show=False,
),
0,
)
self.write_parameters(0)
self.writer.add_text("datasets", '\n'.join(self.data_settings.config_path))
self.writer.flush()
train_loader, valid_loader = self.get_sliced_data()
optimizer_name = self.optimizer_settings.optimizer
lr = self.optimizer_settings.learning_rate
# lr = self.optimizer_settings.learning_rate
self.optimizer: optim.Optimizer = getattr(optim, optimizer_name)(self.model.parameters(), lr=lr)
self.optimizer: optim.Optimizer = getattr(optim, optimizer_name)(
self.model.parameters(), **self.optimizer_settings.optimizer_kwargs
)
if self.optimizer_settings.scheduler is not None:
self.scheduler = getattr(optim.lr_scheduler, self.optimizer_settings.scheduler)(
self.optimizer, **self.optimizer_settings.scheduler_kwargs
)
if self.resume:
try:
self.scheduler.load_state_dict(self.checkpoint_dict["scheduler_state_dict"])
except ValueError:
pass
self.writer.add_scalar("learning rate", self.scheduler.get_last_lr()[0], -1)
# if self.resume:
# try:
# self.scheduler.load_state_dict(self.checkpoint_dict["scheduler_state_dict"])
# except ValueError:
# pass
self.writer.add_scalar("learning rate", self.optimizer.param_groups[0]["lr"], -1)
if not self.resume:
self.best = self.build_checkpoint_dict()
else:
self.best = self.checkpoint_dict
self.model.load_state_dict(self.best["model_state_dict"], strict=False)
try:
self.optimizer.load_state_dict(self.best["optimizer_state_dict"])
except ValueError:
pass
self.best["loss"] = float("inf")
# self.model.load_state_dict(self.best["model_state_dict"], strict=False)
# try:
# self.optimizer.load_state_dict(self.best["optimizer_state_dict"])
# except ValueError:
# pass
for epoch in range(self.best["epoch"] + 1, self.pytorch_settings.epochs):
enable_progress = True
@@ -562,12 +643,8 @@ class Trainer:
enable_progress=enable_progress,
)
if self.optimizer_settings.scheduler is not None:
lr_old = self.scheduler.get_last_lr()
self.scheduler.step(loss)
lr_new = self.scheduler.get_last_lr()
if lr_old[0] != lr_new[0]:
self.writer.add_scalar("learning rate", lr_new[0], epoch)
self.writer.add_scalar("learning rate", self.optimizer.param_groups[0]["lr"], epoch)
if self.pytorch_settings.save_models and self.model is not None:
save_path = (
Path(self.pytorch_settings.model_dir) / f"{self.writer.get_logdir().split('/')[-1]}_{epoch}.tar"
@@ -588,7 +665,28 @@ class Trainer:
self.writer.close()
return self.best
def _plot_model_response_eye(self, *signals, labels=None, sps=None, title_append="", subtitle="", show=True):
def _plot_model_response_powers(self, powers, layer_names, title_append="", subtitle="", show=True):
powers = [power / powers[0] for power in powers]
fig, ax = plt.subplots()
fig.set_figwidth(18)
fig.suptitle(
f"Energy conservation{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}"
)
ax.semilogy(powers, marker="o")
ax.set_xticks(range(len(layer_names)), layer_names, rotation=90)
ax.set_xlabel("Layer")
ax.set_ylabel("Normailzed Power")
ax.grid(which="major", axis="x")
ax.grid(which="major", axis="y")
ax.grid(which="minor", axis="y", linestyle=":")
fig.tight_layout()
if show:
plt.show()
return fig
def _plot_model_response_eye(
self, *signals, timestamps, labels=None, sps=None, title_append="", subtitle="", show=True
):
if sps is None:
raise ValueError("sps must be provided")
if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))):
@@ -603,16 +701,67 @@ class Trainer:
if not any(labels):
labels = [f"signal {i + 1}" for i in range(len(signals))]
x_bins = np.linspace(0, 2, 2 * sps, endpoint=False)
y_bins = np.zeros((2 * len(signals), 1000))
eye_data = np.zeros((2 * len(signals), 1000, 2 * sps))
# signals = [signal.cpu().numpy() for signal in signals]
for i in range(len(signals) * 2):
eye_signal = signals[i // 2][:, i % 2] # x, y, x, y, ...
eye_signal = np.real(np.square(np.abs(eye_signal)))
data_min = np.min(eye_signal)
data_max = np.max(eye_signal)
y_bins[i] = np.linspace(data_min, data_max, 1000, endpoint=False)
for j in range(len(timestamps)):
t = timestamps[j] / sps
val = eye_signal[j]
x = np.digitize(t % 2, x_bins) - 1
y = np.digitize(val, y_bins[i]) - 1
eye_data[i][y][x] += 1
cmap = LinearSegmentedColormap.from_list(
"eyemap",
[
(0, "white"),
(0.001, "dodgerblue"),
(0.1, "blue"),
(0.2, "cyan"),
(0.5, "lime"),
(0.8, "gold"),
(1, "red"),
],
)
# ordering = np.argsort(timestamps)
# signals = [signal[ordering] for signal in signals]
# timestamps = timestamps[ordering]
fig, axs = plt.subplots(1, 2 * len(signals), sharex=True, sharey=True)
fig.set_figwidth(18)
fig.suptitle(f"Eye diagram{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}")
xaxis = np.linspace(0, 2, 2 * sps, endpoint=False)
for j, (label, signal) in enumerate(zip(labels, signals)):
# xaxis = timestamps / sps
# xaxis = np.arange(2 * sps) / sps
for j, label in enumerate(labels):
x = eye_data[2 * j]
y = eye_data[2 * j + 1]
# x, y = signal.T
# signal = signal.cpu().numpy()
for i in range(len(signal) // sps - 1):
x, y = signal[i * sps : (i + 2) * sps].T
axs[0 + 2 * j].plot(xaxis, np.abs(x) ** 2, color=f"C{j}", alpha=1 / (len(signal) // sps) * 10)
axs[1 + 2 * j].plot(xaxis, np.abs(y) ** 2, color=f"C{j}", alpha=1 / (len(signal) // sps) * 10)
# for i in range(len(signal) // sps - 1):
# x, y = signal[i * sps : (i + 2) * sps].T
# axs[0 + 2 * j].scatter((timestamps/sps) % 2, np.abs(x) ** 2, color=f"C{j}", alpha=1 / (len(signal) // sps) * 10, s=1)
# axs[1 + 2 * j].scatter((timestamps/sps) % 2, np.abs(y) ** 2, color=f"C{j}", alpha=1 / (len(signal) // sps) * 10, s=1)
axs[0 + 2 * j].imshow(
x, aspect="auto", cmap=cmap, origin="lower", extent=[0, 2, y_bins[2 * j][0], y_bins[2 * j][-1]]
)
axs[1 + 2 * j].imshow(
y, aspect="auto", cmap=cmap, origin="lower", extent=[0, 2, y_bins[2 * j + 1][0], y_bins[2 * j + 1][-1]]
)
axs[0 + 2 * j].set_xlim((x_bins[0], x_bins[-1]))
axs[1 + 2 * j].set_xlim((x_bins[0], x_bins[-1]))
ymin = np.min(y_bins[:, 0])
ymax = np.max(y_bins[:, -1])
ydiff = ymax - ymin
axs[0 + 2 * j].set_ylim((ymin - 0.05 * ydiff, ymax + 0.05 * ydiff))
axs[1 + 2 * j].set_ylim((ymin - 0.05 * ydiff, ymax + 0.05 * ydiff))
axs[0 + 2 * j].set_title(label + " x")
axs[1 + 2 * j].set_title(label + " y")
axs[0 + 2 * j].set_xlabel("Symbol")
@@ -627,7 +776,9 @@ class Trainer:
plt.show()
return fig
def _plot_model_response_head(self, *signals, labels=None, sps=None, title_append="", subtitle="", show=True):
def _plot_model_response_head(
self, *signals, timestamps, labels=None, sps=None, title_append="", subtitle="", show=True
):
if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))):
labels = [labels]
else:
@@ -640,19 +791,29 @@ class Trainer:
if not any(labels):
labels = [f"signal {i + 1}" for i in range(len(signals))]
ordering = np.argsort(timestamps)
signals = [signal[ordering] for signal in signals]
timestamps = timestamps[ordering]
fig, axs = plt.subplots(1, 2, sharex=True, sharey=True)
fig.set_figwidth(18)
fig.set_figheight(4)
fig.suptitle(f"Fiber response{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}")
for i, ax in enumerate(axs):
ax: plt.Axes
for signal, label in zip(signals, labels):
if sps is not None:
xaxis = np.linspace(0, len(signal) / sps, len(signal), endpoint=False)
xaxis = timestamps / sps
else:
xaxis = np.arange(len(signal))
xaxis = timestamps
ax.plot(xaxis, np.abs(signal[:, i]) ** 2, label=label)
ax.set_xlabel("Sample" if sps is None else "Symbol")
ax.set_ylabel("normalized power")
ax.minorticks_on()
ax.tick_params(axis="y", which="minor", left=False, right=False)
ax.grid(which="major", axis="x")
ax.grid(which="minor", axis="x", linestyle=":")
ax.grid(which="major", axis="y")
ax.legend(loc="upper right")
fig.tight_layout()
if show:
@@ -664,22 +825,51 @@ class Trainer:
model=None,
title_append="",
subtitle="",
mode: Literal["eye", "head"] = "head",
mode: Literal["eye", "head", "powers"] = "head",
show=False,
):
if mode == "powers":
input_data = torch.ones(
1, 2 * self.data_settings.output_size, dtype=getattr(torch, self.data_settings.dtype)
).to(self.pytorch_settings.device)
model = model.to(self.pytorch_settings.device)
model.eval()
with torch.no_grad():
_, powers = model(input_data, trace_powers=True)
powers = [power.item() for power in powers]
layer_names = ["input", *[str(x).split("(")[0] for x in model._layers._modules.values()]]
# remove dropout layers
mask = [1 if "Dropout" not in layer_name else 0 for layer_name in layer_names]
layer_names = [layer_name for layer_name, m in zip(layer_names, mask) if m]
powers = [power for power, m in zip(powers, mask) if m]
fig = self._plot_model_response_powers(
powers, layer_names, title_append=title_append, subtitle=subtitle, show=show
)
return fig
data_settings_backup = copy.deepcopy(self.data_settings)
pytorch_settings_backup = copy.deepcopy(self.pytorch_settings)
self.data_settings.drop_first = 100 * 128
self.data_settings.drop_first = 99.5 + random.randint(0, 1000)
self.data_settings.shuffle = False
self.data_settings.train_split = 1.0
self.pytorch_settings.batchsize = (
self.pytorch_settings.eye_symbols if mode == "eye" else self.pytorch_settings.head_symbols
)
plot_loader, _ = self.get_sliced_data(override={"num_symbols": self.pytorch_settings.batchsize})
config_path = random.choice(self.data_settings.config_path)
fiber_length = int(float(str(config_path).split('-')[-7])/1000)
plot_loader, _ = self.get_sliced_data(
override={
"num_symbols": self.pytorch_settings.batchsize,
"config_path": config_path,
}
)
self.data_settings = data_settings_backup
self.pytorch_settings = pytorch_settings_backup
fiber_in, fiber_out, regen = self.run_model(model, plot_loader)
fiber_in, fiber_out, regen, timestamps = self.run_model(model, plot_loader)
fiber_in = fiber_in.view(-1, 2)
fiber_out = fiber_out.view(-1, 2)
regen = regen.view(-1, 2)
@@ -687,6 +877,7 @@ class Trainer:
fiber_in = fiber_in.numpy()
fiber_out = fiber_out.numpy()
regen = regen.numpy()
timestamps = timestamps.numpy()
# https://github.com/matplotlib/matplotlib/issues/27713#issue-2104110987
# https://github.com/matplotlib/matplotlib/issues/27713#issuecomment-1915497463
@@ -697,9 +888,10 @@ class Trainer:
fiber_in,
fiber_out,
regen,
timestamps=timestamps,
labels=("fiber in", "fiber out", "regen"),
sps=plot_loader.dataset.samples_per_symbol,
title_append=title_append,
title_append=title_append + f" ({fiber_length} km)",
subtitle=subtitle,
show=show,
)
@@ -709,9 +901,10 @@ class Trainer:
fiber_in,
fiber_out,
regen,
timestamps=timestamps,
labels=("fiber in", "fiber out", "regen"),
sps=plot_loader.dataset.samples_per_symbol,
title_append=title_append,
title_append=title_append + f" ({fiber_length} km)",
subtitle=subtitle,
show=show,
)

View File

@@ -1,3 +1,6 @@
import matplotlib
import numpy as np
import torch
from hypertraining.settings import (
GlobalSettings,
DataSettings,
@@ -7,16 +10,20 @@ from hypertraining.settings import (
)
from hypertraining.training import Trainer
import torch
# import torch
import json
import util
from rich import print as rprint
global_settings = GlobalSettings(
seed=42,
seed=0xC0FFEE,
)
data_settings = DataSettings(
config_path="data/*-128-16384-100000-0-0-17-0-PAM4-0.ini",
# config_path="data/*-128-16384-50000-0-0-17-0-PAM4-0.ini",
config_path=[f"data/*-128-16384-{length}-0-0-17-0-PAM4-0.ini" for length in (40000, 50000, 60000)],
dtype="complex64",
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
symbols=13, # study: single_core_regen_20241123_011232
@@ -25,7 +32,7 @@ data_settings = DataSettings(
shuffle=True,
in_out_delay=0,
xy_delay=0,
drop_first=128*64,
drop_first=128 * 64,
train_split=0.8,
)
@@ -45,55 +52,83 @@ model_settings = ModelSettings(
output_dim=2,
n_hidden_layers=4,
overrides={
"n_hidden_nodes_0": 8,
"n_hidden_nodes_1": 8,
"n_hidden_nodes_0": 4,
"n_hidden_nodes_1": 4,
"n_hidden_nodes_2": 4,
"n_hidden_nodes_3": 6,
"n_hidden_nodes_3": 4,
},
model_activation_func="PowScale",
# dropout_prob=0.01,
model_layer_function="ONN",
model_activation_func="EOActivation",
dropout_prob=0.01,
model_layer_function="ONNRect",
model_layer_kwargs={"square": True},
scale=True,
model_layer_parametrizations=[
{
"tensor_name": "weight",
"parametrization": torch.nn.utils.parametrizations.orthogonal,
"parametrization": util.complexNN.energy_conserving,
},
{
"tensor_name": "alpha",
"parametrization": util.complexNN.clamp,
},
{
"tensor_name": "gain",
"parametrization": util.complexNN.clamp,
"kwargs": {
"min": 0,
"max": float("inf"),
},
},
{
"tensor_name": "phase_bias",
"parametrization": util.complexNN.clamp,
"kwargs": {
"min": 0,
"max": 2 * torch.pi,
},
},
{
"tensor_name": "scales",
"parametrization": util.complexNN.clamp,
},
{
"tensor_name": "scale",
"parametrization": util.complexNN.clamp,
},
{
"tensor_name": "bias",
"parametrization": util.complexNN.clamp,
},
# {
# "tensor_name": "scale",
# "parametrization": util.complexNN.clamp,
# },
# {
# "tensor_name": "bias",
# "parametrization": util.complexNN.clamp,
# },
# {
# "tensor_name": "V",
# "parametrization": torch.nn.utils.parametrizations.orthogonal,
# },
# {
# "tensor_name": "S",
# "parametrization": util.complexNN.clamp,
# },
{
"tensor_name": "loss",
"parametrization": util.complexNN.clamp,
},
],
)
optimizer_settings = OptimizerSettings(
optimizer="Adam",
learning_rate=0.05,
optimizer="AdamW",
optimizer_kwargs={
"lr": 0.05,
"amsgrad": True,
# "weight_decay": 1e-7,
},
# learning_rate=0.05,
scheduler="ReduceLROnPlateau",
scheduler_kwargs={
"patience": 2**6,
"factor": 0.9,
"factor": 0.75,
# "threshold": 1e-3,
"min_lr": 1e-6,
"cooldown": 10,
},
)
def save_dict_to_file(dictionary, filename):
"""
Save the best dictionary to a JSON file.
@@ -103,28 +138,79 @@ def save_dict_to_file(dictionary, filename):
:param filename: Path to the JSON file where the dictionary will be saved.
:type filename: str
"""
with open(filename, 'w') as f:
with open(filename, "w") as f:
json.dump(dictionary, f, indent=4)
def sweep_lengths(*lengths, model=None):
assert model is not None, "Model must be provided."
model = model
fiber_ins = {}
fiber_outs = {}
regens = {}
timestampss = {}
for length in lengths:
trainer = Trainer(
checkpoint_path=model,
settings_override={
"data_settings": {
"config_path": f"data/*-128-16384-{length}-0-0-17-0-PAM4-0.ini",
"train_split": 1,
"shuffle": True,
}
},
)
trainer.define_model()
loader, _ = trainer.get_sliced_data()
fiber_in, fiber_out, regen, timestamps = trainer.run_model(trainer.model, loader=loader)
fiber_ins[length] = fiber_in
fiber_outs[length] = fiber_out
regens[length] = regen
timestampss[length] = timestamps
data = torch.zeros(2 * len(lengths), 2, fiber_out.shape[0])
channel_names = ["" for _ in range(2 * len(lengths))]
for li, length in enumerate(lengths):
data[2 * li, 0, :] = timestampss[length] / 128
data[2 * li, 1, :] = regens[length][:, 0].abs().square()
data[2 * li + 1, 0, :] = timestampss[length] / 128
data[2 * li + 1, 1, :] = regens[length][:, 1].abs().square()
channel_names[2 * li] = f"regen x {length}"
channel_names[2 * li + 1] = f"regen y {length}"
# get current backend
backend = matplotlib.get_backend()
matplotlib.use("TkCairo")
eye = util.eye_diagram.eye_diagram(data.to(dtype=torch.float32).detach().cpu().numpy(), channel_names=channel_names)
print_attrs = ("channel", "success", "min_area")
with np.printoptions(precision=3, suppress=True, formatter={'float': '{:0.3e}'.format}):
for result in eye.eye_stats:
print_dict = {attr: result[attr] for attr in print_attrs}
rprint(print_dict)
rprint()
eye.plot()
matplotlib.use(backend)
if __name__ == "__main__":
# sweep_lengths(30000, 40000, 50000, 60000, 70000, model=".models/best_20241202_143149.tar")
trainer = Trainer(
global_settings=global_settings,
data_settings=data_settings,
pytorch_settings=pytorch_settings,
model_settings=model_settings,
optimizer_settings=optimizer_settings,
checkpoint_path='.models/20241128_084935_8885.tar',
settings_override={
"model_settings": {
# "model_activation_func": "PowScale",
"dropout_prob": 0,
}
},
reset_epoch=True,
# checkpoint_path=".models/best_20241202_143149.tar",
# 20241202_143149
)
best = trainer.train()
save_dict_to_file(best, ".models/best_results.json")
...
trainer.train()

View File

@@ -17,3 +17,5 @@ from . import complexNN # noqa: F401
# from .complexNN import complex_sse_loss # noqa: F401
from . import misc # noqa: F401
from . import eye_diagram # noqa: F401

View File

@@ -4,23 +4,36 @@ import torch.nn.functional as F
# from torchlambertw.special import lambertw
def complex_mse_loss(input, target, power=False, reduction="mean"):
def complex_mse_loss(input, target, power=False, normalize=False, reduction="mean"):
"""
Compute the mean squared error between two complex tensors.
If power is set to True, the loss is computed as |input|^2 - |target|^2
"""
reduce = getattr(torch, reduction)
power_penalty = 0
if power:
input = (input * input.conj()).real.to(dtype=input.dtype.to_real())
target = (target * target.conj()).real.to(dtype=target.dtype.to_real())
if normalize:
power_penalty = ((input.max() - input.min()) - (target.max() - target.min())) ** 2
power_penalty += (input.min() - target.min()) ** 2
input = input - input.min()
input = input / input.max()
target = target - target.min()
target = target / target.max()
else:
if normalize:
power_penalty = (input.abs().max() - target.abs().max()) ** 2
input = input / input.abs().max()
target = target / target.abs().max()
if input.is_complex() and target.is_complex():
return reduce(torch.square(input.real - target.real) + torch.square(input.imag - target.imag))
return reduce(torch.square(input.real - target.real) + torch.square(input.imag - target.imag)) + power_penalty
elif input.is_complex() or target.is_complex():
raise ValueError("Input and target must have the same type (real or complex)")
else:
return F.mse_loss(input, target, reduction=reduction)
return F.mse_loss(input, target, reduction=reduction) + power_penalty
def complex_sse_loss(input, target):
@@ -53,23 +66,19 @@ class UnitaryLayer(nn.Module):
return f"UnitaryLayer({self.in_features}, {self.out_features})"
class _Unitary(nn.Module):
def forward(self, X:torch.Tensor):
def forward(self, X: torch.Tensor):
if X.ndim < 2:
raise ValueError(
"Only tensors with 2 or more dimensions are supported. "
f"Got a tensor of shape {X.shape}"
)
raise ValueError(f"Only tensors with 2 or more dimensions are supported. Got a tensor of shape {X.shape}")
n, k = X.size(-2), X.size(-1)
transpose = n<k
transpose = n < k
if transpose:
X = X.transpose(-2, -1)
q, r = torch.linalg.qr(X)
# q: torch.Tensor = q
# r: torch.Tensor = r
d = r.diagonal(dim1=-2, dim2=-1).sgn()
q*=d.unsqueeze(-2)
q *= d.unsqueeze(-2)
if transpose:
q = q.transpose(-2, -1)
if n == k:
@@ -80,6 +89,7 @@ class _Unitary(nn.Module):
# X.copy_(q)
return q
def unitary(module: nn.Module, name: str = "weight") -> nn.Module:
weight = getattr(module, name, None)
if not isinstance(weight, torch.Tensor):
@@ -95,19 +105,21 @@ def unitary(module: nn.Module, name: str = "weight") -> nn.Module:
nn.utils.parametrize.register_parametrization(module, name, unit)
return module
class _SpecialUnitary(nn.Module):
def __init__(self):
super().__init__()
def forward(self, X:torch.Tensor):
def forward(self, X: torch.Tensor):
n, k = X.size(-2), X.size(-1)
if n != k:
raise ValueError(f"Expected a square matrix. Got a tensor of shape {X.shape}")
q, _ = torch.linalg.qr(X)
q = q / torch.linalg.det(q).pow(1/n)
q = q / torch.linalg.det(q).pow(1 / n)
return q
def special_unitary(module: nn.Module, name: str = "weight") -> nn.Module:
weight = getattr(module, name, None)
if not isinstance(weight, torch.Tensor):
@@ -123,11 +135,13 @@ def special_unitary(module: nn.Module, name: str = "weight") -> nn.Module:
nn.utils.parametrize.register_parametrization(module, name, unit)
return module
class _Clamp(nn.Module):
def __init__(self, min, max):
super(_Clamp, self).__init__()
self.min = min
self.max = max
def forward(self, x):
if x.is_complex():
# clamp magnitude, ignore phase
@@ -145,43 +159,29 @@ def clamp(module: nn.Module, name: str = "scale", min=0, max=1) -> nn.Module:
return module
class ONNMiller(nn.Module):
def __init__(self, input_dim, output_dim, dtype=None) -> None:
super(ONNMiller, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.dtype = dtype
class _EnergyConserving(nn.Module):
def __init__(self):
super(_EnergyConserving, self).__init__()
self.dim = max(input_dim, output_dim)
def forward(self, X: torch.Tensor):
if X.ndim == 2:
X = X.unsqueeze(0)
spectral_norm = torch.linalg.svdvals(X)[:, 0]
return (X / spectral_norm).squeeze()
# zero pad input to internal size if smaller
if self.input_dim < self.dim:
self.pad = lambda x: F.pad(x, ((self.dim - self.input_dim) // 2, (self.dim - self.input_dim + 1) // 2))
else:
self.pad = lambda x: x
self.pad.__doc__ = f"Zero pad input from {self.input_dim} to {self.dim}"
# crop output to desired size
if self.output_dim < self.dim:
self.crop = lambda x: x[:, (self.dim - self.output_dim) // 2 : (x.shape[1] - (self.dim - self.output_dim + 1) // 2)]
else:
self.crop = lambda x: x
self.crop.__doc__ = f"Crop output from {self.dim} to {self.output_dim}"
def energy_conserving(module: nn.Module, name: str = "weight") -> nn.Module:
param = getattr(module, name, None)
if not isinstance(param, torch.Tensor):
raise ValueError(f"Module '{module}' has no parameter or buffer '{name}'")
self.U = nn.Parameter(torch.randn(self.dim, self.dim, dtype=self.dtype)) # -> parametrization: Unitary
self.S = nn.Parameter(torch.randn(self.dim, dtype=self.dtype)) # -> parametrization: Clamp (magnitude 0..1)
self.V = nn.Parameter(torch.randn(self.dim, self.dim, dtype=self.dtype)) # -> parametrization: Unitary
self.register_buffer("MZI_scale", torch.tensor(2, dtype=self.dtype.to_real()).sqrt())
# V is actually V.H, but
if not (2 <= param.ndim <= 3):
raise ValueError(f"Expected a matrix or batch of matrices. Got a tensor of {param.ndim} dimensions.")
unit = _EnergyConserving()
nn.utils.parametrize.register_parametrization(module, name, unit)
return module
def forward(self, x_in):
x = x_in
x = self.pad(x)
x = x @ self.U
x = x * (self.S.squeeze() / self.MZI_scale)
x = x @ self.V
x = self.crop(x)
return x
class ONN(nn.Module):
def __init__(self, input_dim, output_dim, dtype=None) -> None:
@@ -202,18 +202,21 @@ class ONN(nn.Module):
# crop output to desired size
if self.output_dim < self.dim:
self.crop = lambda x: x[:, (self.dim - self.output_dim) // 2 : (x.shape[1] - (self.dim - self.output_dim + 1) // 2)]
self.crop = lambda x: x[
:, (self.dim - self.output_dim) // 2 : (x.shape[1] - (self.dim - self.output_dim + 1) // 2)
]
self.crop.__doc__ = f"Crop output from {self.dim} to {self.output_dim}"
else:
self.crop = lambda x: x
self.crop.__doc__ = f"Output size equals internal size {self.dim}"
self.weight = nn.Parameter(torch.randn(self.dim, self.dim, dtype=self.dtype))
# self.scale = nn.Parameter(torch.randn(1, dtype=self.dtype.to_real())+0.5)
def reset_parameters(self):
q, _ = torch.linalg.qr(self.weight)
self.weight.data = q
# def get_M(self):
# return self.U @ self.sigma @ self.V
@@ -221,37 +224,50 @@ class ONN(nn.Module):
return self.crop(self.pad(x) @ self.weight)
class SemiUnitaryLayer(nn.Module):
def __init__(self, input_dim, output_dim, dtype=None):
super(SemiUnitaryLayer, self).__init__()
class ONNRect(nn.Module):
def __init__(self, input_dim, output_dim, square=False, dtype=None):
super(ONNRect, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
# Create a larger square matrix for QR decomposition
self.weight = nn.Parameter(torch.randn(max(input_dim, output_dim), max(input_dim, output_dim), dtype=dtype))
self.scale = nn.Parameter(torch.tensor(1.0, dtype=dtype.to_real()))
self.reset_parameters()
if square:
dim = max(input_dim, output_dim)
self.weight = nn.Parameter(torch.randn(dim, dim, dtype=dtype))
def reset_parameters(self):
# Ensure the weights are unitary by QR decomposition
q, _ = torch.linalg.qr(self.weight)
# A = QR with A being a complex square matrix -> Q is unitary, R is upper triangular
# truncate the matrix to the desired size
if self.input_dim > self.output_dim:
self.weight.data = q[: self.input_dim, : self.output_dim]
# zero pad input to internal size if smaller
if self.input_dim < dim:
self.pad = lambda x: F.pad(x, ((dim - self.input_dim) // 2, (dim - self.input_dim + 1) // 2))
self.pad.__doc__ = f"Zero pad input from {self.input_dim} to {dim}"
else:
self.weight.data = q[: self.output_dim, : self.input_dim].t()
...
self.pad = lambda x: x
self.pad.__doc__ = f"Input size equals internal size {dim}"
# crop output to desired size
if self.output_dim < dim:
self.crop = lambda x: x[
:, (dim - self.output_dim) // 2 : (x.shape[1] - (dim - self.output_dim + 1) // 2)
]
self.crop.__doc__ = f"Crop output from {dim} to {self.output_dim}"
else:
self.crop = lambda x: x
self.crop.__doc__ = f"Output size equals internal size {dim}"
else:
self.weight = nn.Parameter(torch.randn(output_dim, input_dim, dtype=dtype))
self.pad = lambda x: x
self.pad.__doc__ = "No padding"
self.crop = lambda x: x
self.crop.__doc__ = "No cropping"
def forward(self, x):
with torch.no_grad():
scale = torch.clamp(self.scale, 0.0, 1.0)
out = torch.matmul(x, scale * self.weight)
x = self.pad(x)
out = self.crop((self.weight @ x.mT).mT)
return out
def __repr__(self):
return f"SemiUnitaryLayer({self.input_dim}, {self.output_dim})"
# def __repr__(self):
# return f"ONNRect({self.input_dim}, {self.output_dim})"
# class SaturableAbsorberLambertW(nn.Module):
@@ -336,6 +352,19 @@ class DropoutComplex(nn.Module):
return self.dropout(x)
class Scale(nn.Module):
def __init__(self, size):
super(Scale, self).__init__()
self.size = size
self.scale = nn.Parameter(torch.ones(size, dtype=torch.float32))
def forward(self, x):
return x * self.scale
def __repr__(self):
return f"Scale({self.size})"
class Identity(nn.Module):
"""
implements the "activation" function
@@ -348,6 +377,7 @@ class Identity(nn.Module):
def forward(self, x):
return x
class PowRot(nn.Module):
def __init__(self, bias=False):
super(PowRot, self).__init__()
@@ -359,15 +389,75 @@ class PowRot(nn.Module):
def forward(self, x: torch.Tensor):
if x.is_complex():
return x * torch.exp(-self.scale*1j*x.abs().square()+self.bias.to(dtype=x.dtype))
return x * torch.exp(-self.scale * 1j * x.abs().square() + self.bias.to(dtype=x.dtype))
else:
return x
class MZISingle(nn.Module):
def __init__(self, bias, size, func=None):
super(MZISingle, self).__init__()
self.omega = nn.Parameter(torch.randn(size))
self.phi = nn.Parameter(torch.randn(size))
self.func = func or (lambda x: x.abs().square()) # default to |z|^2
def forward(self, x: torch.Tensor):
return x * torch.exp(1j * self.phi) * torch.sin(self.omega + self.func(x))
class EOActivation(nn.Module):
def __init__(self, bias, size=None):
# 10.1109/SiPhotonics60897.2024.10543376
super(EOActivation, self).__init__()
if size is None:
raise ValueError("Size must be specified")
self.size = size
self.alpha = nn.Parameter(torch.ones(size))
self.V_bias = nn.Parameter(torch.ones(size))
self.gain = nn.Parameter(torch.ones(size))
# if bias:
# self.phase_bias = nn.Parameter(torch.zeros(size))
# else:
# self.register_buffer("phase_bias", torch.zeros(size))
self.register_buffer("phase_bias", torch.clamp(torch.ones(size) + torch.randn(size)*0.1, 0, 1)*torch.pi)
self.register_buffer("responsivity", torch.ones(size)*0.9)
self.register_buffer("V_pi", torch.ones(size)*3)
self.reset_weights()
def reset_weights(self):
if "alpha" in self._parameters:
self.alpha.data = torch.ones(self.size)*0.5
if "V_pi" in self._parameters:
self.V_pi.data = torch.ones(self.size)*3
if "V_bias" in self._parameters:
self.V_bias.data = torch.zeros(self.size)
if "gain" in self._parameters:
self.gain.data = torch.ones(self.size)
if "responsivity" in self._parameters:
self.responsivity.data = torch.ones(self.size)*0.9
if "bias" in self._parameters:
self.phase_bias.data = torch.zeros(self.size)
def forward(self, x: torch.Tensor):
phi_b = torch.pi * self.V_bias / (self.V_pi + 1e-8)
g_phi = torch.pi * (self.alpha * self.gain * self.responsivity) / (self.V_pi + 1e-8)
intermediate = g_phi * x.abs().square() + phi_b
return (
1j
* torch.sqrt(1 - self.alpha)
* torch.exp(-0.5j * (intermediate + self.phase_bias))
* torch.cos(0.5 * intermediate)
* x
)
class Pow(nn.Module):
"""
implements the activation function
M(z) = ||z||^2 + b
"""
def __init__(self, bias=False):
super(Pow, self).__init__()
if bias:
@@ -375,7 +465,6 @@ class Pow(nn.Module):
else:
self.register_buffer("bias", torch.tensor(0.0))
def forward(self, x: torch.Tensor):
return x.abs().square().add(self.bias).to(dtype=x.dtype)
@@ -408,6 +497,7 @@ class MagScale(nn.Module):
def forward(self, x: torch.Tensor):
return x.abs().add(self.bias).to(dtype=x.dtype).sin().mul(x)
class PowScale(nn.Module):
def __init__(self, bias=False):
super(PowScale, self).__init__()
@@ -486,10 +576,10 @@ __all__ = [
complex_mse_loss,
UnitaryLayer,
unitary,
energy_conserving,
clamp,
ONN,
ONNMiller,
SemiUnitaryLayer,
ONNRect,
DropoutComplex,
Identity,
Pow,
@@ -498,6 +588,8 @@ __all__ = [
ModReLU,
CReLU,
ZReLU,
MZISingle,
EOActivation,
# SaturableAbsorberLambertW,
# SaturableAbsorber,
# SpreadLayer,