Compare commits

..

7 Commits

Author SHA1 Message Date
Joseph Hopfmüller
487288c923 define new activation functions and parametrizations 2024-11-29 15:51:25 +01:00
Joseph Hopfmüller
bdf6f5bfb8 clean up regen_no_hyper.py 2024-11-29 15:50:34 +01:00
Joseph Hopfmüller
e02662ed4f new optuna studies 2024-11-29 15:49:59 +01:00
Joseph Hopfmüller
fd7a0b9c31 using latest knowledge for hyperparameter search 2024-11-29 15:49:46 +01:00
Joseph Hopfmüller
ff32aefd52 minor fixes and changes 2024-11-29 15:49:10 +01:00
Joseph Hopfmüller
b156b9ceaf refactor hypertraining.py to improve model layer handling and response plotting; adjust data settings for batch processing 2024-11-29 15:48:27 +01:00
Joseph Hopfmüller
cfa08aae4e add training.py for defining and running models without hyperparametertuning 2024-11-29 15:48:18 +01:00
9 changed files with 1258 additions and 461 deletions

View File

@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:f3510d41f9f0605e438a09767c43edda38162601292be1207f50747117ae5479 oid sha256:746ea83013870351296e01e294905ee027291ef79fd78c1e6b69dd9ebaa1cba0
size 9863168 size 10240000

View File

@@ -19,7 +19,7 @@ import time
from matplotlib import pyplot as plt # noqa: F401 from matplotlib import pyplot as plt # noqa: F401
import numpy as np import numpy as np
import path_fix import add_pypho # noqa: F401
import pypho import pypho
default_config = f""" default_config = f"""
@@ -497,18 +497,18 @@ def plot_eye_diagram(
if __name__ == "__main__": if __name__ == "__main__":
path_fix.show_log() add_pypho.show_log()
config = get_config() config = get_config()
length_ranges = [1000, 10000] # length_ranges = [1000, 10000]
length_scales = [1, 2, 3, 4, 5, 6, 7, 8, 9] # length_scales = [1, 2, 3, 4, 5, 6, 7, 8, 9]
lengths = [ # lengths = [
length_scale * length_range # length_scale * length_range
for length_range in length_ranges # for length_range in length_ranges
for length_scale in length_scales # for length_scale in length_scales
] # ]
lengths.append(max(length_ranges)*10) # lengths.append(max(length_ranges)*10)
# length_loop(config, lengths) # length_loop(config, lengths)

View File

@@ -245,18 +245,18 @@ class HyperTraining:
dtype = getattr(torch, dtype) dtype = getattr(torch, dtype)
afunc = trial.suggest_categorical_optional("model_activation_func", self.model_settings.model_activation_func) afunc = trial.suggest_categorical_optional("model_activation_func", self.model_settings.model_activation_func)
# T0 = trial.suggest_float_optional("T0", self.model_settings.satabsT0 , log=True)
layers = [] layers = []
last_dim = input_dim last_dim = input_dim
n_nodes = last_dim n_nodes = last_dim
for i in range(n_layers): for i in range(n_layers):
if hidden_dim_override := self.model_settings.overrides.get(f"n_hidden_nodes_{i}", False): if hidden_dim_override := self.model_settings.overrides.get(f"n_hidden_nodes_{i}", False):
hidden_dim = trial.suggest_int_optional(f"model_hidden_dim_{i}", hidden_dim_override, force=True) hidden_dim = trial.suggest_int_optional(f"model_hidden_dim_{i}", hidden_dim_override)
else: else:
hidden_dim = trial.suggest_int_optional( hidden_dim = trial.suggest_int_optional(
f"model_hidden_dim_{i}", f"model_hidden_dim_{i}",
self.model_settings.n_hidden_nodes, self.model_settings.n_hidden_nodes,
# step=2,
) )
layers.append(util.complexNN.SemiUnitaryLayer(last_dim, hidden_dim, dtype=dtype)) layers.append(util.complexNN.SemiUnitaryLayer(last_dim, hidden_dim, dtype=dtype))
last_dim = hidden_dim last_dim = hidden_dim
@@ -642,6 +642,7 @@ class HyperTraining:
if show: if show:
plt.show() plt.show()
return fig
def _plot_model_response_head(self, *signals, labels=None, sps=None, title_append="", subtitle="", show=True): def _plot_model_response_head(self, *signals, labels=None, sps=None, title_append="", subtitle="", show=True):
if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))): if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))):
@@ -684,7 +685,7 @@ class HyperTraining:
): ):
data_settings_backup = copy.deepcopy(self.data_settings) data_settings_backup = copy.deepcopy(self.data_settings)
pytorch_settings_backup = copy.deepcopy(self.pytorch_settings) pytorch_settings_backup = copy.deepcopy(self.pytorch_settings)
self.data_settings.drop_first = 100 self.data_settings.drop_first = 100*128
self.data_settings.shuffle = False self.data_settings.shuffle = False
self.data_settings.train_split = 1.0 self.data_settings.train_split = 1.0
self.pytorch_settings.batchsize = ( self.pytorch_settings.batchsize = (
@@ -739,11 +740,15 @@ class HyperTraining:
@staticmethod @staticmethod
def build_title(trial: optuna.trial.Trial): def build_title(trial: optuna.trial.Trial):
title_append = f"for trial {trial.number}" title_append = f"for trial {trial.number}"
model_n_layers = util.misc.multi_getattr((trial.params, trial.user_attrs), "model_n_layers", 0) model_n_hidden_layers = util.misc.multi_getattr((trial.params, trial.user_attrs), "model_n_hidden_layers", 0)
model_hidden_dims = [ input_dim = util.misc.multi_getattr((trial.params, trial.user_attrs), "model_input_dim", 0)
model_dims = [
util.misc.multi_getattr((trial.params, trial.user_attrs), f"model_hidden_dim_{i}", 0) util.misc.multi_getattr((trial.params, trial.user_attrs), f"model_hidden_dim_{i}", 0)
for i in range(model_n_layers) for i in range(model_n_hidden_layers)
] ]
model_dims.insert(0, input_dim)
model_dims.append(2)
model_dims = [str(dim) for dim in model_dims]
model_activation_func = util.misc.multi_getattr( model_activation_func = util.misc.multi_getattr(
(trial.params, trial.user_attrs), (trial.params, trial.user_attrs),
"model_activation_func", "model_activation_func",
@@ -752,7 +757,7 @@ class HyperTraining:
model_dtype = util.misc.multi_getattr((trial.params, trial.user_attrs), "model_dtype", "unknown dtype") model_dtype = util.misc.multi_getattr((trial.params, trial.user_attrs), "model_dtype", "unknown dtype")
subtitle = ( subtitle = (
f"{model_n_layers} layers à ({', '.join(model_hidden_dims)}) units, {model_activation_func}, {model_dtype}" f"{model_n_hidden_layers+2} layers à ({', '.join(model_dims)}) units, {model_activation_func}, {model_dtype}"
) )
return title_append, subtitle return title_append, subtitle

View File

@@ -39,7 +39,7 @@ class PytorchSettings:
summary_dir: str = ".runs" summary_dir: str = ".runs"
write_every: int = 10 write_every: int = 10
head_symbols: int = 40 head_symbols: int = 40
eye_symbols: int = 1000 eye_symbols: int = 400
# model settings # model settings
@@ -48,8 +48,11 @@ class ModelSettings:
output_dim: int = 2 output_dim: int = 2
n_hidden_layers: tuple | int = 3 n_hidden_layers: tuple | int = 3
n_hidden_nodes: tuple | int = 8 n_hidden_nodes: tuple | int = 8
model_activation_func: tuple = "ModReLU" model_activation_func: tuple | str = "ModReLU"
overrides: dict = field(default_factory=dict) overrides: dict = field(default_factory=dict)
dropout_prob: float | None = None
model_layer_function: str | None = None
model_layer_parametrizations: list= field(default_factory=list)
@dataclass @dataclass

View File

@@ -0,0 +1,739 @@
import copy
from datetime import datetime
from pathlib import Path
from typing import Literal
import matplotlib
import torch.nn.utils.parametrize
try:
matplotlib.use("cairo")
except ImportError:
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
# import torch.nn.functional as F # mse_loss doesn't support complex numbers
import torch.optim as optim
import torch.utils.data
from torch.utils.tensorboard import SummaryWriter
from rich.progress import (
Progress,
TextColumn,
BarColumn,
TaskProgressColumn,
TimeRemainingColumn,
MofNCompleteColumn,
TimeElapsedColumn,
)
from rich.console import Console
from util.datasets import FiberRegenerationDataset
import util
from .settings import (
GlobalSettings,
DataSettings,
ModelSettings,
OptimizerSettings,
PytorchSettings,
)
class regenerator(nn.Module):
def __init__(
self,
*dims,
layer_function=util.complexNN.ONN,
layer_parametrizations: list[dict] = None,
# [
# {
# "tensor_name": "weight",
# "parametrization": util.complexNN.Unitary,
# },
# {
# "tensor_name": "scale",
# "parametrization": util.complexNN.Clamp,
# },
# ],
activation_function=util.complexNN.Pow,
dtype=torch.float64,
dropout_prob=0.01,
**kwargs,
):
super(regenerator, self).__init__()
if len(dims) == 0:
try:
dims = kwargs["dims"]
except KeyError:
raise ValueError("dims must be provided")
self._n_hidden_layers = len(dims) - 2
self._layers = nn.Sequential()
for i in range(self._n_hidden_layers + 1):
self._layers.append(layer_function(dims[i], dims[i + 1], dtype=dtype))
if i < self._n_hidden_layers:
if dropout_prob is not None:
self._layers.append(util.complexNN.DropoutComplex(p=dropout_prob))
self._layers.append(activation_function())
# add parametrizations
if layer_parametrizations is not None:
for layer_parametrization in layer_parametrizations:
tensor_name = layer_parametrization.get("tensor_name", None)
parametrization = layer_parametrization.get("parametrization", None)
param_kwargs = layer_parametrization.get("kwargs", {})
if (
tensor_name is not None
and tensor_name in self._layers[-1]._parameters
and parametrization is not None
):
parametrization(self._layers[-1], tensor_name, **param_kwargs)
def forward(self, input_x):
x = input_x
# check if tracing
if torch.jit.is_tracing():
for layer in self._layers:
x = layer(x)
else:
# with torch.nn.utils.parametrize.cached():
for layer in self._layers:
x = layer(x)
return x
def traverse_dict_update(target, source):
for k, v in source.items():
if isinstance(v, dict):
if k not in target:
target[k] = {}
traverse_dict_update(target[k], v)
else:
try:
target[k] = v
except TypeError:
target.__dict__[k] = v
class Trainer:
def __init__(
self,
*,
global_settings=None,
data_settings=None,
pytorch_settings=None,
model_settings=None,
optimizer_settings=None,
console=None,
checkpoint_path=None,
settings_override=None,
reset_epoch=False,
):
self.resume = checkpoint_path is not None
torch.serialization.add_safe_globals([
*util.complexNN.__all__,
GlobalSettings,
DataSettings,
ModelSettings,
OptimizerSettings,
PytorchSettings,
regenerator,
torch.nn.utils.parametrizations.orthogonal
])
if self.resume:
self.checkpoint_dict = torch.load(checkpoint_path, weights_only=True)
if settings_override is not None:
traverse_dict_update(self.checkpoint_dict["settings"], settings_override)
if reset_epoch:
self.checkpoint_dict["epoch"] = -1
self.global_settings: GlobalSettings = self.checkpoint_dict["settings"]["global_settings"]
self.data_settings: DataSettings = self.checkpoint_dict["settings"]["data_settings"]
self.pytorch_settings: PytorchSettings = self.checkpoint_dict["settings"]["pytorch_settings"]
self.model_settings: ModelSettings = self.checkpoint_dict["settings"]["model_settings"]
self.optimizer_settings: OptimizerSettings = self.checkpoint_dict["settings"]["optimizer_settings"]
else:
if global_settings is None:
raise ValueError("global_settings must be provided")
if data_settings is None:
raise ValueError("data_settings must be provided")
if pytorch_settings is None:
raise ValueError("pytorch_settings must be provided")
if model_settings is None:
raise ValueError("model_settings must be provided")
if optimizer_settings is None:
raise ValueError("optimizer_settings must be provided")
self.global_settings: GlobalSettings = global_settings
self.data_settings: DataSettings = data_settings
self.pytorch_settings: PytorchSettings = pytorch_settings
self.model_settings: ModelSettings = model_settings
self.optimizer_settings: OptimizerSettings = optimizer_settings
self.console = console or Console()
self.writer = None
def setup_tb_writer(self, append=None):
log_dir = self.pytorch_settings.summary_dir + "/" + (datetime.now().strftime("%Y%m%d_%H%M%S"))
if append is not None:
log_dir += "_" + str(append)
print(f"Logging to {log_dir}")
self.writer = SummaryWriter(log_dir=log_dir)
def save_checkpoint(self, save_dict, filename):
torch.save(save_dict, filename)
def build_checkpoint_dict(self, loss=None, epoch=None):
return {
"epoch": -1 if epoch is None else epoch,
"loss": float("inf") if loss is None else loss,
"model_state_dict": copy.deepcopy(self.model.state_dict()),
"optimizer_state_dict": copy.deepcopy(self.optimizer.state_dict()),
"scheduler_state_dict": copy.deepcopy(self.scheduler.state_dict()) if hasattr(self, "scheduler") else None,
"model_kwargs": copy.deepcopy(self.model_kwargs),
"settings": {
"global_settings": copy.deepcopy(self.global_settings),
"data_settings": copy.deepcopy(self.data_settings),
"pytorch_settings": copy.deepcopy(self.pytorch_settings),
"model_settings": copy.deepcopy(self.model_settings),
"optimizer_settings": copy.deepcopy(self.optimizer_settings),
},
}
def define_model(self, model_kwargs=None):
if model_kwargs is None:
n_hidden_layers = self.model_settings.n_hidden_layers
input_dim = 2 * self.data_settings.output_size
dtype = getattr(torch, self.data_settings.dtype)
afunc = getattr(util.complexNN, self.model_settings.model_activation_func)
layer_func = getattr(util.complexNN, self.model_settings.model_layer_function)
layer_parametrizations = self.model_settings.model_layer_parametrizations
hidden_dims = [self.model_settings.overrides.get(f"n_hidden_nodes_{i}") for i in range(n_hidden_layers)]
self.model_kwargs = {
"dims": (input_dim, *hidden_dims, self.model_settings.output_dim),
"layer_function": layer_func,
"layer_parametrizations": layer_parametrizations,
"activation_function": afunc,
"dtype": dtype,
"dropout_prob": self.model_settings.dropout_prob,
}
else:
self.model_kwargs = model_kwargs
input_dim = self.model_kwargs["dims"][0]
dtype = self.model_kwargs["dtype"]
# dims = self.model_kwargs.pop("dims")
self.model = regenerator(**self.model_kwargs)
self.writer.add_graph(self.model, torch.zeros(1, input_dim, dtype=dtype))
self.model = self.model.to(self.pytorch_settings.device)
def get_sliced_data(self, override=None):
symbols = self.data_settings.symbols
in_out_delay = self.data_settings.in_out_delay
xy_delay = self.data_settings.xy_delay
data_size = self.data_settings.output_size
dtype = getattr(torch, self.data_settings.dtype)
num_symbols = None
if override is not None:
num_symbols = override.get("num_symbols", None)
# get dataset
dataset = FiberRegenerationDataset(
file_path=self.data_settings.config_path,
symbols=symbols,
output_dim=data_size,
target_delay=in_out_delay,
xy_delay=xy_delay,
drop_first=self.data_settings.drop_first,
dtype=dtype,
real=not dtype.is_complex,
num_symbols=num_symbols,
)
dataset_size = len(dataset)
indices = list(range(dataset_size))
split = int(np.floor(self.data_settings.train_split * dataset_size))
if self.data_settings.shuffle:
np.random.seed(self.global_settings.seed)
np.random.shuffle(indices)
train_indices, valid_indices = indices[:split], indices[split:]
if self.data_settings.shuffle:
train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
valid_sampler = torch.utils.data.SubsetRandomSampler(valid_indices)
else:
train_sampler = train_indices
valid_sampler = valid_indices
train_loader = torch.utils.data.DataLoader(
dataset,
batch_size=self.pytorch_settings.batchsize,
sampler=train_sampler,
drop_last=True,
pin_memory=True,
num_workers=self.pytorch_settings.dataloader_workers,
prefetch_factor=self.pytorch_settings.dataloader_prefetch,
)
valid_loader = torch.utils.data.DataLoader(
dataset,
batch_size=self.pytorch_settings.batchsize,
sampler=valid_sampler,
drop_last=True,
pin_memory=True,
num_workers=self.pytorch_settings.dataloader_workers,
prefetch_factor=self.pytorch_settings.dataloader_prefetch,
)
return train_loader, valid_loader
def train_model(
self,
optimizer,
train_loader,
epoch,
enable_progress=False,
):
if enable_progress:
progress = Progress(
TextColumn("[yellow] Training..."),
TextColumn("Error: {task.description}"),
BarColumn(),
TaskProgressColumn(),
TextColumn("[green]Batch"),
MofNCompleteColumn(),
TimeRemainingColumn(),
TimeElapsedColumn(),
transient=False,
console=self.console,
refresh_per_second=10,
)
task = progress.add_task("-.---e--", total=len(train_loader))
progress.start()
running_loss2 = 0.0
running_loss = 0.0
self.model.train()
for batch_idx, (x, y) in enumerate(train_loader):
self.model.zero_grad(set_to_none=True)
x, y = (
x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device),
)
y_pred = self.model(x)
loss = util.complexNN.complex_mse_loss(y_pred, y, power=True)
loss_value = loss.item()
loss.backward()
optimizer.step()
running_loss2 += loss_value
running_loss += loss_value
if enable_progress:
progress.update(task, advance=1, description=f"{loss_value:.3e}")
if batch_idx % self.pytorch_settings.write_every == 0:
self.writer.add_scalar(
"training loss",
running_loss2 / (self.pytorch_settings.write_every if batch_idx > 0 else 1),
epoch * len(train_loader) + batch_idx,
)
running_loss2 = 0.0
if enable_progress:
progress.stop()
return running_loss / len(train_loader)
def eval_model(self, valid_loader, epoch, enable_progress=True):
if enable_progress:
progress = Progress(
TextColumn("[green]Evaluating..."),
TextColumn("Error: {task.description}"),
BarColumn(),
TaskProgressColumn(),
TextColumn("[green]Batch"),
MofNCompleteColumn(),
TimeRemainingColumn(),
TimeElapsedColumn(),
transient=False,
console=self.console,
refresh_per_second=10,
)
progress.start()
task = progress.add_task("-.---e--", total=len(valid_loader))
self.model.eval()
running_error = 0
with torch.no_grad():
for batch_idx, (x, y) in enumerate(valid_loader):
x, y = (
x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device),
)
y_pred = self.model(x)
error = util.complexNN.complex_mse_loss(y_pred, y, power=True)
error_value = error.item()
running_error += error_value
if enable_progress:
progress.update(task, advance=1, description=f"{error_value:.3e}")
running_error /= len(valid_loader)
self.writer.add_scalar(
"eval loss",
running_error,
epoch,
)
title_append, subtitle = self.build_title(epoch + 1)
self.writer.add_figure(
"fiber response",
self.plot_model_response(
model=self.model,
title_append=title_append,
subtitle=subtitle,
show=False,
),
epoch + 1,
)
self.writer.add_figure(
"eye diagram",
self.plot_model_response(
model=self.model,
title_append=title_append,
subtitle=subtitle,
show=False,
mode="eye",
),
epoch + 1,
)
self.writer_histograms(epoch + 1)
if enable_progress:
progress.stop()
return running_error
def run_model(self, model, loader):
model.eval()
xs = []
ys = []
y_preds = []
with torch.no_grad():
model = model.to(self.pytorch_settings.device)
for x, y in loader:
x, y = (
x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device),
)
y_pred = model(x).cpu()
# x = x.cpu()
# y = y.cpu()
y_pred = y_pred.view(y_pred.shape[0], -1, 2)
y = y.view(y.shape[0], -1, 2)
x = x.view(x.shape[0], -1, 2)
xs.append(x[:, 0, :].squeeze())
ys.append(y.squeeze())
y_preds.append(y_pred.squeeze())
xs = torch.vstack(xs).cpu()
ys = torch.vstack(ys).cpu()
y_preds = torch.vstack(y_preds).cpu()
return ys, xs, y_preds
def writer_histograms(self, epoch, attributes=["weight", "weight_U", "weight_V", "bias", "sigma", "scale"]):
for i, layer in enumerate(self.model._layers):
tag = f"layer {i}"
for attribute in attributes:
if hasattr(layer, attribute):
vals: np.ndarray = getattr(layer, attribute).detach().cpu().numpy().flatten()
if vals.ndim <= 1 and len(vals) == 1:
if np.iscomplexobj(vals):
self.writer.add_scalar(f"{tag} {attribute} (Mag)", np.abs(vals), epoch)
self.writer.add_scalar(f"{tag} {attribute} (Phase)", np.angle(vals), epoch)
else:
self.writer.add_scalar(f"{tag} {attribute}", vals, epoch)
else:
if np.iscomplexobj(vals):
self.writer.add_histogram(f"{tag} {attribute} (Mag)", np.abs(vals), epoch, bins="fd")
self.writer.add_histogram(f"{tag} {attribute} (Phase)", np.angle(vals), epoch, bins="fd")
else:
self.writer.add_histogram(f"{tag} {attribute}", vals, epoch, bins="fd")
def train(self):
if self.writer is None:
self.setup_tb_writer()
if self.resume:
model_kwargs = self.checkpoint_dict["model_kwargs"]
else:
model_kwargs = None
self.define_model(model_kwargs=model_kwargs)
print(f"number of parameters (trainable): {sum(p.numel() for p in self.model.parameters())} ({sum(p.numel() for p in self.model.parameters() if p.requires_grad)})")
title_append, subtitle = self.build_title(0)
self.writer.add_figure(
"fiber response",
self.plot_model_response(
model=self.model,
title_append=title_append,
subtitle=subtitle,
show=False,
),
0,
)
self.writer.add_figure(
"eye diagram",
self.plot_model_response(
model=self.model,
title_append=title_append,
subtitle=subtitle,
mode="eye",
show=False,
),
0,
)
self.writer_histograms(0)
train_loader, valid_loader = self.get_sliced_data()
optimizer_name = self.optimizer_settings.optimizer
lr = self.optimizer_settings.learning_rate
self.optimizer: optim.Optimizer = getattr(optim, optimizer_name)(self.model.parameters(), lr=lr)
if self.optimizer_settings.scheduler is not None:
self.scheduler = getattr(optim.lr_scheduler, self.optimizer_settings.scheduler)(
self.optimizer, **self.optimizer_settings.scheduler_kwargs
)
if self.resume:
try:
self.scheduler.load_state_dict(self.checkpoint_dict["scheduler_state_dict"])
except ValueError:
pass
self.writer.add_scalar("learning rate", self.scheduler.get_last_lr()[0], -1)
if not self.resume:
self.best = self.build_checkpoint_dict()
else:
self.best = self.checkpoint_dict
self.model.load_state_dict(self.best["model_state_dict"], strict=False)
try:
self.optimizer.load_state_dict(self.best["optimizer_state_dict"])
except ValueError:
pass
for epoch in range(self.best["epoch"] + 1, self.pytorch_settings.epochs):
enable_progress = True
if enable_progress:
self.console.rule(f"Epoch {epoch + 1}/{self.pytorch_settings.epochs}")
self.train_model(
self.optimizer,
train_loader,
epoch,
enable_progress=enable_progress,
)
loss = self.eval_model(
valid_loader,
epoch,
enable_progress=enable_progress,
)
if self.optimizer_settings.scheduler is not None:
lr_old = self.scheduler.get_last_lr()
self.scheduler.step(loss)
lr_new = self.scheduler.get_last_lr()
if lr_old[0] != lr_new[0]:
self.writer.add_scalar("learning rate", lr_new[0], epoch)
if self.pytorch_settings.save_models and self.model is not None:
save_path = (
Path(self.pytorch_settings.model_dir) / f"{self.writer.get_logdir().split('/')[-1]}_{epoch}.tar"
)
save_path.parent.mkdir(parents=True, exist_ok=True)
checkpoint = self.build_checkpoint_dict(loss, epoch)
self.save_checkpoint(checkpoint, save_path)
if loss < self.best["loss"]:
self.best = checkpoint
save_path = (
Path(self.pytorch_settings.model_dir) / f"best_{self.writer.get_logdir().split('/')[-1]}.tar"
)
save_path.parent.mkdir(parents=True, exist_ok=True)
self.save_checkpoint(self.best, save_path)
self.writer.flush()
self.writer.close()
return self.best
def _plot_model_response_eye(self, *signals, labels=None, sps=None, title_append="", subtitle="", show=True):
if sps is None:
raise ValueError("sps must be provided")
if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))):
labels = [labels]
else:
labels = list(labels)
while len(labels) < len(signals):
labels.append(None)
# check if there are any labels
if not any(labels):
labels = [f"signal {i + 1}" for i in range(len(signals))]
fig, axs = plt.subplots(1, 2 * len(signals), sharex=True, sharey=True)
fig.set_figwidth(18)
fig.suptitle(f"Eye diagram{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}")
xaxis = np.linspace(0, 2, 2 * sps, endpoint=False)
for j, (label, signal) in enumerate(zip(labels, signals)):
# signal = signal.cpu().numpy()
for i in range(len(signal) // sps - 1):
x, y = signal[i * sps : (i + 2) * sps].T
axs[0 + 2 * j].plot(xaxis, np.abs(x) ** 2, color=f"C{j}", alpha=1 / (len(signal) // sps) * 10)
axs[1 + 2 * j].plot(xaxis, np.abs(y) ** 2, color=f"C{j}", alpha=1 / (len(signal) // sps) * 10)
axs[0 + 2 * j].set_title(label + " x")
axs[1 + 2 * j].set_title(label + " y")
axs[0 + 2 * j].set_xlabel("Symbol")
axs[1 + 2 * j].set_xlabel("Symbol")
axs[0 + 2 * j].set_box_aspect(1)
axs[1 + 2 * j].set_box_aspect(1)
axs[0].set_ylabel("normalized power")
fig.tight_layout()
# axs[1+2*len(labels)-1].set_ylabel("normalized power")
if show:
plt.show()
return fig
def _plot_model_response_head(self, *signals, labels=None, sps=None, title_append="", subtitle="", show=True):
if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))):
labels = [labels]
else:
labels = list(labels)
while len(labels) < len(signals):
labels.append(None)
# check if there are any labels
if not any(labels):
labels = [f"signal {i + 1}" for i in range(len(signals))]
fig, axs = plt.subplots(1, 2, sharex=True, sharey=True)
fig.set_figwidth(18)
fig.set_figheight(4)
fig.suptitle(f"Fiber response{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}")
for i, ax in enumerate(axs):
for signal, label in zip(signals, labels):
if sps is not None:
xaxis = np.linspace(0, len(signal) / sps, len(signal), endpoint=False)
else:
xaxis = np.arange(len(signal))
ax.plot(xaxis, np.abs(signal[:, i]) ** 2, label=label)
ax.set_xlabel("Sample" if sps is None else "Symbol")
ax.set_ylabel("normalized power")
ax.legend(loc="upper right")
fig.tight_layout()
if show:
plt.show()
return fig
def plot_model_response(
self,
model=None,
title_append="",
subtitle="",
mode: Literal["eye", "head"] = "head",
show=False,
):
data_settings_backup = copy.deepcopy(self.data_settings)
pytorch_settings_backup = copy.deepcopy(self.pytorch_settings)
self.data_settings.drop_first = 100 * 128
self.data_settings.shuffle = False
self.data_settings.train_split = 1.0
self.pytorch_settings.batchsize = (
self.pytorch_settings.eye_symbols if mode == "eye" else self.pytorch_settings.head_symbols
)
plot_loader, _ = self.get_sliced_data(override={"num_symbols": self.pytorch_settings.batchsize})
self.data_settings = data_settings_backup
self.pytorch_settings = pytorch_settings_backup
fiber_in, fiber_out, regen = self.run_model(model, plot_loader)
fiber_in = fiber_in.view(-1, 2)
fiber_out = fiber_out.view(-1, 2)
regen = regen.view(-1, 2)
fiber_in = fiber_in.numpy()
fiber_out = fiber_out.numpy()
regen = regen.numpy()
# https://github.com/matplotlib/matplotlib/issues/27713#issue-2104110987
# https://github.com/matplotlib/matplotlib/issues/27713#issuecomment-1915497463
import gc
if mode == "head":
fig = self._plot_model_response_head(
fiber_in,
fiber_out,
regen,
labels=("fiber in", "fiber out", "regen"),
sps=plot_loader.dataset.samples_per_symbol,
title_append=title_append,
subtitle=subtitle,
show=show,
)
elif mode == "eye":
# raise NotImplementedError("Eye diagram not implemented")
fig = self._plot_model_response_eye(
fiber_in,
fiber_out,
regen,
labels=("fiber in", "fiber out", "regen"),
sps=plot_loader.dataset.samples_per_symbol,
title_append=title_append,
subtitle=subtitle,
show=show,
)
else:
raise ValueError(f"Unknown mode: {mode}")
gc.collect()
return fig
def build_title(self, number: int):
title_append = f"epoch {number}"
model_n_hidden_layers = self.model_settings.n_hidden_layers
input_dim = 2 * self.data_settings.output_size
model_dims = [
self.model_settings.overrides.get(f"n_hidden_nodes_{i}", -1) for i in range(model_n_hidden_layers)
]
model_dims.insert(0, input_dim)
model_dims.append(2)
model_dims = [str(dim) for dim in model_dims]
model_activation_func = self.model_settings.model_activation_func
model_dtype = self.data_settings.dtype
subtitle = f"{model_n_hidden_layers + 2} layers à ({', '.join(model_dims)}) units, {model_activation_func}, {model_dtype}"
return title_append, subtitle

View File

@@ -30,10 +30,10 @@ data_settings = DataSettings(
) )
pytorch_settings = PytorchSettings( pytorch_settings = PytorchSettings(
epochs=10, epochs=10000,
batchsize=2**10, batchsize=2**10,
device="cuda", device="cuda",
dataloader_workers=2, dataloader_workers=12,
dataloader_prefetch=4, dataloader_prefetch=4,
summary_dir=".runs", summary_dir=".runs",
write_every=2**5, write_every=2**5,
@@ -44,33 +44,31 @@ pytorch_settings = PytorchSettings(
model_settings = ModelSettings( model_settings = ModelSettings(
output_dim=2, output_dim=2,
# n_hidden_layers = (3, 8), # n_hidden_layers = (3, 8),
n_hidden_layers=(4, 6), # study: single_core_regen_20241123_011232 n_hidden_layers=4,
n_hidden_nodes=(4,20), overrides={
# overrides={ "n_hidden_nodes_0": 8,
# "n_hidden_nodes_0": (14, 20), # study: single_core_regen_20241123_011232 "n_hidden_nodes_1": 6,
# "n_hidden_nodes_1": (8, 16), "n_hidden_nodes_2": 4,
# "n_hidden_nodes_2": (10, 16), "n_hidden_nodes_3": 8,
# # "n_hidden_nodes_3": (4, 20), # study: single_core_regen_20241123_135749 },
# "n_hidden_nodes_4": (2, 8), model_activation_func="Mag",
# "n_hidden_nodes_5": (10, 16), # satabsT0=(1e-6, 1),
# },
# model_activation_func = ("ModReLU", "Mag", "Identity")
model_activation_func="Mag", # study: single_core_regen_20241123_011232
) )
optimizer_settings = OptimizerSettings( optimizer_settings = OptimizerSettings(
optimizer="Adam", optimizer="Adam",
# learning_rate = (1e-5, 1e-1), # learning_rate = (1e-5, 1e-1),
learning_rate=5e-4, learning_rate=5e-3
# learning_rate=5e-4,
) )
optuna_settings = OptunaSettings( optuna_settings = OptunaSettings(
n_trials=512, n_trials=1,
n_workers=14, n_workers=1,
timeout=3600, timeout=3600,
directions=("maximize", "minimize"), directions=("minimize",),
metrics_names=("neg_log_mse","n_nodes"), metrics_names=("mse",),
limit_examples=True, limit_examples=False,
n_train_batches=500, n_train_batches=500,
# n_valid_batches = 100, # n_valid_batches = 100,
storage="sqlite:///data/single_core_regen.db", storage="sqlite:///data/single_core_regen.db",

View File

@@ -1,414 +1,130 @@
import copy from hypertraining.settings import (
from dataclasses import dataclass GlobalSettings,
from datetime import datetime DataSettings,
from pathlib import Path PytorchSettings,
import matplotlib.pyplot as plt ModelSettings,
OptimizerSettings,
import numpy as np
import torch
import torch.nn as nn
# import torch.nn.functional as F # mse_loss doesn't support complex numbers
import torch.optim as optim
import torch.utils.data
from torch.utils.tensorboard import SummaryWriter
from rich.progress import (
Progress,
TextColumn,
BarColumn,
TaskProgressColumn,
TimeRemainingColumn,
MofNCompleteColumn,
TimeElapsedColumn,
) )
from rich.console import Console
from rich import print as rprint
# from util.optuna_helpers import optional_suggest_categorical, optional_suggest_float, optional_suggest_int from hypertraining.training import Trainer
import torch
import json
import util import util
global_settings = GlobalSettings(
# global settings seed=42,
@dataclass
class GlobalSettings:
seed: int = 42
# data settings
@dataclass
class DataSettings:
config_path: str = "data/*-128-16384-100000-0-0-17-0-PAM4-0.ini"
dtype: torch.dtype = torch.complex64
symbols_range: float | int = 8
data_size_range: float | int = 64
shuffle: bool = True
target_delay: float = 0
xy_delay_range: float | int = 0
drop_first: int = 10
train_split: float = 0.8
# pytorch settings
@dataclass
class PytorchSettings:
epochs: int = 1000
batchsize: int = 2**12
device: str = "cuda"
summary_dir: str = ".runs"
model_dir: str = ".models"
# model settings
@dataclass
class ModelSettings:
output_size: int = 2
# n_layer_range: float|int = 2
# n_units_range: float|int = 32
n_layers: int = 3
n_units: int = 32
activation_func: tuple | str = "ModReLU"
@dataclass
class OptimizerSettings:
optimizer_range: str = "Adam"
lr_range: float = 2e-3
class Training:
def __init__(self):
self.global_settings = GlobalSettings()
self.data_settings = DataSettings()
self.pytorch_settings = PytorchSettings()
self.model_settings = ModelSettings()
self.optimizer_settings = OptimizerSettings()
self.study_name = (
f"single_core_regen_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}"
) )
if not hasattr(self.pytorch_settings, "model_dir"): data_settings = DataSettings(
self.pytorch_settings.model_dir = ".models" config_path="data/*-128-16384-100000-0-0-17-0-PAM4-0.ini",
dtype="complex64",
self.writer = None # symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
self.console = Console() symbols=13, # study: single_core_regen_20241123_011232
# output_size = (11, 32), # 26 taps -> 2 taps per input symbol -> 1 tap every 1cm (model has 52 inputs (x/y))
def setup_tb_writer(self, study_name=None, append=None): output_size=26, # study: single_core_regen_20241123_011232 (model_input_dim/2)
log_dir = ( shuffle=True,
self.pytorch_settings.summary_dir + "/" + (study_name or self.study_name) + ("_" + str(append)) if append else "" in_out_delay=0,
) xy_delay=0,
self.writer = SummaryWriter(log_dir) drop_first=128*64,
return self.writer train_split=0.8,
def plot_eye(self, width=2, symbols=None, alpha=None, complex=False, show=True):
if not hasattr(self, "eye_data"):
data, config = util.datasets.load_data(
self.data_settings.config_path,
skipfirst=10,
symbols=symbols or 1000,
real=not self.data_settings.dtype.is_complex,
normalize=True,
)
self.eye_data = {"data": data, "sps": int(config["glova"]["sps"])}
return util.plot.eye(
**self.eye_data,
width=width,
show=show,
alpha=alpha,
complex=complex,
symbols=symbols or 1000,
skipfirst=0,
) )
def define_model(self): pytorch_settings = PytorchSettings(
n_layers = self.model_settings.n_layers epochs=10000,
batchsize=2**12,
in_features = 2 * self.data_settings.data_size_range device="cuda",
dataloader_workers=12,
layers = [] dataloader_prefetch=8,
for i in range(n_layers): summary_dir=".runs",
out_features = self.model_settings.n_units write_every=2**5,
save_models=True,
layers.append(util.complexNN.UnitaryLayer(in_features, out_features)) model_dir=".models",
# layers.append(getattr(nn, self.model_settings.activation_func)())
layers.append(
getattr(util.complexNN, self.model_settings.activation_func)()
)
in_features = out_features
layers.append(
util.complexNN.UnitaryLayer(in_features, self.model_settings.output_size)
) )
if self.writer is not None: model_settings = ModelSettings(
self.writer.add_graph( output_dim=2,
nn.Sequential(*layers), n_hidden_layers=4,
torch.zeros(1, layers[0].in_features, dtype=self.data_settings.dtype), overrides={
"n_hidden_nodes_0": 8,
"n_hidden_nodes_1": 8,
"n_hidden_nodes_2": 4,
"n_hidden_nodes_3": 6,
},
model_activation_func="PowScale",
# dropout_prob=0.01,
model_layer_function="ONN",
model_layer_parametrizations=[
{
"tensor_name": "weight",
"parametrization": torch.nn.utils.parametrizations.orthogonal,
},
{
"tensor_name": "scales",
"parametrization": util.complexNN.clamp,
},
{
"tensor_name": "scale",
"parametrization": util.complexNN.clamp,
},
{
"tensor_name": "bias",
"parametrization": util.complexNN.clamp,
},
# {
# "tensor_name": "V",
# "parametrization": torch.nn.utils.parametrizations.orthogonal,
# },
# {
# "tensor_name": "S",
# "parametrization": util.complexNN.clamp,
# },
],
) )
return nn.Sequential(*layers) optimizer_settings = OptimizerSettings(
optimizer="Adam",
def get_sliced_data(self): learning_rate=0.05,
symbols = self.data_settings.symbols_range scheduler="ReduceLROnPlateau",
scheduler_kwargs={
xy_delay = self.data_settings.xy_delay_range "patience": 2**6,
"factor": 0.9,
data_size = self.data_settings.data_size_range # "threshold": 1e-3,
"min_lr": 1e-6,
# get dataset "cooldown": 10,
dataset = util.datasets.FiberRegenerationDataset( },
file_path=self.data_settings.config_path,
symbols=symbols,
output_dim=data_size,
target_delay=self.data_settings.target_delay,
xy_delay=xy_delay,
drop_first=self.data_settings.drop_first,
dtype=self.data_settings.dtype,
real=not self.data_settings.dtype.is_complex,
# device=self.pytorch_settings.device,
) )
dataset_size = len(dataset) def save_dict_to_file(dictionary, filename):
indices = list(range(dataset_size)) """
split = int(np.floor(self.data_settings.train_split * dataset_size)) Save the best dictionary to a JSON file.
if self.data_settings.shuffle:
np.random.seed(self.global_settings.seed)
np.random.shuffle(indices)
train_indices, valid_indices = indices[:split], indices[split:] :param best: Dictionary containing the best training results.
:type best: dict
:param filename: Path to the JSON file where the dictionary will be saved.
:type filename: str
"""
with open(filename, 'w') as f:
json.dump(dictionary, f, indent=4)
if self.data_settings.shuffle:
train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
valid_sampler = torch.utils.data.SubsetRandomSampler(valid_indices)
else:
train_sampler = train_indices
valid_sampler = valid_indices
train_loader = torch.utils.data.DataLoader(
dataset,
batch_size=self.pytorch_settings.batchsize,
sampler=train_sampler,
drop_last=True,
pin_memory=True,
num_workers=24,
prefetch_factor=4,
# persistent_workers=True
)
valid_loader = torch.utils.data.DataLoader(
dataset,
batch_size=self.pytorch_settings.batchsize,
sampler=valid_sampler,
drop_last=True,
pin_memory=True,
num_workers=24,
prefetch_factor=4,
# persistent_workers=True
)
return train_loader, valid_loader
def train_model(self, model, optimizer, train_loader, epoch):
with Progress(
TextColumn("[yellow] Training..."),
TextColumn("Error: {task.description}"),
BarColumn(),
TaskProgressColumn(),
TextColumn("[green]Batch"),
MofNCompleteColumn(),
TimeRemainingColumn(),
TimeElapsedColumn(),
# description="Training",
transient=False,
console=self.console,
refresh_per_second=10,
) as progress:
task = progress.add_task("-.---e--", total=len(train_loader))
running_loss = 0.0
model.train()
for batch_idx, (x, y) in enumerate(train_loader):
model.zero_grad(set_to_none=True)
x, y = (
x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device),
)
y_pred = model(x)
loss = util.complexNN.complex_mse_loss(y_pred, y)
loss.backward()
optimizer.step()
progress.update(task, advance=1, description=f"{loss.item():.3e}")
running_loss += loss.item()
if self.writer is not None:
if (batch_idx + 1) % 10 == 0:
self.writer.add_scalar(
"training loss",
running_loss / 10,
epoch * len(train_loader) + batch_idx,
)
running_loss = 0.0
return running_loss
def eval_model(self, model, valid_loader, epoch):
with Progress(
TextColumn("[green]Evaluating..."),
TextColumn("Error: {task.description}"),
BarColumn(),
TaskProgressColumn(),
TextColumn("[green]Batch"),
MofNCompleteColumn(),
TimeRemainingColumn(),
TimeElapsedColumn(),
# description="Training",
transient=False,
console=self.console,
refresh_per_second=10,
) as progress:
task = progress.add_task("-.---e--", total=len(valid_loader))
model.eval()
running_loss = 0
running_loss2 = 0
with torch.no_grad():
for batch_idx, (x, y) in enumerate(valid_loader):
x, y = (
x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device),
)
y_pred = model(x)
loss = util.complexNN.complex_mse_loss(y_pred, y)
running_loss += loss.item()
running_loss2 += loss.item()
progress.update(task, advance=1, description=f"{loss.item():.3e}")
if self.writer is not None:
if (batch_idx + 1) % 10 == 0:
self.writer.add_scalar(
"loss",
running_loss / 10,
epoch * len(valid_loader) + batch_idx,
)
running_loss = 0.0
if self.writer is not None:
self.writer.add_figure("fiber response", self.plot_model_response(model, plot=False), epoch+1)
return running_loss2 / len(valid_loader)
def run_model(self, model, loader):
model.eval()
xs = []
ys = []
y_preds = []
with torch.no_grad():
model = model.to(self.pytorch_settings.device)
for x, y in loader:
x, y = (
x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device),
)
y_pred = model(x).cpu()
# x = x.cpu()
# y = y.cpu()
y_pred = y_pred.view(y_pred.shape[0], -1, 2)
y = y.view(y.shape[0], -1, 2)
x = x.view(x.shape[0], -1, 2)
xs.append(x[:, 0, :].squeeze())
ys.append(y.squeeze())
y_preds.append(y_pred.squeeze())
xs = torch.vstack(xs).cpu()
ys = torch.vstack(ys).cpu()
y_preds = torch.vstack(y_preds).cpu()
return ys, xs, y_preds
def objective(self, save=False, plot_before=False):
try:
rprint(*list(self.study_name.split("_")))
self.model = self.define_model().to(self.pytorch_settings.device)
if self.writer is not None:
self.writer.add_figure("fiber response", self.plot_model_response(plot=plot_before), 0)
train_loader, valid_loader = self.get_sliced_data()
optimizer_name = self.optimizer_settings.optimizer_range
lr = self.optimizer_settings.lr_range
optimizer = getattr(optim, optimizer_name)(self.model.parameters(), lr=lr)
for epoch in range(self.pytorch_settings.epochs):
self.console.rule(f"Epoch {epoch + 1}/{self.pytorch_settings.epochs}")
self.train_model(self.model, optimizer, train_loader, epoch)
eval_loss = self.eval_model(self.model, valid_loader, epoch)
return eval_loss
except KeyboardInterrupt:
...
finally:
if hasattr(self, "model"):
save_path = (
Path(self.pytorch_settings.model_dir) / f"{self.study_name}.pth"
)
save_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(self.model, save_path)
def _plot_model_response_plotter(self, fiber_in, fiber_out, regen, plot=True):
fig, axs = plt.subplots(2)
for i, ax in enumerate(axs):
ax.plot(np.abs(fiber_in[:, i]) ** 2, label="fiber in")
ax.plot(np.abs(fiber_out[:, i]) ** 2, label="fiber out")
ax.plot(np.abs(regen[:, i]) ** 2, label="regenerated")
ax.legend()
if plot:
plt.show()
return fig
def plot_model_response(self, model=None, plot=True):
data_settings_backup = copy.copy(self.data_settings)
self.data_settings.shuffle = False
self.data_settings.train_split = 0.01
self.data_settings.drop_first = 100
plot_loader, _ = self.get_sliced_data()
self.data_settings = data_settings_backup
fiber_in, fiber_out, regen = self.run_model(model or self.model, plot_loader)
fiber_in = fiber_in.view(-1, 2)
fiber_out = fiber_out.view(-1, 2)
regen = regen.view(-1, 2)
fiber_in = fiber_in.numpy()
fiber_out = fiber_out.numpy()
regen = regen.numpy()
# https://github.com/matplotlib/matplotlib/issues/27713#issue-2104110987
# https://github.com/matplotlib/matplotlib/issues/27713#issuecomment-1915497463
import gc
fig = self._plot_model_response_plotter(fiber_in, fiber_out, regen, plot=plot)
gc.collect()
return fig
if __name__ == "__main__": if __name__ == "__main__":
trainer = Training() trainer = Trainer(
global_settings=global_settings,
data_settings=data_settings,
pytorch_settings=pytorch_settings,
model_settings=model_settings,
optimizer_settings=optimizer_settings,
checkpoint_path='.models/20241128_084935_8885.tar',
settings_override={
"model_settings": {
# "model_activation_func": "PowScale",
"dropout_prob": 0,
}
},
reset_epoch=True,
)
# trainer.plot_eye() best = trainer.train()
trainer.setup_tb_writer() save_dict_to_file(best, ".models/best_results.json")
trainer.objective(save=True)
best_model = trainer.model
# best_model = trainer.define_model(trainer.study.best_trial).to(trainer.pytorch_settings.device)
trainer.plot_model_response(best_model)
# print(f"Best model: {best_model}")
... ...

View File

@@ -1,16 +1,26 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
# from torchlambertw.special import lambertw
def complex_mse_loss(input, target): def complex_mse_loss(input, target, power=False, reduction="mean"):
""" """
Compute the mean squared error between two complex tensors. Compute the mean squared error between two complex tensors.
If power is set to True, the loss is computed as |input|^2 - |target|^2
""" """
if input.is_complex(): reduce = getattr(torch, reduction)
return torch.mean(torch.square(input.real - target.real) + torch.square(input.imag - target.imag))
if power:
input = (input * input.conj()).real.to(dtype=input.dtype.to_real())
target = (target * target.conj()).real.to(dtype=target.dtype.to_real())
if input.is_complex() and target.is_complex():
return reduce(torch.square(input.real - target.real) + torch.square(input.imag - target.imag))
elif input.is_complex() or target.is_complex():
raise ValueError("Input and target must have the same type (real or complex)")
else: else:
return F.mse_loss(input, target) return F.mse_loss(input, target, reduction=reduction)
def complex_sse_loss(input, target): def complex_sse_loss(input, target):
@@ -43,6 +53,174 @@ class UnitaryLayer(nn.Module):
return f"UnitaryLayer({self.in_features}, {self.out_features})" return f"UnitaryLayer({self.in_features}, {self.out_features})"
class _Unitary(nn.Module):
def forward(self, X:torch.Tensor):
if X.ndim < 2:
raise ValueError(
"Only tensors with 2 or more dimensions are supported. "
f"Got a tensor of shape {X.shape}"
)
n, k = X.size(-2), X.size(-1)
transpose = n<k
if transpose:
X = X.transpose(-2, -1)
q, r = torch.linalg.qr(X)
# q: torch.Tensor = q
# r: torch.Tensor = r
d = r.diagonal(dim1=-2, dim2=-1).sgn()
q*=d.unsqueeze(-2)
if transpose:
q = q.transpose(-2, -1)
if n == k:
mask = (torch.linalg.det(q).abs() >= 0).to(q.dtype.to_real())
mask[mask == 0] = -1
mask = mask.unsqueeze(-1)
q[..., 0] *= mask
# X.copy_(q)
return q
def unitary(module: nn.Module, name: str = "weight") -> nn.Module:
weight = getattr(module, name, None)
if not isinstance(weight, torch.Tensor):
raise ValueError(f"Module '{module}' has no parameter or buffer '{name}'")
if weight.ndim < 2:
raise ValueError(f"Expected a matrix or batch of matrices. Got a tensor of {weight.ndim} dimensions.")
if weight.shape[-2] != weight.shape[-1]:
raise ValueError(f"Expected a square matrix or batch of square matrices. Got a tensor of shape {weight.shape}")
unit = _Unitary()
nn.utils.parametrize.register_parametrization(module, name, unit)
return module
class _SpecialUnitary(nn.Module):
def __init__(self):
super().__init__()
def forward(self, X:torch.Tensor):
n, k = X.size(-2), X.size(-1)
if n != k:
raise ValueError(f"Expected a square matrix. Got a tensor of shape {X.shape}")
q, _ = torch.linalg.qr(X)
q = q / torch.linalg.det(q).pow(1/n)
return q
def special_unitary(module: nn.Module, name: str = "weight") -> nn.Module:
weight = getattr(module, name, None)
if not isinstance(weight, torch.Tensor):
raise ValueError(f"Module '{module}' has no parameter or buffer '{name}'")
if weight.ndim < 2:
raise ValueError(f"Expected a matrix or batch of matrices. Got a tensor of {weight.ndim} dimensions.")
if weight.shape[-2] != weight.shape[-1]:
raise ValueError(f"Expected a square matrix or batch of square matrices. Got a tensor of shape {weight.shape}")
unit = _SpecialUnitary()
nn.utils.parametrize.register_parametrization(module, name, unit)
return module
class _Clamp(nn.Module):
def __init__(self, min, max):
super(_Clamp, self).__init__()
self.min = min
self.max = max
def forward(self, x):
if x.is_complex():
# clamp magnitude, ignore phase
return torch.clamp(x.abs(), self.min, self.max) * x / x.abs()
return torch.clamp(x, self.min, self.max)
def clamp(module: nn.Module, name: str = "scale", min=0, max=1) -> nn.Module:
scale = getattr(module, name, None)
if not isinstance(scale, torch.Tensor):
raise ValueError(f"Module '{module}' has no parameter or buffer '{name}'")
cl = _Clamp(min, max)
nn.utils.parametrize.register_parametrization(module, name, cl)
return module
class ONNMiller(nn.Module):
def __init__(self, input_dim, output_dim, dtype=None) -> None:
super(ONNMiller, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.dtype = dtype
self.dim = max(input_dim, output_dim)
# zero pad input to internal size if smaller
if self.input_dim < self.dim:
self.pad = lambda x: F.pad(x, ((self.dim - self.input_dim) // 2, (self.dim - self.input_dim + 1) // 2))
else:
self.pad = lambda x: x
self.pad.__doc__ = f"Zero pad input from {self.input_dim} to {self.dim}"
# crop output to desired size
if self.output_dim < self.dim:
self.crop = lambda x: x[:, (self.dim - self.output_dim) // 2 : (x.shape[1] - (self.dim - self.output_dim + 1) // 2)]
else:
self.crop = lambda x: x
self.crop.__doc__ = f"Crop output from {self.dim} to {self.output_dim}"
self.U = nn.Parameter(torch.randn(self.dim, self.dim, dtype=self.dtype)) # -> parametrization: Unitary
self.S = nn.Parameter(torch.randn(self.dim, dtype=self.dtype)) # -> parametrization: Clamp (magnitude 0..1)
self.V = nn.Parameter(torch.randn(self.dim, self.dim, dtype=self.dtype)) # -> parametrization: Unitary
self.register_buffer("MZI_scale", torch.tensor(2, dtype=self.dtype.to_real()).sqrt())
# V is actually V.H, but
def forward(self, x_in):
x = x_in
x = self.pad(x)
x = x @ self.U
x = x * (self.S.squeeze() / self.MZI_scale)
x = x @ self.V
x = self.crop(x)
return x
class ONN(nn.Module):
def __init__(self, input_dim, output_dim, dtype=None) -> None:
super(ONN, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.dtype = dtype
self.dim = max(input_dim, output_dim)
# zero pad input to internal size if smaller
if self.input_dim < self.dim:
self.pad = lambda x: F.pad(x, ((self.dim - self.input_dim) // 2, (self.dim - self.input_dim + 1) // 2))
self.pad.__doc__ = f"Zero pad input from {self.input_dim} to {self.dim}"
else:
self.pad = lambda x: x
self.pad.__doc__ = f"Input size equals internal size {self.dim}"
# crop output to desired size
if self.output_dim < self.dim:
self.crop = lambda x: x[:, (self.dim - self.output_dim) // 2 : (x.shape[1] - (self.dim - self.output_dim + 1) // 2)]
self.crop.__doc__ = f"Crop output from {self.dim} to {self.output_dim}"
else:
self.crop = lambda x: x
self.crop.__doc__ = f"Output size equals internal size {self.dim}"
self.weight = nn.Parameter(torch.randn(self.dim, self.dim, dtype=self.dtype))
def reset_parameters(self):
q, _ = torch.linalg.qr(self.weight)
self.weight.data = q
# def get_M(self):
# return self.U @ self.sigma @ self.V
def forward(self, x):
return self.crop(self.pad(x) @ self.weight)
class SemiUnitaryLayer(nn.Module): class SemiUnitaryLayer(nn.Module):
def __init__(self, input_dim, output_dim, dtype=None): def __init__(self, input_dim, output_dim, dtype=None):
super(SemiUnitaryLayer, self).__init__() super(SemiUnitaryLayer, self).__init__()
@@ -51,24 +229,84 @@ class SemiUnitaryLayer(nn.Module):
# Create a larger square matrix for QR decomposition # Create a larger square matrix for QR decomposition
self.weight = nn.Parameter(torch.randn(max(input_dim, output_dim), max(input_dim, output_dim), dtype=dtype)) self.weight = nn.Parameter(torch.randn(max(input_dim, output_dim), max(input_dim, output_dim), dtype=dtype))
self.scale = nn.Parameter(torch.tensor(1.0, dtype=dtype.to_real()))
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
# Ensure the weights are semi-unitary by QR decomposition # Ensure the weights are unitary by QR decomposition
q, _ = torch.linalg.qr(self.weight) q, _ = torch.linalg.qr(self.weight)
# A = QR with A being a complex square matrix -> Q is unitary, R is upper triangular
# truncate the matrix to the desired size
if self.input_dim > self.output_dim: if self.input_dim > self.output_dim:
self.weight.data = q[: self.input_dim, : self.output_dim] self.weight.data = q[: self.input_dim, : self.output_dim]
else: else:
self.weight.data = q[: self.output_dim, : self.input_dim].t() self.weight.data = q[: self.output_dim, : self.input_dim].t()
...
def forward(self, x): def forward(self, x):
out = torch.matmul(x, self.weight) with torch.no_grad():
scale = torch.clamp(self.scale, 0.0, 1.0)
out = torch.matmul(x, scale * self.weight)
return out return out
def __repr__(self): def __repr__(self):
return f"SemiUnitaryLayer({self.input_dim}, {self.output_dim})" return f"SemiUnitaryLayer({self.input_dim}, {self.output_dim})"
# class SaturableAbsorberLambertW(nn.Module):
# """
# Implements the activation function for an optical saturable absorber
# base eqn: sigma*tau*I0 = 0.5*(log(Tm/T0))/(1-Tm),
# where: sigma is the absorption cross section
# tau is the radiative lifetime of the absorber material
# T0 is the initial transmittance
# I0 is the input intensity
# Tm is the transmittance of the absorber
# The activation function is defined as:
# Iout = I0 * Tm(I0)
# where Tm(I0) is the transmittance of the absorber as a function of the input intensity I0
# for a unit sigma*tau product, he solution Tm(I0) is given by:
# Tm(I0) = (W(2*exp(2*I0)*I0*T0))/(2*I0),
# where W is the Lambert W function
# if sigma*tau is not 1, I0 has to be scaled by sigma*tau
# (-> x has to be scaled by sqrt(sigma*tau))
# """
# def __init__(self, T0):
# super(SaturableAbsorberLambertW, self).__init__()
# self.register_buffer("T0", torch.tensor(T0))
# def forward(self, x: torch.Tensor):
# xc = x.conj()
# two_x_xc = (2 * x * xc).real
# return (lambertw(2 * torch.exp(two_x_xc) * (x * self.T0 * xc).real) / two_x_xc).to(dtype=x.dtype)
# def backward(self, x):
# xc = x.conj()
# lambert_eval = lambertw(2 * torch.exp(2 * x * xc).real * (x * self.T0 * xc).real)
# return (((xc * (-2 * lambert_eval + 2 * torch.square(x) - 1) + 2 * x * torch.square(xc) + x) * lambert_eval) / (
# 2 * torch.pow(x, 3) * xc * (lambert_eval + 1)
# )).to(dtype=x.dtype)
# class SaturableAbsorber(nn.Module):
# def __init__(self, alpha, I0):
# super(SaturableAbsorber, self).__init__()
# self.register_buffer("alpha", torch.tensor(alpha))
# self.register_buffer("I0", torch.tensor(I0))
# def forward(self, x):
# I = (x*x.conj()).to(dtype=x.dtype.to_real())
# A = self.alpha/(1+I/self.I0)
# class SpreadLayer(nn.Module): # class SpreadLayer(nn.Module):
# def __init__(self, in_features, out_features, dtype=None): # def __init__(self, in_features, out_features, dtype=None):
# super(SpreadLayer, self).__init__() # super(SpreadLayer, self).__init__()
@@ -85,6 +323,19 @@ class SemiUnitaryLayer(nn.Module):
#### as defined by zhang et al #### as defined by zhang et al
class DropoutComplex(nn.Module):
def __init__(self, p=0.5):
super(DropoutComplex, self).__init__()
self.dropout = nn.Dropout(p=p)
def forward(self, x):
if x.is_complex():
mask = self.dropout(torch.ones_like(x.real))
return x * mask
else:
return self.dropout(x)
class Identity(nn.Module): class Identity(nn.Module):
""" """
implements the "activation" function implements the "activation" function
@@ -97,18 +348,76 @@ class Identity(nn.Module):
def forward(self, x): def forward(self, x):
return x return x
class PowRot(nn.Module):
def __init__(self, bias=False):
super(PowRot, self).__init__()
self.scale = nn.Parameter(torch.tensor(1.0))
if bias:
self.bias = nn.Parameter(torch.tensor(0.0))
else:
self.register_buffer("bias", torch.tensor(0.0))
def forward(self, x: torch.Tensor):
if x.is_complex():
return x * torch.exp(-self.scale*1j*x.abs().square()+self.bias.to(dtype=x.dtype))
else:
return x
class Pow(nn.Module):
"""
implements the activation function
M(z) = ||z||^2 + b
"""
def __init__(self, bias=False):
super(Pow, self).__init__()
if bias:
self.bias = nn.Parameter(torch.tensor(0.0))
else:
self.register_buffer("bias", torch.tensor(0.0))
def forward(self, x: torch.Tensor):
return x.abs().square().add(self.bias).to(dtype=x.dtype)
class Mag(nn.Module): class Mag(nn.Module):
""" """
implements the activation function implements the activation function
M(z) = ||z|| M(z) = ||z||+b
""" """
def __init__(self): def __init__(self, bias=False):
super(Mag, self).__init__() super(Mag, self).__init__()
if bias:
self.bias = nn.Parameter(torch.tensor(0.0))
else:
self.register_buffer("bias", torch.tensor(0.0))
def forward(self, x): def forward(self, x: torch.Tensor):
return torch.abs(x).to(dtype=x.dtype) return x.abs().add(self.bias).to(dtype=x.dtype)
class MagScale(nn.Module):
def __init__(self, bias=False):
super(MagScale, self).__init__()
if bias:
self.bias = nn.Parameter(torch.tensor(0.0))
else:
self.register_buffer("bias", torch.tensor(0.0))
def forward(self, x: torch.Tensor):
return x.abs().add(self.bias).to(dtype=x.dtype).sin().mul(x)
class PowScale(nn.Module):
def __init__(self, bias=False):
super(PowScale, self).__init__()
if bias:
self.bias = nn.Parameter(torch.tensor(0.0))
else:
self.register_buffer("bias", torch.tensor(0.0))
def forward(self, x: torch.Tensor):
return x.mul(x.abs().square().add(self.bias).to(dtype=x.dtype).sin())
class ModReLU(nn.Module): class ModReLU(nn.Module):
@@ -118,17 +427,21 @@ class ModReLU(nn.Module):
= ReLU(||z|| + b)*z/||z|| = ReLU(||z|| + b)*z/||z||
""" """
def __init__(self, b=0): def __init__(self, bias=True):
super(ModReLU, self).__init__() super(ModReLU, self).__init__()
self.b = torch.tensor(b) if bias:
self.bias = nn.Parameter(torch.tensor(0.0))
else:
self.register_buffer("bias", torch.tensor(0.0))
def forward(self, x): def forward(self, x):
if x.is_complex(): if x.is_complex():
mod = torch.abs(x.real**2 + x.imag**2) mod = x.abs()
return torch.relu(mod + self.b) * x / mod out = torch.relu(mod + self.bias) * x / mod
return out.to(dtype=x.dtype)
else: else:
return torch.relu(x + self.b) return torch.relu(x + self.bias).to(dtype=x.dtype)
def __repr__(self): def __repr__(self):
return f"ModReLU(b={self.b})" return f"ModReLU(b={self.b})"
@@ -166,3 +479,26 @@ class ZReLU(nn.Module):
return x * (torch.angle(x) >= 0) * (torch.angle(x) <= torch.pi / 2) return x * (torch.angle(x) >= 0) * (torch.angle(x) <= torch.pi / 2)
else: else:
return torch.relu(x) return torch.relu(x)
__all__ = [
complex_sse_loss,
complex_mse_loss,
UnitaryLayer,
unitary,
clamp,
ONN,
ONNMiller,
SemiUnitaryLayer,
DropoutComplex,
Identity,
Pow,
PowRot,
Mag,
ModReLU,
CReLU,
ZReLU,
# SaturableAbsorberLambertW,
# SaturableAbsorber,
# SpreadLayer,
]

View File

@@ -20,7 +20,7 @@ def _optional_suggest(
type: str, type: str,
log: bool = False, log: bool = False,
step: int | float | None = None, step: int | float | None = None,
add_user: bool = False, add_user: bool = True,
force: bool = False, force: bool = False,
multiply: float | int = 1, multiply: float | int = 1,
set_new: bool = True, set_new: bool = True,
@@ -96,7 +96,7 @@ def suggest_categorical_optional(
trial: trial.Trial, trial: trial.Trial,
name: str, name: str,
choices_or_value: tuple[Any] | list[Any] | Any, choices_or_value: tuple[Any] | list[Any] | Any,
add_user: bool = False, add_user: bool = True,
force: bool = False, force: bool = False,
set_new: bool = True, set_new: bool = True,
): ):
@@ -129,7 +129,7 @@ def suggest_int_optional(
range_or_value: tuple[int] | list[int] | int, range_or_value: tuple[int] | list[int] | int,
step: int = 1, step: int = 1,
log: bool = False, log: bool = False,
add_user: bool = False, add_user: bool = True,
force: bool = False, force: bool = False,
multiply: int = 1, multiply: int = 1,
set_new: bool = True, set_new: bool = True,
@@ -174,7 +174,7 @@ def suggest_float_optional(
range_or_value: tuple[float] | list[float] | float, range_or_value: tuple[float] | list[float] | float,
step: float | None = None, step: float | None = None,
log: bool = False, log: bool = False,
add_user: bool = False, add_user: bool = True,
force: bool = False, force: bool = False,
multiply: float = 1, multiply: float = 1,
set_new: bool = True, set_new: bool = True,
@@ -222,7 +222,7 @@ def suggest_categorical_optional_wrapper(
self: trial.Trial, self: trial.Trial,
name: str, name: str,
choices_or_value: tuple[Any] | list[Any] | Any, choices_or_value: tuple[Any] | list[Any] | Any,
add_user: bool = False, add_user: bool = True,
force: bool = False, force: bool = False,
set_new: bool = True, set_new: bool = True,
): ):
@@ -253,7 +253,7 @@ def suggest_int_optional_wrapper(
range_or_value: tuple[int] | list[int] | int, range_or_value: tuple[int] | list[int] | int,
step: int = 1, step: int = 1,
log: bool = False, log: bool = False,
add_user: bool = False, add_user: bool = True,
force: bool = False, force: bool = False,
multiply: int = 1, multiply: int = 1,
set_new: bool = True, set_new: bool = True,
@@ -295,7 +295,7 @@ def suggest_float_optional_wrapper(
range_or_value: tuple[float] | list[float] | float, range_or_value: tuple[float] | list[float] | float,
step: float | None = None, step: float | None = None,
log: bool = False, log: bool = False,
add_user: bool = False, add_user: bool = True,
force: bool = False, force: bool = False,
multiply: float = 1, multiply: float = 1,
set_new: bool = True, set_new: bool = True,