Compare commits
7 Commits
0422c81f3b
...
487288c923
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
487288c923 | ||
|
|
bdf6f5bfb8 | ||
|
|
e02662ed4f | ||
|
|
fd7a0b9c31 | ||
|
|
ff32aefd52 | ||
|
|
b156b9ceaf | ||
|
|
cfa08aae4e |
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:f3510d41f9f0605e438a09767c43edda38162601292be1207f50747117ae5479
|
||||
size 9863168
|
||||
oid sha256:746ea83013870351296e01e294905ee027291ef79fd78c1e6b69dd9ebaa1cba0
|
||||
size 10240000
|
||||
|
||||
@@ -19,7 +19,7 @@ import time
|
||||
from matplotlib import pyplot as plt # noqa: F401
|
||||
import numpy as np
|
||||
|
||||
import path_fix
|
||||
import add_pypho # noqa: F401
|
||||
import pypho
|
||||
|
||||
default_config = f"""
|
||||
@@ -497,18 +497,18 @@ def plot_eye_diagram(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
path_fix.show_log()
|
||||
add_pypho.show_log()
|
||||
config = get_config()
|
||||
|
||||
length_ranges = [1000, 10000]
|
||||
length_scales = [1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
# length_ranges = [1000, 10000]
|
||||
# length_scales = [1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
|
||||
lengths = [
|
||||
length_scale * length_range
|
||||
for length_range in length_ranges
|
||||
for length_scale in length_scales
|
||||
]
|
||||
lengths.append(max(length_ranges)*10)
|
||||
# lengths = [
|
||||
# length_scale * length_range
|
||||
# for length_range in length_ranges
|
||||
# for length_scale in length_scales
|
||||
# ]
|
||||
# lengths.append(max(length_ranges)*10)
|
||||
|
||||
# length_loop(config, lengths)
|
||||
|
||||
|
||||
@@ -245,18 +245,18 @@ class HyperTraining:
|
||||
dtype = getattr(torch, dtype)
|
||||
|
||||
afunc = trial.suggest_categorical_optional("model_activation_func", self.model_settings.model_activation_func)
|
||||
# T0 = trial.suggest_float_optional("T0", self.model_settings.satabsT0 , log=True)
|
||||
|
||||
layers = []
|
||||
last_dim = input_dim
|
||||
n_nodes = last_dim
|
||||
for i in range(n_layers):
|
||||
if hidden_dim_override := self.model_settings.overrides.get(f"n_hidden_nodes_{i}", False):
|
||||
hidden_dim = trial.suggest_int_optional(f"model_hidden_dim_{i}", hidden_dim_override, force=True)
|
||||
hidden_dim = trial.suggest_int_optional(f"model_hidden_dim_{i}", hidden_dim_override)
|
||||
else:
|
||||
hidden_dim = trial.suggest_int_optional(
|
||||
f"model_hidden_dim_{i}",
|
||||
self.model_settings.n_hidden_nodes,
|
||||
# step=2,
|
||||
)
|
||||
layers.append(util.complexNN.SemiUnitaryLayer(last_dim, hidden_dim, dtype=dtype))
|
||||
last_dim = hidden_dim
|
||||
@@ -642,6 +642,7 @@ class HyperTraining:
|
||||
|
||||
if show:
|
||||
plt.show()
|
||||
return fig
|
||||
|
||||
def _plot_model_response_head(self, *signals, labels=None, sps=None, title_append="", subtitle="", show=True):
|
||||
if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))):
|
||||
@@ -684,7 +685,7 @@ class HyperTraining:
|
||||
):
|
||||
data_settings_backup = copy.deepcopy(self.data_settings)
|
||||
pytorch_settings_backup = copy.deepcopy(self.pytorch_settings)
|
||||
self.data_settings.drop_first = 100
|
||||
self.data_settings.drop_first = 100*128
|
||||
self.data_settings.shuffle = False
|
||||
self.data_settings.train_split = 1.0
|
||||
self.pytorch_settings.batchsize = (
|
||||
@@ -739,11 +740,15 @@ class HyperTraining:
|
||||
@staticmethod
|
||||
def build_title(trial: optuna.trial.Trial):
|
||||
title_append = f"for trial {trial.number}"
|
||||
model_n_layers = util.misc.multi_getattr((trial.params, trial.user_attrs), "model_n_layers", 0)
|
||||
model_hidden_dims = [
|
||||
model_n_hidden_layers = util.misc.multi_getattr((trial.params, trial.user_attrs), "model_n_hidden_layers", 0)
|
||||
input_dim = util.misc.multi_getattr((trial.params, trial.user_attrs), "model_input_dim", 0)
|
||||
model_dims = [
|
||||
util.misc.multi_getattr((trial.params, trial.user_attrs), f"model_hidden_dim_{i}", 0)
|
||||
for i in range(model_n_layers)
|
||||
for i in range(model_n_hidden_layers)
|
||||
]
|
||||
model_dims.insert(0, input_dim)
|
||||
model_dims.append(2)
|
||||
model_dims = [str(dim) for dim in model_dims]
|
||||
model_activation_func = util.misc.multi_getattr(
|
||||
(trial.params, trial.user_attrs),
|
||||
"model_activation_func",
|
||||
@@ -752,7 +757,7 @@ class HyperTraining:
|
||||
model_dtype = util.misc.multi_getattr((trial.params, trial.user_attrs), "model_dtype", "unknown dtype")
|
||||
|
||||
subtitle = (
|
||||
f"{model_n_layers} layers à ({', '.join(model_hidden_dims)}) units, {model_activation_func}, {model_dtype}"
|
||||
f"{model_n_hidden_layers+2} layers à ({', '.join(model_dims)}) units, {model_activation_func}, {model_dtype}"
|
||||
)
|
||||
|
||||
return title_append, subtitle
|
||||
|
||||
@@ -39,7 +39,7 @@ class PytorchSettings:
|
||||
summary_dir: str = ".runs"
|
||||
write_every: int = 10
|
||||
head_symbols: int = 40
|
||||
eye_symbols: int = 1000
|
||||
eye_symbols: int = 400
|
||||
|
||||
|
||||
# model settings
|
||||
@@ -48,8 +48,11 @@ class ModelSettings:
|
||||
output_dim: int = 2
|
||||
n_hidden_layers: tuple | int = 3
|
||||
n_hidden_nodes: tuple | int = 8
|
||||
model_activation_func: tuple = "ModReLU"
|
||||
model_activation_func: tuple | str = "ModReLU"
|
||||
overrides: dict = field(default_factory=dict)
|
||||
dropout_prob: float | None = None
|
||||
model_layer_function: str | None = None
|
||||
model_layer_parametrizations: list= field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
739
src/single-core-regen/hypertraining/training.py
Normal file
739
src/single-core-regen/hypertraining/training.py
Normal file
@@ -0,0 +1,739 @@
|
||||
import copy
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
import matplotlib
|
||||
import torch.nn.utils.parametrize
|
||||
|
||||
try:
|
||||
matplotlib.use("cairo")
|
||||
except ImportError:
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# import torch.nn.functional as F # mse_loss doesn't support complex numbers
|
||||
import torch.optim as optim
|
||||
import torch.utils.data
|
||||
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from rich.progress import (
|
||||
Progress,
|
||||
TextColumn,
|
||||
BarColumn,
|
||||
TaskProgressColumn,
|
||||
TimeRemainingColumn,
|
||||
MofNCompleteColumn,
|
||||
TimeElapsedColumn,
|
||||
)
|
||||
from rich.console import Console
|
||||
|
||||
from util.datasets import FiberRegenerationDataset
|
||||
import util
|
||||
|
||||
from .settings import (
|
||||
GlobalSettings,
|
||||
DataSettings,
|
||||
ModelSettings,
|
||||
OptimizerSettings,
|
||||
PytorchSettings,
|
||||
)
|
||||
|
||||
|
||||
class regenerator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*dims,
|
||||
layer_function=util.complexNN.ONN,
|
||||
layer_parametrizations: list[dict] = None,
|
||||
# [
|
||||
# {
|
||||
# "tensor_name": "weight",
|
||||
# "parametrization": util.complexNN.Unitary,
|
||||
# },
|
||||
# {
|
||||
# "tensor_name": "scale",
|
||||
# "parametrization": util.complexNN.Clamp,
|
||||
# },
|
||||
# ],
|
||||
activation_function=util.complexNN.Pow,
|
||||
dtype=torch.float64,
|
||||
dropout_prob=0.01,
|
||||
**kwargs,
|
||||
):
|
||||
super(regenerator, self).__init__()
|
||||
if len(dims) == 0:
|
||||
try:
|
||||
dims = kwargs["dims"]
|
||||
except KeyError:
|
||||
raise ValueError("dims must be provided")
|
||||
self._n_hidden_layers = len(dims) - 2
|
||||
self._layers = nn.Sequential()
|
||||
|
||||
for i in range(self._n_hidden_layers + 1):
|
||||
self._layers.append(layer_function(dims[i], dims[i + 1], dtype=dtype))
|
||||
if i < self._n_hidden_layers:
|
||||
if dropout_prob is not None:
|
||||
self._layers.append(util.complexNN.DropoutComplex(p=dropout_prob))
|
||||
self._layers.append(activation_function())
|
||||
|
||||
# add parametrizations
|
||||
if layer_parametrizations is not None:
|
||||
for layer_parametrization in layer_parametrizations:
|
||||
tensor_name = layer_parametrization.get("tensor_name", None)
|
||||
parametrization = layer_parametrization.get("parametrization", None)
|
||||
param_kwargs = layer_parametrization.get("kwargs", {})
|
||||
if (
|
||||
tensor_name is not None
|
||||
and tensor_name in self._layers[-1]._parameters
|
||||
and parametrization is not None
|
||||
):
|
||||
parametrization(self._layers[-1], tensor_name, **param_kwargs)
|
||||
|
||||
def forward(self, input_x):
|
||||
x = input_x
|
||||
# check if tracing
|
||||
if torch.jit.is_tracing():
|
||||
for layer in self._layers:
|
||||
x = layer(x)
|
||||
else:
|
||||
# with torch.nn.utils.parametrize.cached():
|
||||
for layer in self._layers:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
def traverse_dict_update(target, source):
|
||||
for k, v in source.items():
|
||||
if isinstance(v, dict):
|
||||
if k not in target:
|
||||
target[k] = {}
|
||||
traverse_dict_update(target[k], v)
|
||||
else:
|
||||
try:
|
||||
target[k] = v
|
||||
except TypeError:
|
||||
target.__dict__[k] = v
|
||||
|
||||
class Trainer:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
global_settings=None,
|
||||
data_settings=None,
|
||||
pytorch_settings=None,
|
||||
model_settings=None,
|
||||
optimizer_settings=None,
|
||||
console=None,
|
||||
checkpoint_path=None,
|
||||
settings_override=None,
|
||||
reset_epoch=False,
|
||||
):
|
||||
self.resume = checkpoint_path is not None
|
||||
torch.serialization.add_safe_globals([
|
||||
*util.complexNN.__all__,
|
||||
GlobalSettings,
|
||||
DataSettings,
|
||||
ModelSettings,
|
||||
OptimizerSettings,
|
||||
PytorchSettings,
|
||||
regenerator,
|
||||
torch.nn.utils.parametrizations.orthogonal
|
||||
])
|
||||
if self.resume:
|
||||
self.checkpoint_dict = torch.load(checkpoint_path, weights_only=True)
|
||||
if settings_override is not None:
|
||||
traverse_dict_update(self.checkpoint_dict["settings"], settings_override)
|
||||
if reset_epoch:
|
||||
self.checkpoint_dict["epoch"] = -1
|
||||
|
||||
self.global_settings: GlobalSettings = self.checkpoint_dict["settings"]["global_settings"]
|
||||
self.data_settings: DataSettings = self.checkpoint_dict["settings"]["data_settings"]
|
||||
self.pytorch_settings: PytorchSettings = self.checkpoint_dict["settings"]["pytorch_settings"]
|
||||
self.model_settings: ModelSettings = self.checkpoint_dict["settings"]["model_settings"]
|
||||
self.optimizer_settings: OptimizerSettings = self.checkpoint_dict["settings"]["optimizer_settings"]
|
||||
else:
|
||||
if global_settings is None:
|
||||
raise ValueError("global_settings must be provided")
|
||||
if data_settings is None:
|
||||
raise ValueError("data_settings must be provided")
|
||||
if pytorch_settings is None:
|
||||
raise ValueError("pytorch_settings must be provided")
|
||||
if model_settings is None:
|
||||
raise ValueError("model_settings must be provided")
|
||||
if optimizer_settings is None:
|
||||
raise ValueError("optimizer_settings must be provided")
|
||||
|
||||
self.global_settings: GlobalSettings = global_settings
|
||||
self.data_settings: DataSettings = data_settings
|
||||
self.pytorch_settings: PytorchSettings = pytorch_settings
|
||||
self.model_settings: ModelSettings = model_settings
|
||||
self.optimizer_settings: OptimizerSettings = optimizer_settings
|
||||
|
||||
self.console = console or Console()
|
||||
self.writer = None
|
||||
|
||||
def setup_tb_writer(self, append=None):
|
||||
log_dir = self.pytorch_settings.summary_dir + "/" + (datetime.now().strftime("%Y%m%d_%H%M%S"))
|
||||
if append is not None:
|
||||
log_dir += "_" + str(append)
|
||||
|
||||
print(f"Logging to {log_dir}")
|
||||
self.writer = SummaryWriter(log_dir=log_dir)
|
||||
|
||||
def save_checkpoint(self, save_dict, filename):
|
||||
torch.save(save_dict, filename)
|
||||
|
||||
def build_checkpoint_dict(self, loss=None, epoch=None):
|
||||
return {
|
||||
"epoch": -1 if epoch is None else epoch,
|
||||
"loss": float("inf") if loss is None else loss,
|
||||
"model_state_dict": copy.deepcopy(self.model.state_dict()),
|
||||
"optimizer_state_dict": copy.deepcopy(self.optimizer.state_dict()),
|
||||
"scheduler_state_dict": copy.deepcopy(self.scheduler.state_dict()) if hasattr(self, "scheduler") else None,
|
||||
"model_kwargs": copy.deepcopy(self.model_kwargs),
|
||||
"settings": {
|
||||
"global_settings": copy.deepcopy(self.global_settings),
|
||||
"data_settings": copy.deepcopy(self.data_settings),
|
||||
"pytorch_settings": copy.deepcopy(self.pytorch_settings),
|
||||
"model_settings": copy.deepcopy(self.model_settings),
|
||||
"optimizer_settings": copy.deepcopy(self.optimizer_settings),
|
||||
},
|
||||
}
|
||||
|
||||
def define_model(self, model_kwargs=None):
|
||||
if model_kwargs is None:
|
||||
n_hidden_layers = self.model_settings.n_hidden_layers
|
||||
|
||||
input_dim = 2 * self.data_settings.output_size
|
||||
|
||||
dtype = getattr(torch, self.data_settings.dtype)
|
||||
|
||||
afunc = getattr(util.complexNN, self.model_settings.model_activation_func)
|
||||
|
||||
layer_func = getattr(util.complexNN, self.model_settings.model_layer_function)
|
||||
|
||||
layer_parametrizations = self.model_settings.model_layer_parametrizations
|
||||
|
||||
hidden_dims = [self.model_settings.overrides.get(f"n_hidden_nodes_{i}") for i in range(n_hidden_layers)]
|
||||
|
||||
self.model_kwargs = {
|
||||
"dims": (input_dim, *hidden_dims, self.model_settings.output_dim),
|
||||
"layer_function": layer_func,
|
||||
"layer_parametrizations": layer_parametrizations,
|
||||
"activation_function": afunc,
|
||||
"dtype": dtype,
|
||||
"dropout_prob": self.model_settings.dropout_prob,
|
||||
}
|
||||
else:
|
||||
self.model_kwargs = model_kwargs
|
||||
input_dim = self.model_kwargs["dims"][0]
|
||||
dtype = self.model_kwargs["dtype"]
|
||||
|
||||
# dims = self.model_kwargs.pop("dims")
|
||||
self.model = regenerator(**self.model_kwargs)
|
||||
|
||||
self.writer.add_graph(self.model, torch.zeros(1, input_dim, dtype=dtype))
|
||||
|
||||
self.model = self.model.to(self.pytorch_settings.device)
|
||||
|
||||
def get_sliced_data(self, override=None):
|
||||
symbols = self.data_settings.symbols
|
||||
|
||||
in_out_delay = self.data_settings.in_out_delay
|
||||
|
||||
xy_delay = self.data_settings.xy_delay
|
||||
|
||||
data_size = self.data_settings.output_size
|
||||
|
||||
dtype = getattr(torch, self.data_settings.dtype)
|
||||
|
||||
num_symbols = None
|
||||
if override is not None:
|
||||
num_symbols = override.get("num_symbols", None)
|
||||
# get dataset
|
||||
dataset = FiberRegenerationDataset(
|
||||
file_path=self.data_settings.config_path,
|
||||
symbols=symbols,
|
||||
output_dim=data_size,
|
||||
target_delay=in_out_delay,
|
||||
xy_delay=xy_delay,
|
||||
drop_first=self.data_settings.drop_first,
|
||||
dtype=dtype,
|
||||
real=not dtype.is_complex,
|
||||
num_symbols=num_symbols,
|
||||
)
|
||||
|
||||
dataset_size = len(dataset)
|
||||
indices = list(range(dataset_size))
|
||||
split = int(np.floor(self.data_settings.train_split * dataset_size))
|
||||
if self.data_settings.shuffle:
|
||||
np.random.seed(self.global_settings.seed)
|
||||
np.random.shuffle(indices)
|
||||
|
||||
train_indices, valid_indices = indices[:split], indices[split:]
|
||||
|
||||
if self.data_settings.shuffle:
|
||||
train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
|
||||
valid_sampler = torch.utils.data.SubsetRandomSampler(valid_indices)
|
||||
else:
|
||||
train_sampler = train_indices
|
||||
valid_sampler = valid_indices
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=self.pytorch_settings.batchsize,
|
||||
sampler=train_sampler,
|
||||
drop_last=True,
|
||||
pin_memory=True,
|
||||
num_workers=self.pytorch_settings.dataloader_workers,
|
||||
prefetch_factor=self.pytorch_settings.dataloader_prefetch,
|
||||
)
|
||||
|
||||
valid_loader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=self.pytorch_settings.batchsize,
|
||||
sampler=valid_sampler,
|
||||
drop_last=True,
|
||||
pin_memory=True,
|
||||
num_workers=self.pytorch_settings.dataloader_workers,
|
||||
prefetch_factor=self.pytorch_settings.dataloader_prefetch,
|
||||
)
|
||||
|
||||
return train_loader, valid_loader
|
||||
|
||||
def train_model(
|
||||
self,
|
||||
optimizer,
|
||||
train_loader,
|
||||
epoch,
|
||||
enable_progress=False,
|
||||
):
|
||||
if enable_progress:
|
||||
progress = Progress(
|
||||
TextColumn("[yellow] Training..."),
|
||||
TextColumn("Error: {task.description}"),
|
||||
BarColumn(),
|
||||
TaskProgressColumn(),
|
||||
TextColumn("[green]Batch"),
|
||||
MofNCompleteColumn(),
|
||||
TimeRemainingColumn(),
|
||||
TimeElapsedColumn(),
|
||||
transient=False,
|
||||
console=self.console,
|
||||
refresh_per_second=10,
|
||||
)
|
||||
task = progress.add_task("-.---e--", total=len(train_loader))
|
||||
progress.start()
|
||||
|
||||
running_loss2 = 0.0
|
||||
running_loss = 0.0
|
||||
self.model.train()
|
||||
for batch_idx, (x, y) in enumerate(train_loader):
|
||||
self.model.zero_grad(set_to_none=True)
|
||||
x, y = (
|
||||
x.to(self.pytorch_settings.device),
|
||||
y.to(self.pytorch_settings.device),
|
||||
)
|
||||
y_pred = self.model(x)
|
||||
loss = util.complexNN.complex_mse_loss(y_pred, y, power=True)
|
||||
loss_value = loss.item()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
running_loss2 += loss_value
|
||||
running_loss += loss_value
|
||||
|
||||
if enable_progress:
|
||||
progress.update(task, advance=1, description=f"{loss_value:.3e}")
|
||||
|
||||
if batch_idx % self.pytorch_settings.write_every == 0:
|
||||
self.writer.add_scalar(
|
||||
"training loss",
|
||||
running_loss2 / (self.pytorch_settings.write_every if batch_idx > 0 else 1),
|
||||
epoch * len(train_loader) + batch_idx,
|
||||
)
|
||||
running_loss2 = 0.0
|
||||
|
||||
if enable_progress:
|
||||
progress.stop()
|
||||
|
||||
return running_loss / len(train_loader)
|
||||
|
||||
def eval_model(self, valid_loader, epoch, enable_progress=True):
|
||||
if enable_progress:
|
||||
progress = Progress(
|
||||
TextColumn("[green]Evaluating..."),
|
||||
TextColumn("Error: {task.description}"),
|
||||
BarColumn(),
|
||||
TaskProgressColumn(),
|
||||
TextColumn("[green]Batch"),
|
||||
MofNCompleteColumn(),
|
||||
TimeRemainingColumn(),
|
||||
TimeElapsedColumn(),
|
||||
transient=False,
|
||||
console=self.console,
|
||||
refresh_per_second=10,
|
||||
)
|
||||
progress.start()
|
||||
task = progress.add_task("-.---e--", total=len(valid_loader))
|
||||
|
||||
self.model.eval()
|
||||
running_error = 0
|
||||
with torch.no_grad():
|
||||
for batch_idx, (x, y) in enumerate(valid_loader):
|
||||
x, y = (
|
||||
x.to(self.pytorch_settings.device),
|
||||
y.to(self.pytorch_settings.device),
|
||||
)
|
||||
y_pred = self.model(x)
|
||||
error = util.complexNN.complex_mse_loss(y_pred, y, power=True)
|
||||
error_value = error.item()
|
||||
running_error += error_value
|
||||
|
||||
if enable_progress:
|
||||
progress.update(task, advance=1, description=f"{error_value:.3e}")
|
||||
|
||||
|
||||
running_error /= len(valid_loader)
|
||||
self.writer.add_scalar(
|
||||
"eval loss",
|
||||
running_error,
|
||||
epoch,
|
||||
)
|
||||
title_append, subtitle = self.build_title(epoch + 1)
|
||||
self.writer.add_figure(
|
||||
"fiber response",
|
||||
self.plot_model_response(
|
||||
model=self.model,
|
||||
title_append=title_append,
|
||||
subtitle=subtitle,
|
||||
show=False,
|
||||
),
|
||||
epoch + 1,
|
||||
)
|
||||
self.writer.add_figure(
|
||||
"eye diagram",
|
||||
self.plot_model_response(
|
||||
model=self.model,
|
||||
title_append=title_append,
|
||||
subtitle=subtitle,
|
||||
show=False,
|
||||
mode="eye",
|
||||
),
|
||||
epoch + 1,
|
||||
)
|
||||
self.writer_histograms(epoch + 1)
|
||||
|
||||
if enable_progress:
|
||||
progress.stop()
|
||||
|
||||
return running_error
|
||||
|
||||
def run_model(self, model, loader):
|
||||
model.eval()
|
||||
xs = []
|
||||
ys = []
|
||||
y_preds = []
|
||||
with torch.no_grad():
|
||||
model = model.to(self.pytorch_settings.device)
|
||||
for x, y in loader:
|
||||
x, y = (
|
||||
x.to(self.pytorch_settings.device),
|
||||
y.to(self.pytorch_settings.device),
|
||||
)
|
||||
y_pred = model(x).cpu()
|
||||
# x = x.cpu()
|
||||
# y = y.cpu()
|
||||
y_pred = y_pred.view(y_pred.shape[0], -1, 2)
|
||||
y = y.view(y.shape[0], -1, 2)
|
||||
x = x.view(x.shape[0], -1, 2)
|
||||
xs.append(x[:, 0, :].squeeze())
|
||||
ys.append(y.squeeze())
|
||||
y_preds.append(y_pred.squeeze())
|
||||
|
||||
xs = torch.vstack(xs).cpu()
|
||||
ys = torch.vstack(ys).cpu()
|
||||
y_preds = torch.vstack(y_preds).cpu()
|
||||
return ys, xs, y_preds
|
||||
|
||||
def writer_histograms(self, epoch, attributes=["weight", "weight_U", "weight_V", "bias", "sigma", "scale"]):
|
||||
for i, layer in enumerate(self.model._layers):
|
||||
tag = f"layer {i}"
|
||||
for attribute in attributes:
|
||||
if hasattr(layer, attribute):
|
||||
vals: np.ndarray = getattr(layer, attribute).detach().cpu().numpy().flatten()
|
||||
if vals.ndim <= 1 and len(vals) == 1:
|
||||
if np.iscomplexobj(vals):
|
||||
self.writer.add_scalar(f"{tag} {attribute} (Mag)", np.abs(vals), epoch)
|
||||
self.writer.add_scalar(f"{tag} {attribute} (Phase)", np.angle(vals), epoch)
|
||||
else:
|
||||
self.writer.add_scalar(f"{tag} {attribute}", vals, epoch)
|
||||
else:
|
||||
if np.iscomplexobj(vals):
|
||||
self.writer.add_histogram(f"{tag} {attribute} (Mag)", np.abs(vals), epoch, bins="fd")
|
||||
self.writer.add_histogram(f"{tag} {attribute} (Phase)", np.angle(vals), epoch, bins="fd")
|
||||
else:
|
||||
self.writer.add_histogram(f"{tag} {attribute}", vals, epoch, bins="fd")
|
||||
|
||||
def train(self):
|
||||
if self.writer is None:
|
||||
self.setup_tb_writer()
|
||||
|
||||
if self.resume:
|
||||
model_kwargs = self.checkpoint_dict["model_kwargs"]
|
||||
else:
|
||||
model_kwargs = None
|
||||
|
||||
self.define_model(model_kwargs=model_kwargs)
|
||||
|
||||
print(f"number of parameters (trainable): {sum(p.numel() for p in self.model.parameters())} ({sum(p.numel() for p in self.model.parameters() if p.requires_grad)})")
|
||||
|
||||
title_append, subtitle = self.build_title(0)
|
||||
|
||||
self.writer.add_figure(
|
||||
"fiber response",
|
||||
self.plot_model_response(
|
||||
model=self.model,
|
||||
title_append=title_append,
|
||||
subtitle=subtitle,
|
||||
show=False,
|
||||
),
|
||||
0,
|
||||
)
|
||||
self.writer.add_figure(
|
||||
"eye diagram",
|
||||
self.plot_model_response(
|
||||
model=self.model,
|
||||
title_append=title_append,
|
||||
subtitle=subtitle,
|
||||
mode="eye",
|
||||
show=False,
|
||||
),
|
||||
0,
|
||||
)
|
||||
self.writer_histograms(0)
|
||||
|
||||
train_loader, valid_loader = self.get_sliced_data()
|
||||
|
||||
optimizer_name = self.optimizer_settings.optimizer
|
||||
|
||||
lr = self.optimizer_settings.learning_rate
|
||||
|
||||
self.optimizer: optim.Optimizer = getattr(optim, optimizer_name)(self.model.parameters(), lr=lr)
|
||||
if self.optimizer_settings.scheduler is not None:
|
||||
self.scheduler = getattr(optim.lr_scheduler, self.optimizer_settings.scheduler)(
|
||||
self.optimizer, **self.optimizer_settings.scheduler_kwargs
|
||||
)
|
||||
if self.resume:
|
||||
try:
|
||||
self.scheduler.load_state_dict(self.checkpoint_dict["scheduler_state_dict"])
|
||||
except ValueError:
|
||||
pass
|
||||
self.writer.add_scalar("learning rate", self.scheduler.get_last_lr()[0], -1)
|
||||
|
||||
|
||||
if not self.resume:
|
||||
self.best = self.build_checkpoint_dict()
|
||||
else:
|
||||
self.best = self.checkpoint_dict
|
||||
self.model.load_state_dict(self.best["model_state_dict"], strict=False)
|
||||
try:
|
||||
self.optimizer.load_state_dict(self.best["optimizer_state_dict"])
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
for epoch in range(self.best["epoch"] + 1, self.pytorch_settings.epochs):
|
||||
enable_progress = True
|
||||
if enable_progress:
|
||||
self.console.rule(f"Epoch {epoch + 1}/{self.pytorch_settings.epochs}")
|
||||
self.train_model(
|
||||
self.optimizer,
|
||||
train_loader,
|
||||
epoch,
|
||||
enable_progress=enable_progress,
|
||||
)
|
||||
loss = self.eval_model(
|
||||
valid_loader,
|
||||
epoch,
|
||||
enable_progress=enable_progress,
|
||||
)
|
||||
if self.optimizer_settings.scheduler is not None:
|
||||
lr_old = self.scheduler.get_last_lr()
|
||||
self.scheduler.step(loss)
|
||||
lr_new = self.scheduler.get_last_lr()
|
||||
if lr_old[0] != lr_new[0]:
|
||||
self.writer.add_scalar("learning rate", lr_new[0], epoch)
|
||||
|
||||
if self.pytorch_settings.save_models and self.model is not None:
|
||||
save_path = (
|
||||
Path(self.pytorch_settings.model_dir) / f"{self.writer.get_logdir().split('/')[-1]}_{epoch}.tar"
|
||||
)
|
||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
checkpoint = self.build_checkpoint_dict(loss, epoch)
|
||||
self.save_checkpoint(checkpoint, save_path)
|
||||
|
||||
if loss < self.best["loss"]:
|
||||
self.best = checkpoint
|
||||
save_path = (
|
||||
Path(self.pytorch_settings.model_dir) / f"best_{self.writer.get_logdir().split('/')[-1]}.tar"
|
||||
)
|
||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.save_checkpoint(self.best, save_path)
|
||||
self.writer.flush()
|
||||
|
||||
self.writer.close()
|
||||
return self.best
|
||||
|
||||
def _plot_model_response_eye(self, *signals, labels=None, sps=None, title_append="", subtitle="", show=True):
|
||||
if sps is None:
|
||||
raise ValueError("sps must be provided")
|
||||
if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))):
|
||||
labels = [labels]
|
||||
else:
|
||||
labels = list(labels)
|
||||
|
||||
while len(labels) < len(signals):
|
||||
labels.append(None)
|
||||
|
||||
# check if there are any labels
|
||||
if not any(labels):
|
||||
labels = [f"signal {i + 1}" for i in range(len(signals))]
|
||||
|
||||
fig, axs = plt.subplots(1, 2 * len(signals), sharex=True, sharey=True)
|
||||
fig.set_figwidth(18)
|
||||
fig.suptitle(f"Eye diagram{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}")
|
||||
xaxis = np.linspace(0, 2, 2 * sps, endpoint=False)
|
||||
for j, (label, signal) in enumerate(zip(labels, signals)):
|
||||
# signal = signal.cpu().numpy()
|
||||
for i in range(len(signal) // sps - 1):
|
||||
x, y = signal[i * sps : (i + 2) * sps].T
|
||||
axs[0 + 2 * j].plot(xaxis, np.abs(x) ** 2, color=f"C{j}", alpha=1 / (len(signal) // sps) * 10)
|
||||
axs[1 + 2 * j].plot(xaxis, np.abs(y) ** 2, color=f"C{j}", alpha=1 / (len(signal) // sps) * 10)
|
||||
axs[0 + 2 * j].set_title(label + " x")
|
||||
axs[1 + 2 * j].set_title(label + " y")
|
||||
axs[0 + 2 * j].set_xlabel("Symbol")
|
||||
axs[1 + 2 * j].set_xlabel("Symbol")
|
||||
axs[0 + 2 * j].set_box_aspect(1)
|
||||
axs[1 + 2 * j].set_box_aspect(1)
|
||||
axs[0].set_ylabel("normalized power")
|
||||
fig.tight_layout()
|
||||
# axs[1+2*len(labels)-1].set_ylabel("normalized power")
|
||||
|
||||
if show:
|
||||
plt.show()
|
||||
return fig
|
||||
|
||||
def _plot_model_response_head(self, *signals, labels=None, sps=None, title_append="", subtitle="", show=True):
|
||||
if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))):
|
||||
labels = [labels]
|
||||
else:
|
||||
labels = list(labels)
|
||||
|
||||
while len(labels) < len(signals):
|
||||
labels.append(None)
|
||||
|
||||
# check if there are any labels
|
||||
if not any(labels):
|
||||
labels = [f"signal {i + 1}" for i in range(len(signals))]
|
||||
|
||||
fig, axs = plt.subplots(1, 2, sharex=True, sharey=True)
|
||||
fig.set_figwidth(18)
|
||||
fig.set_figheight(4)
|
||||
fig.suptitle(f"Fiber response{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}")
|
||||
for i, ax in enumerate(axs):
|
||||
for signal, label in zip(signals, labels):
|
||||
if sps is not None:
|
||||
xaxis = np.linspace(0, len(signal) / sps, len(signal), endpoint=False)
|
||||
else:
|
||||
xaxis = np.arange(len(signal))
|
||||
ax.plot(xaxis, np.abs(signal[:, i]) ** 2, label=label)
|
||||
ax.set_xlabel("Sample" if sps is None else "Symbol")
|
||||
ax.set_ylabel("normalized power")
|
||||
ax.legend(loc="upper right")
|
||||
fig.tight_layout()
|
||||
if show:
|
||||
plt.show()
|
||||
return fig
|
||||
|
||||
def plot_model_response(
|
||||
self,
|
||||
model=None,
|
||||
title_append="",
|
||||
subtitle="",
|
||||
mode: Literal["eye", "head"] = "head",
|
||||
show=False,
|
||||
):
|
||||
data_settings_backup = copy.deepcopy(self.data_settings)
|
||||
pytorch_settings_backup = copy.deepcopy(self.pytorch_settings)
|
||||
self.data_settings.drop_first = 100 * 128
|
||||
self.data_settings.shuffle = False
|
||||
self.data_settings.train_split = 1.0
|
||||
self.pytorch_settings.batchsize = (
|
||||
self.pytorch_settings.eye_symbols if mode == "eye" else self.pytorch_settings.head_symbols
|
||||
)
|
||||
plot_loader, _ = self.get_sliced_data(override={"num_symbols": self.pytorch_settings.batchsize})
|
||||
self.data_settings = data_settings_backup
|
||||
self.pytorch_settings = pytorch_settings_backup
|
||||
|
||||
fiber_in, fiber_out, regen = self.run_model(model, plot_loader)
|
||||
fiber_in = fiber_in.view(-1, 2)
|
||||
fiber_out = fiber_out.view(-1, 2)
|
||||
regen = regen.view(-1, 2)
|
||||
|
||||
fiber_in = fiber_in.numpy()
|
||||
fiber_out = fiber_out.numpy()
|
||||
regen = regen.numpy()
|
||||
|
||||
# https://github.com/matplotlib/matplotlib/issues/27713#issue-2104110987
|
||||
# https://github.com/matplotlib/matplotlib/issues/27713#issuecomment-1915497463
|
||||
import gc
|
||||
|
||||
if mode == "head":
|
||||
fig = self._plot_model_response_head(
|
||||
fiber_in,
|
||||
fiber_out,
|
||||
regen,
|
||||
labels=("fiber in", "fiber out", "regen"),
|
||||
sps=plot_loader.dataset.samples_per_symbol,
|
||||
title_append=title_append,
|
||||
subtitle=subtitle,
|
||||
show=show,
|
||||
)
|
||||
elif mode == "eye":
|
||||
# raise NotImplementedError("Eye diagram not implemented")
|
||||
fig = self._plot_model_response_eye(
|
||||
fiber_in,
|
||||
fiber_out,
|
||||
regen,
|
||||
labels=("fiber in", "fiber out", "regen"),
|
||||
sps=plot_loader.dataset.samples_per_symbol,
|
||||
title_append=title_append,
|
||||
subtitle=subtitle,
|
||||
show=show,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown mode: {mode}")
|
||||
gc.collect()
|
||||
|
||||
return fig
|
||||
|
||||
def build_title(self, number: int):
|
||||
title_append = f"epoch {number}"
|
||||
model_n_hidden_layers = self.model_settings.n_hidden_layers
|
||||
input_dim = 2 * self.data_settings.output_size
|
||||
model_dims = [
|
||||
self.model_settings.overrides.get(f"n_hidden_nodes_{i}", -1) for i in range(model_n_hidden_layers)
|
||||
]
|
||||
model_dims.insert(0, input_dim)
|
||||
model_dims.append(2)
|
||||
model_dims = [str(dim) for dim in model_dims]
|
||||
model_activation_func = self.model_settings.model_activation_func
|
||||
model_dtype = self.data_settings.dtype
|
||||
|
||||
subtitle = f"{model_n_hidden_layers + 2} layers à ({', '.join(model_dims)}) units, {model_activation_func}, {model_dtype}"
|
||||
|
||||
return title_append, subtitle
|
||||
@@ -30,10 +30,10 @@ data_settings = DataSettings(
|
||||
)
|
||||
|
||||
pytorch_settings = PytorchSettings(
|
||||
epochs=10,
|
||||
epochs=10000,
|
||||
batchsize=2**10,
|
||||
device="cuda",
|
||||
dataloader_workers=2,
|
||||
dataloader_workers=12,
|
||||
dataloader_prefetch=4,
|
||||
summary_dir=".runs",
|
||||
write_every=2**5,
|
||||
@@ -44,33 +44,31 @@ pytorch_settings = PytorchSettings(
|
||||
model_settings = ModelSettings(
|
||||
output_dim=2,
|
||||
# n_hidden_layers = (3, 8),
|
||||
n_hidden_layers=(4, 6), # study: single_core_regen_20241123_011232
|
||||
n_hidden_nodes=(4,20),
|
||||
# overrides={
|
||||
# "n_hidden_nodes_0": (14, 20), # study: single_core_regen_20241123_011232
|
||||
# "n_hidden_nodes_1": (8, 16),
|
||||
# "n_hidden_nodes_2": (10, 16),
|
||||
# # "n_hidden_nodes_3": (4, 20), # study: single_core_regen_20241123_135749
|
||||
# "n_hidden_nodes_4": (2, 8),
|
||||
# "n_hidden_nodes_5": (10, 16),
|
||||
# },
|
||||
# model_activation_func = ("ModReLU", "Mag", "Identity")
|
||||
model_activation_func="Mag", # study: single_core_regen_20241123_011232
|
||||
n_hidden_layers=4,
|
||||
overrides={
|
||||
"n_hidden_nodes_0": 8,
|
||||
"n_hidden_nodes_1": 6,
|
||||
"n_hidden_nodes_2": 4,
|
||||
"n_hidden_nodes_3": 8,
|
||||
},
|
||||
model_activation_func="Mag",
|
||||
# satabsT0=(1e-6, 1),
|
||||
)
|
||||
|
||||
optimizer_settings = OptimizerSettings(
|
||||
optimizer="Adam",
|
||||
# learning_rate = (1e-5, 1e-1),
|
||||
learning_rate=5e-4,
|
||||
learning_rate=5e-3
|
||||
# learning_rate=5e-4,
|
||||
)
|
||||
|
||||
optuna_settings = OptunaSettings(
|
||||
n_trials=512,
|
||||
n_workers=14,
|
||||
n_trials=1,
|
||||
n_workers=1,
|
||||
timeout=3600,
|
||||
directions=("maximize", "minimize"),
|
||||
metrics_names=("neg_log_mse","n_nodes"),
|
||||
limit_examples=True,
|
||||
directions=("minimize",),
|
||||
metrics_names=("mse",),
|
||||
limit_examples=False,
|
||||
n_train_batches=500,
|
||||
# n_valid_batches = 100,
|
||||
storage="sqlite:///data/single_core_regen.db",
|
||||
|
||||
@@ -1,414 +1,130 @@
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# import torch.nn.functional as F # mse_loss doesn't support complex numbers
|
||||
import torch.optim as optim
|
||||
import torch.utils.data
|
||||
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from rich.progress import (
|
||||
Progress,
|
||||
TextColumn,
|
||||
BarColumn,
|
||||
TaskProgressColumn,
|
||||
TimeRemainingColumn,
|
||||
MofNCompleteColumn,
|
||||
TimeElapsedColumn,
|
||||
from hypertraining.settings import (
|
||||
GlobalSettings,
|
||||
DataSettings,
|
||||
PytorchSettings,
|
||||
ModelSettings,
|
||||
OptimizerSettings,
|
||||
)
|
||||
from rich.console import Console
|
||||
from rich import print as rprint
|
||||
|
||||
# from util.optuna_helpers import optional_suggest_categorical, optional_suggest_float, optional_suggest_int
|
||||
from hypertraining.training import Trainer
|
||||
import torch
|
||||
import json
|
||||
import util
|
||||
|
||||
global_settings = GlobalSettings(
|
||||
seed=42,
|
||||
)
|
||||
|
||||
# global settings
|
||||
@dataclass
|
||||
class GlobalSettings:
|
||||
seed: int = 42
|
||||
data_settings = DataSettings(
|
||||
config_path="data/*-128-16384-100000-0-0-17-0-PAM4-0.ini",
|
||||
dtype="complex64",
|
||||
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
|
||||
symbols=13, # study: single_core_regen_20241123_011232
|
||||
# output_size = (11, 32), # 26 taps -> 2 taps per input symbol -> 1 tap every 1cm (model has 52 inputs (x/y))
|
||||
output_size=26, # study: single_core_regen_20241123_011232 (model_input_dim/2)
|
||||
shuffle=True,
|
||||
in_out_delay=0,
|
||||
xy_delay=0,
|
||||
drop_first=128*64,
|
||||
train_split=0.8,
|
||||
)
|
||||
|
||||
pytorch_settings = PytorchSettings(
|
||||
epochs=10000,
|
||||
batchsize=2**12,
|
||||
device="cuda",
|
||||
dataloader_workers=12,
|
||||
dataloader_prefetch=8,
|
||||
summary_dir=".runs",
|
||||
write_every=2**5,
|
||||
save_models=True,
|
||||
model_dir=".models",
|
||||
)
|
||||
|
||||
# data settings
|
||||
@dataclass
|
||||
class DataSettings:
|
||||
config_path: str = "data/*-128-16384-100000-0-0-17-0-PAM4-0.ini"
|
||||
dtype: torch.dtype = torch.complex64
|
||||
symbols_range: float | int = 8
|
||||
data_size_range: float | int = 64
|
||||
shuffle: bool = True
|
||||
target_delay: float = 0
|
||||
xy_delay_range: float | int = 0
|
||||
drop_first: int = 10
|
||||
train_split: float = 0.8
|
||||
model_settings = ModelSettings(
|
||||
output_dim=2,
|
||||
n_hidden_layers=4,
|
||||
overrides={
|
||||
"n_hidden_nodes_0": 8,
|
||||
"n_hidden_nodes_1": 8,
|
||||
"n_hidden_nodes_2": 4,
|
||||
"n_hidden_nodes_3": 6,
|
||||
},
|
||||
model_activation_func="PowScale",
|
||||
# dropout_prob=0.01,
|
||||
model_layer_function="ONN",
|
||||
model_layer_parametrizations=[
|
||||
{
|
||||
"tensor_name": "weight",
|
||||
"parametrization": torch.nn.utils.parametrizations.orthogonal,
|
||||
},
|
||||
{
|
||||
"tensor_name": "scales",
|
||||
"parametrization": util.complexNN.clamp,
|
||||
},
|
||||
{
|
||||
"tensor_name": "scale",
|
||||
"parametrization": util.complexNN.clamp,
|
||||
},
|
||||
{
|
||||
"tensor_name": "bias",
|
||||
"parametrization": util.complexNN.clamp,
|
||||
},
|
||||
# {
|
||||
# "tensor_name": "V",
|
||||
# "parametrization": torch.nn.utils.parametrizations.orthogonal,
|
||||
# },
|
||||
# {
|
||||
# "tensor_name": "S",
|
||||
# "parametrization": util.complexNN.clamp,
|
||||
# },
|
||||
],
|
||||
)
|
||||
|
||||
optimizer_settings = OptimizerSettings(
|
||||
optimizer="Adam",
|
||||
learning_rate=0.05,
|
||||
scheduler="ReduceLROnPlateau",
|
||||
scheduler_kwargs={
|
||||
"patience": 2**6,
|
||||
"factor": 0.9,
|
||||
# "threshold": 1e-3,
|
||||
"min_lr": 1e-6,
|
||||
"cooldown": 10,
|
||||
},
|
||||
)
|
||||
|
||||
# pytorch settings
|
||||
@dataclass
|
||||
class PytorchSettings:
|
||||
epochs: int = 1000
|
||||
batchsize: int = 2**12
|
||||
device: str = "cuda"
|
||||
summary_dir: str = ".runs"
|
||||
model_dir: str = ".models"
|
||||
def save_dict_to_file(dictionary, filename):
|
||||
"""
|
||||
Save the best dictionary to a JSON file.
|
||||
|
||||
:param best: Dictionary containing the best training results.
|
||||
:type best: dict
|
||||
:param filename: Path to the JSON file where the dictionary will be saved.
|
||||
:type filename: str
|
||||
"""
|
||||
with open(filename, 'w') as f:
|
||||
json.dump(dictionary, f, indent=4)
|
||||
|
||||
# model settings
|
||||
@dataclass
|
||||
class ModelSettings:
|
||||
output_size: int = 2
|
||||
# n_layer_range: float|int = 2
|
||||
# n_units_range: float|int = 32
|
||||
n_layers: int = 3
|
||||
n_units: int = 32
|
||||
activation_func: tuple | str = "ModReLU"
|
||||
|
||||
|
||||
@dataclass
|
||||
class OptimizerSettings:
|
||||
optimizer_range: str = "Adam"
|
||||
lr_range: float = 2e-3
|
||||
|
||||
|
||||
class Training:
|
||||
def __init__(self):
|
||||
self.global_settings = GlobalSettings()
|
||||
self.data_settings = DataSettings()
|
||||
self.pytorch_settings = PytorchSettings()
|
||||
self.model_settings = ModelSettings()
|
||||
self.optimizer_settings = OptimizerSettings()
|
||||
self.study_name = (
|
||||
f"single_core_regen_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}"
|
||||
)
|
||||
|
||||
if not hasattr(self.pytorch_settings, "model_dir"):
|
||||
self.pytorch_settings.model_dir = ".models"
|
||||
|
||||
self.writer = None
|
||||
self.console = Console()
|
||||
|
||||
def setup_tb_writer(self, study_name=None, append=None):
|
||||
log_dir = (
|
||||
self.pytorch_settings.summary_dir + "/" + (study_name or self.study_name) + ("_" + str(append)) if append else ""
|
||||
)
|
||||
self.writer = SummaryWriter(log_dir)
|
||||
return self.writer
|
||||
|
||||
def plot_eye(self, width=2, symbols=None, alpha=None, complex=False, show=True):
|
||||
if not hasattr(self, "eye_data"):
|
||||
data, config = util.datasets.load_data(
|
||||
self.data_settings.config_path,
|
||||
skipfirst=10,
|
||||
symbols=symbols or 1000,
|
||||
real=not self.data_settings.dtype.is_complex,
|
||||
normalize=True,
|
||||
)
|
||||
self.eye_data = {"data": data, "sps": int(config["glova"]["sps"])}
|
||||
return util.plot.eye(
|
||||
**self.eye_data,
|
||||
width=width,
|
||||
show=show,
|
||||
alpha=alpha,
|
||||
complex=complex,
|
||||
symbols=symbols or 1000,
|
||||
skipfirst=0,
|
||||
)
|
||||
|
||||
def define_model(self):
|
||||
n_layers = self.model_settings.n_layers
|
||||
|
||||
in_features = 2 * self.data_settings.data_size_range
|
||||
|
||||
layers = []
|
||||
for i in range(n_layers):
|
||||
out_features = self.model_settings.n_units
|
||||
|
||||
layers.append(util.complexNN.UnitaryLayer(in_features, out_features))
|
||||
# layers.append(getattr(nn, self.model_settings.activation_func)())
|
||||
layers.append(
|
||||
getattr(util.complexNN, self.model_settings.activation_func)()
|
||||
)
|
||||
in_features = out_features
|
||||
|
||||
layers.append(
|
||||
util.complexNN.UnitaryLayer(in_features, self.model_settings.output_size)
|
||||
)
|
||||
|
||||
if self.writer is not None:
|
||||
self.writer.add_graph(
|
||||
nn.Sequential(*layers),
|
||||
torch.zeros(1, layers[0].in_features, dtype=self.data_settings.dtype),
|
||||
)
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def get_sliced_data(self):
|
||||
symbols = self.data_settings.symbols_range
|
||||
|
||||
xy_delay = self.data_settings.xy_delay_range
|
||||
|
||||
data_size = self.data_settings.data_size_range
|
||||
|
||||
# get dataset
|
||||
dataset = util.datasets.FiberRegenerationDataset(
|
||||
file_path=self.data_settings.config_path,
|
||||
symbols=symbols,
|
||||
output_dim=data_size,
|
||||
target_delay=self.data_settings.target_delay,
|
||||
xy_delay=xy_delay,
|
||||
drop_first=self.data_settings.drop_first,
|
||||
dtype=self.data_settings.dtype,
|
||||
real=not self.data_settings.dtype.is_complex,
|
||||
# device=self.pytorch_settings.device,
|
||||
)
|
||||
|
||||
dataset_size = len(dataset)
|
||||
indices = list(range(dataset_size))
|
||||
split = int(np.floor(self.data_settings.train_split * dataset_size))
|
||||
if self.data_settings.shuffle:
|
||||
np.random.seed(self.global_settings.seed)
|
||||
np.random.shuffle(indices)
|
||||
|
||||
train_indices, valid_indices = indices[:split], indices[split:]
|
||||
|
||||
if self.data_settings.shuffle:
|
||||
train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
|
||||
valid_sampler = torch.utils.data.SubsetRandomSampler(valid_indices)
|
||||
else:
|
||||
train_sampler = train_indices
|
||||
valid_sampler = valid_indices
|
||||
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=self.pytorch_settings.batchsize,
|
||||
sampler=train_sampler,
|
||||
drop_last=True,
|
||||
pin_memory=True,
|
||||
num_workers=24,
|
||||
prefetch_factor=4,
|
||||
# persistent_workers=True
|
||||
)
|
||||
valid_loader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=self.pytorch_settings.batchsize,
|
||||
sampler=valid_sampler,
|
||||
drop_last=True,
|
||||
pin_memory=True,
|
||||
num_workers=24,
|
||||
prefetch_factor=4,
|
||||
# persistent_workers=True
|
||||
)
|
||||
|
||||
return train_loader, valid_loader
|
||||
|
||||
def train_model(self, model, optimizer, train_loader, epoch):
|
||||
with Progress(
|
||||
TextColumn("[yellow] Training..."),
|
||||
TextColumn("Error: {task.description}"),
|
||||
BarColumn(),
|
||||
TaskProgressColumn(),
|
||||
TextColumn("[green]Batch"),
|
||||
MofNCompleteColumn(),
|
||||
TimeRemainingColumn(),
|
||||
TimeElapsedColumn(),
|
||||
# description="Training",
|
||||
transient=False,
|
||||
console=self.console,
|
||||
refresh_per_second=10,
|
||||
) as progress:
|
||||
task = progress.add_task("-.---e--", total=len(train_loader))
|
||||
|
||||
running_loss = 0.0
|
||||
model.train()
|
||||
for batch_idx, (x, y) in enumerate(train_loader):
|
||||
model.zero_grad(set_to_none=True)
|
||||
x, y = (
|
||||
x.to(self.pytorch_settings.device),
|
||||
y.to(self.pytorch_settings.device),
|
||||
)
|
||||
y_pred = model(x)
|
||||
loss = util.complexNN.complex_mse_loss(y_pred, y)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
progress.update(task, advance=1, description=f"{loss.item():.3e}")
|
||||
|
||||
running_loss += loss.item()
|
||||
if self.writer is not None:
|
||||
if (batch_idx + 1) % 10 == 0:
|
||||
self.writer.add_scalar(
|
||||
"training loss",
|
||||
running_loss / 10,
|
||||
epoch * len(train_loader) + batch_idx,
|
||||
)
|
||||
running_loss = 0.0
|
||||
|
||||
return running_loss
|
||||
|
||||
def eval_model(self, model, valid_loader, epoch):
|
||||
with Progress(
|
||||
TextColumn("[green]Evaluating..."),
|
||||
TextColumn("Error: {task.description}"),
|
||||
BarColumn(),
|
||||
TaskProgressColumn(),
|
||||
TextColumn("[green]Batch"),
|
||||
MofNCompleteColumn(),
|
||||
TimeRemainingColumn(),
|
||||
TimeElapsedColumn(),
|
||||
# description="Training",
|
||||
transient=False,
|
||||
console=self.console,
|
||||
refresh_per_second=10,
|
||||
) as progress:
|
||||
task = progress.add_task("-.---e--", total=len(valid_loader))
|
||||
|
||||
model.eval()
|
||||
running_loss = 0
|
||||
running_loss2 = 0
|
||||
with torch.no_grad():
|
||||
for batch_idx, (x, y) in enumerate(valid_loader):
|
||||
x, y = (
|
||||
x.to(self.pytorch_settings.device),
|
||||
y.to(self.pytorch_settings.device),
|
||||
)
|
||||
y_pred = model(x)
|
||||
loss = util.complexNN.complex_mse_loss(y_pred, y)
|
||||
running_loss += loss.item()
|
||||
running_loss2 += loss.item()
|
||||
|
||||
progress.update(task, advance=1, description=f"{loss.item():.3e}")
|
||||
if self.writer is not None:
|
||||
if (batch_idx + 1) % 10 == 0:
|
||||
self.writer.add_scalar(
|
||||
"loss",
|
||||
running_loss / 10,
|
||||
epoch * len(valid_loader) + batch_idx,
|
||||
)
|
||||
running_loss = 0.0
|
||||
|
||||
if self.writer is not None:
|
||||
self.writer.add_figure("fiber response", self.plot_model_response(model, plot=False), epoch+1)
|
||||
|
||||
return running_loss2 / len(valid_loader)
|
||||
|
||||
def run_model(self, model, loader):
|
||||
model.eval()
|
||||
xs = []
|
||||
ys = []
|
||||
y_preds = []
|
||||
with torch.no_grad():
|
||||
model = model.to(self.pytorch_settings.device)
|
||||
for x, y in loader:
|
||||
x, y = (
|
||||
x.to(self.pytorch_settings.device),
|
||||
y.to(self.pytorch_settings.device),
|
||||
)
|
||||
y_pred = model(x).cpu()
|
||||
# x = x.cpu()
|
||||
# y = y.cpu()
|
||||
y_pred = y_pred.view(y_pred.shape[0], -1, 2)
|
||||
y = y.view(y.shape[0], -1, 2)
|
||||
x = x.view(x.shape[0], -1, 2)
|
||||
xs.append(x[:, 0, :].squeeze())
|
||||
ys.append(y.squeeze())
|
||||
y_preds.append(y_pred.squeeze())
|
||||
|
||||
xs = torch.vstack(xs).cpu()
|
||||
ys = torch.vstack(ys).cpu()
|
||||
y_preds = torch.vstack(y_preds).cpu()
|
||||
return ys, xs, y_preds
|
||||
|
||||
def objective(self, save=False, plot_before=False):
|
||||
try:
|
||||
rprint(*list(self.study_name.split("_")))
|
||||
|
||||
self.model = self.define_model().to(self.pytorch_settings.device)
|
||||
|
||||
if self.writer is not None:
|
||||
self.writer.add_figure("fiber response", self.plot_model_response(plot=plot_before), 0)
|
||||
|
||||
train_loader, valid_loader = self.get_sliced_data()
|
||||
|
||||
optimizer_name = self.optimizer_settings.optimizer_range
|
||||
|
||||
lr = self.optimizer_settings.lr_range
|
||||
|
||||
optimizer = getattr(optim, optimizer_name)(self.model.parameters(), lr=lr)
|
||||
|
||||
for epoch in range(self.pytorch_settings.epochs):
|
||||
self.console.rule(f"Epoch {epoch + 1}/{self.pytorch_settings.epochs}")
|
||||
self.train_model(self.model, optimizer, train_loader, epoch)
|
||||
eval_loss = self.eval_model(self.model, valid_loader, epoch)
|
||||
|
||||
return eval_loss
|
||||
|
||||
except KeyboardInterrupt:
|
||||
...
|
||||
finally:
|
||||
if hasattr(self, "model"):
|
||||
save_path = (
|
||||
Path(self.pytorch_settings.model_dir) / f"{self.study_name}.pth"
|
||||
)
|
||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
torch.save(self.model, save_path)
|
||||
|
||||
def _plot_model_response_plotter(self, fiber_in, fiber_out, regen, plot=True):
|
||||
fig, axs = plt.subplots(2)
|
||||
for i, ax in enumerate(axs):
|
||||
ax.plot(np.abs(fiber_in[:, i]) ** 2, label="fiber in")
|
||||
ax.plot(np.abs(fiber_out[:, i]) ** 2, label="fiber out")
|
||||
ax.plot(np.abs(regen[:, i]) ** 2, label="regenerated")
|
||||
ax.legend()
|
||||
if plot:
|
||||
plt.show()
|
||||
return fig
|
||||
|
||||
def plot_model_response(self, model=None, plot=True):
|
||||
data_settings_backup = copy.copy(self.data_settings)
|
||||
self.data_settings.shuffle = False
|
||||
self.data_settings.train_split = 0.01
|
||||
self.data_settings.drop_first = 100
|
||||
plot_loader, _ = self.get_sliced_data()
|
||||
self.data_settings = data_settings_backup
|
||||
|
||||
fiber_in, fiber_out, regen = self.run_model(model or self.model, plot_loader)
|
||||
fiber_in = fiber_in.view(-1, 2)
|
||||
fiber_out = fiber_out.view(-1, 2)
|
||||
regen = regen.view(-1, 2)
|
||||
|
||||
fiber_in = fiber_in.numpy()
|
||||
fiber_out = fiber_out.numpy()
|
||||
regen = regen.numpy()
|
||||
|
||||
# https://github.com/matplotlib/matplotlib/issues/27713#issue-2104110987
|
||||
# https://github.com/matplotlib/matplotlib/issues/27713#issuecomment-1915497463
|
||||
import gc
|
||||
fig = self._plot_model_response_plotter(fiber_in, fiber_out, regen, plot=plot)
|
||||
gc.collect()
|
||||
|
||||
return fig
|
||||
|
||||
if __name__ == "__main__":
|
||||
trainer = Training()
|
||||
trainer = Trainer(
|
||||
global_settings=global_settings,
|
||||
data_settings=data_settings,
|
||||
pytorch_settings=pytorch_settings,
|
||||
model_settings=model_settings,
|
||||
optimizer_settings=optimizer_settings,
|
||||
checkpoint_path='.models/20241128_084935_8885.tar',
|
||||
settings_override={
|
||||
"model_settings": {
|
||||
# "model_activation_func": "PowScale",
|
||||
"dropout_prob": 0,
|
||||
}
|
||||
},
|
||||
reset_epoch=True,
|
||||
)
|
||||
|
||||
# trainer.plot_eye()
|
||||
trainer.setup_tb_writer()
|
||||
trainer.objective(save=True)
|
||||
|
||||
best_model = trainer.model
|
||||
|
||||
# best_model = trainer.define_model(trainer.study.best_trial).to(trainer.pytorch_settings.device)
|
||||
trainer.plot_model_response(best_model)
|
||||
|
||||
# print(f"Best model: {best_model}")
|
||||
best = trainer.train()
|
||||
save_dict_to_file(best, ".models/best_results.json")
|
||||
|
||||
...
|
||||
|
||||
@@ -1,16 +1,26 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
# from torchlambertw.special import lambertw
|
||||
|
||||
|
||||
def complex_mse_loss(input, target):
|
||||
def complex_mse_loss(input, target, power=False, reduction="mean"):
|
||||
"""
|
||||
Compute the mean squared error between two complex tensors.
|
||||
If power is set to True, the loss is computed as |input|^2 - |target|^2
|
||||
"""
|
||||
if input.is_complex():
|
||||
return torch.mean(torch.square(input.real - target.real) + torch.square(input.imag - target.imag))
|
||||
reduce = getattr(torch, reduction)
|
||||
|
||||
if power:
|
||||
input = (input * input.conj()).real.to(dtype=input.dtype.to_real())
|
||||
target = (target * target.conj()).real.to(dtype=target.dtype.to_real())
|
||||
|
||||
if input.is_complex() and target.is_complex():
|
||||
return reduce(torch.square(input.real - target.real) + torch.square(input.imag - target.imag))
|
||||
elif input.is_complex() or target.is_complex():
|
||||
raise ValueError("Input and target must have the same type (real or complex)")
|
||||
else:
|
||||
return F.mse_loss(input, target)
|
||||
return F.mse_loss(input, target, reduction=reduction)
|
||||
|
||||
|
||||
def complex_sse_loss(input, target):
|
||||
@@ -43,6 +53,174 @@ class UnitaryLayer(nn.Module):
|
||||
return f"UnitaryLayer({self.in_features}, {self.out_features})"
|
||||
|
||||
|
||||
|
||||
class _Unitary(nn.Module):
|
||||
def forward(self, X:torch.Tensor):
|
||||
if X.ndim < 2:
|
||||
raise ValueError(
|
||||
"Only tensors with 2 or more dimensions are supported. "
|
||||
f"Got a tensor of shape {X.shape}"
|
||||
)
|
||||
n, k = X.size(-2), X.size(-1)
|
||||
transpose = n<k
|
||||
if transpose:
|
||||
X = X.transpose(-2, -1)
|
||||
q, r = torch.linalg.qr(X)
|
||||
# q: torch.Tensor = q
|
||||
# r: torch.Tensor = r
|
||||
d = r.diagonal(dim1=-2, dim2=-1).sgn()
|
||||
q*=d.unsqueeze(-2)
|
||||
if transpose:
|
||||
q = q.transpose(-2, -1)
|
||||
if n == k:
|
||||
mask = (torch.linalg.det(q).abs() >= 0).to(q.dtype.to_real())
|
||||
mask[mask == 0] = -1
|
||||
mask = mask.unsqueeze(-1)
|
||||
q[..., 0] *= mask
|
||||
# X.copy_(q)
|
||||
return q
|
||||
|
||||
def unitary(module: nn.Module, name: str = "weight") -> nn.Module:
|
||||
weight = getattr(module, name, None)
|
||||
if not isinstance(weight, torch.Tensor):
|
||||
raise ValueError(f"Module '{module}' has no parameter or buffer '{name}'")
|
||||
|
||||
if weight.ndim < 2:
|
||||
raise ValueError(f"Expected a matrix or batch of matrices. Got a tensor of {weight.ndim} dimensions.")
|
||||
|
||||
if weight.shape[-2] != weight.shape[-1]:
|
||||
raise ValueError(f"Expected a square matrix or batch of square matrices. Got a tensor of shape {weight.shape}")
|
||||
|
||||
unit = _Unitary()
|
||||
nn.utils.parametrize.register_parametrization(module, name, unit)
|
||||
return module
|
||||
|
||||
class _SpecialUnitary(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, X:torch.Tensor):
|
||||
n, k = X.size(-2), X.size(-1)
|
||||
if n != k:
|
||||
raise ValueError(f"Expected a square matrix. Got a tensor of shape {X.shape}")
|
||||
q, _ = torch.linalg.qr(X)
|
||||
q = q / torch.linalg.det(q).pow(1/n)
|
||||
|
||||
return q
|
||||
|
||||
def special_unitary(module: nn.Module, name: str = "weight") -> nn.Module:
|
||||
weight = getattr(module, name, None)
|
||||
if not isinstance(weight, torch.Tensor):
|
||||
raise ValueError(f"Module '{module}' has no parameter or buffer '{name}'")
|
||||
|
||||
if weight.ndim < 2:
|
||||
raise ValueError(f"Expected a matrix or batch of matrices. Got a tensor of {weight.ndim} dimensions.")
|
||||
|
||||
if weight.shape[-2] != weight.shape[-1]:
|
||||
raise ValueError(f"Expected a square matrix or batch of square matrices. Got a tensor of shape {weight.shape}")
|
||||
|
||||
unit = _SpecialUnitary()
|
||||
nn.utils.parametrize.register_parametrization(module, name, unit)
|
||||
return module
|
||||
|
||||
class _Clamp(nn.Module):
|
||||
def __init__(self, min, max):
|
||||
super(_Clamp, self).__init__()
|
||||
self.min = min
|
||||
self.max = max
|
||||
def forward(self, x):
|
||||
if x.is_complex():
|
||||
# clamp magnitude, ignore phase
|
||||
return torch.clamp(x.abs(), self.min, self.max) * x / x.abs()
|
||||
return torch.clamp(x, self.min, self.max)
|
||||
|
||||
|
||||
def clamp(module: nn.Module, name: str = "scale", min=0, max=1) -> nn.Module:
|
||||
scale = getattr(module, name, None)
|
||||
if not isinstance(scale, torch.Tensor):
|
||||
raise ValueError(f"Module '{module}' has no parameter or buffer '{name}'")
|
||||
|
||||
cl = _Clamp(min, max)
|
||||
nn.utils.parametrize.register_parametrization(module, name, cl)
|
||||
return module
|
||||
|
||||
|
||||
class ONNMiller(nn.Module):
|
||||
def __init__(self, input_dim, output_dim, dtype=None) -> None:
|
||||
super(ONNMiller, self).__init__()
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
self.dtype = dtype
|
||||
|
||||
self.dim = max(input_dim, output_dim)
|
||||
|
||||
# zero pad input to internal size if smaller
|
||||
if self.input_dim < self.dim:
|
||||
self.pad = lambda x: F.pad(x, ((self.dim - self.input_dim) // 2, (self.dim - self.input_dim + 1) // 2))
|
||||
else:
|
||||
self.pad = lambda x: x
|
||||
self.pad.__doc__ = f"Zero pad input from {self.input_dim} to {self.dim}"
|
||||
|
||||
# crop output to desired size
|
||||
if self.output_dim < self.dim:
|
||||
self.crop = lambda x: x[:, (self.dim - self.output_dim) // 2 : (x.shape[1] - (self.dim - self.output_dim + 1) // 2)]
|
||||
else:
|
||||
self.crop = lambda x: x
|
||||
self.crop.__doc__ = f"Crop output from {self.dim} to {self.output_dim}"
|
||||
|
||||
self.U = nn.Parameter(torch.randn(self.dim, self.dim, dtype=self.dtype)) # -> parametrization: Unitary
|
||||
self.S = nn.Parameter(torch.randn(self.dim, dtype=self.dtype)) # -> parametrization: Clamp (magnitude 0..1)
|
||||
self.V = nn.Parameter(torch.randn(self.dim, self.dim, dtype=self.dtype)) # -> parametrization: Unitary
|
||||
self.register_buffer("MZI_scale", torch.tensor(2, dtype=self.dtype.to_real()).sqrt())
|
||||
# V is actually V.H, but
|
||||
|
||||
def forward(self, x_in):
|
||||
x = x_in
|
||||
x = self.pad(x)
|
||||
x = x @ self.U
|
||||
x = x * (self.S.squeeze() / self.MZI_scale)
|
||||
x = x @ self.V
|
||||
x = self.crop(x)
|
||||
return x
|
||||
|
||||
class ONN(nn.Module):
|
||||
def __init__(self, input_dim, output_dim, dtype=None) -> None:
|
||||
super(ONN, self).__init__()
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
self.dtype = dtype
|
||||
|
||||
self.dim = max(input_dim, output_dim)
|
||||
|
||||
# zero pad input to internal size if smaller
|
||||
if self.input_dim < self.dim:
|
||||
self.pad = lambda x: F.pad(x, ((self.dim - self.input_dim) // 2, (self.dim - self.input_dim + 1) // 2))
|
||||
self.pad.__doc__ = f"Zero pad input from {self.input_dim} to {self.dim}"
|
||||
else:
|
||||
self.pad = lambda x: x
|
||||
self.pad.__doc__ = f"Input size equals internal size {self.dim}"
|
||||
|
||||
# crop output to desired size
|
||||
if self.output_dim < self.dim:
|
||||
self.crop = lambda x: x[:, (self.dim - self.output_dim) // 2 : (x.shape[1] - (self.dim - self.output_dim + 1) // 2)]
|
||||
self.crop.__doc__ = f"Crop output from {self.dim} to {self.output_dim}"
|
||||
else:
|
||||
self.crop = lambda x: x
|
||||
self.crop.__doc__ = f"Output size equals internal size {self.dim}"
|
||||
|
||||
|
||||
self.weight = nn.Parameter(torch.randn(self.dim, self.dim, dtype=self.dtype))
|
||||
|
||||
def reset_parameters(self):
|
||||
q, _ = torch.linalg.qr(self.weight)
|
||||
self.weight.data = q
|
||||
# def get_M(self):
|
||||
# return self.U @ self.sigma @ self.V
|
||||
|
||||
def forward(self, x):
|
||||
return self.crop(self.pad(x) @ self.weight)
|
||||
|
||||
|
||||
class SemiUnitaryLayer(nn.Module):
|
||||
def __init__(self, input_dim, output_dim, dtype=None):
|
||||
super(SemiUnitaryLayer, self).__init__()
|
||||
@@ -51,24 +229,84 @@ class SemiUnitaryLayer(nn.Module):
|
||||
|
||||
# Create a larger square matrix for QR decomposition
|
||||
self.weight = nn.Parameter(torch.randn(max(input_dim, output_dim), max(input_dim, output_dim), dtype=dtype))
|
||||
self.scale = nn.Parameter(torch.tensor(1.0, dtype=dtype.to_real()))
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
# Ensure the weights are semi-unitary by QR decomposition
|
||||
# Ensure the weights are unitary by QR decomposition
|
||||
q, _ = torch.linalg.qr(self.weight)
|
||||
# A = QR with A being a complex square matrix -> Q is unitary, R is upper triangular
|
||||
|
||||
# truncate the matrix to the desired size
|
||||
if self.input_dim > self.output_dim:
|
||||
self.weight.data = q[: self.input_dim, : self.output_dim]
|
||||
else:
|
||||
self.weight.data = q[: self.output_dim, : self.input_dim].t()
|
||||
...
|
||||
|
||||
def forward(self, x):
|
||||
out = torch.matmul(x, self.weight)
|
||||
with torch.no_grad():
|
||||
scale = torch.clamp(self.scale, 0.0, 1.0)
|
||||
out = torch.matmul(x, scale * self.weight)
|
||||
return out
|
||||
|
||||
def __repr__(self):
|
||||
return f"SemiUnitaryLayer({self.input_dim}, {self.output_dim})"
|
||||
|
||||
|
||||
# class SaturableAbsorberLambertW(nn.Module):
|
||||
# """
|
||||
# Implements the activation function for an optical saturable absorber
|
||||
|
||||
# base eqn: sigma*tau*I0 = 0.5*(log(Tm/T0))/(1-Tm),
|
||||
# where: sigma is the absorption cross section
|
||||
# tau is the radiative lifetime of the absorber material
|
||||
# T0 is the initial transmittance
|
||||
# I0 is the input intensity
|
||||
# Tm is the transmittance of the absorber
|
||||
|
||||
# The activation function is defined as:
|
||||
# Iout = I0 * Tm(I0)
|
||||
|
||||
# where Tm(I0) is the transmittance of the absorber as a function of the input intensity I0
|
||||
|
||||
# for a unit sigma*tau product, he solution Tm(I0) is given by:
|
||||
# Tm(I0) = (W(2*exp(2*I0)*I0*T0))/(2*I0),
|
||||
# where W is the Lambert W function
|
||||
|
||||
# if sigma*tau is not 1, I0 has to be scaled by sigma*tau
|
||||
# (-> x has to be scaled by sqrt(sigma*tau))
|
||||
|
||||
# """
|
||||
|
||||
# def __init__(self, T0):
|
||||
# super(SaturableAbsorberLambertW, self).__init__()
|
||||
# self.register_buffer("T0", torch.tensor(T0))
|
||||
|
||||
# def forward(self, x: torch.Tensor):
|
||||
# xc = x.conj()
|
||||
# two_x_xc = (2 * x * xc).real
|
||||
# return (lambertw(2 * torch.exp(two_x_xc) * (x * self.T0 * xc).real) / two_x_xc).to(dtype=x.dtype)
|
||||
|
||||
# def backward(self, x):
|
||||
# xc = x.conj()
|
||||
# lambert_eval = lambertw(2 * torch.exp(2 * x * xc).real * (x * self.T0 * xc).real)
|
||||
# return (((xc * (-2 * lambert_eval + 2 * torch.square(x) - 1) + 2 * x * torch.square(xc) + x) * lambert_eval) / (
|
||||
# 2 * torch.pow(x, 3) * xc * (lambert_eval + 1)
|
||||
# )).to(dtype=x.dtype)
|
||||
|
||||
|
||||
# class SaturableAbsorber(nn.Module):
|
||||
# def __init__(self, alpha, I0):
|
||||
# super(SaturableAbsorber, self).__init__()
|
||||
# self.register_buffer("alpha", torch.tensor(alpha))
|
||||
# self.register_buffer("I0", torch.tensor(I0))
|
||||
|
||||
# def forward(self, x):
|
||||
# I = (x*x.conj()).to(dtype=x.dtype.to_real())
|
||||
# A = self.alpha/(1+I/self.I0)
|
||||
|
||||
|
||||
# class SpreadLayer(nn.Module):
|
||||
# def __init__(self, in_features, out_features, dtype=None):
|
||||
# super(SpreadLayer, self).__init__()
|
||||
@@ -85,6 +323,19 @@ class SemiUnitaryLayer(nn.Module):
|
||||
#### as defined by zhang et al
|
||||
|
||||
|
||||
class DropoutComplex(nn.Module):
|
||||
def __init__(self, p=0.5):
|
||||
super(DropoutComplex, self).__init__()
|
||||
self.dropout = nn.Dropout(p=p)
|
||||
|
||||
def forward(self, x):
|
||||
if x.is_complex():
|
||||
mask = self.dropout(torch.ones_like(x.real))
|
||||
return x * mask
|
||||
else:
|
||||
return self.dropout(x)
|
||||
|
||||
|
||||
class Identity(nn.Module):
|
||||
"""
|
||||
implements the "activation" function
|
||||
@@ -97,18 +348,76 @@ class Identity(nn.Module):
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
class PowRot(nn.Module):
|
||||
def __init__(self, bias=False):
|
||||
super(PowRot, self).__init__()
|
||||
self.scale = nn.Parameter(torch.tensor(1.0))
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.tensor(0.0))
|
||||
else:
|
||||
self.register_buffer("bias", torch.tensor(0.0))
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
if x.is_complex():
|
||||
return x * torch.exp(-self.scale*1j*x.abs().square()+self.bias.to(dtype=x.dtype))
|
||||
else:
|
||||
return x
|
||||
|
||||
class Pow(nn.Module):
|
||||
"""
|
||||
implements the activation function
|
||||
M(z) = ||z||^2 + b
|
||||
"""
|
||||
def __init__(self, bias=False):
|
||||
super(Pow, self).__init__()
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.tensor(0.0))
|
||||
else:
|
||||
self.register_buffer("bias", torch.tensor(0.0))
|
||||
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return x.abs().square().add(self.bias).to(dtype=x.dtype)
|
||||
|
||||
|
||||
class Mag(nn.Module):
|
||||
"""
|
||||
implements the activation function
|
||||
M(z) = ||z||
|
||||
M(z) = ||z||+b
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, bias=False):
|
||||
super(Mag, self).__init__()
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.tensor(0.0))
|
||||
else:
|
||||
self.register_buffer("bias", torch.tensor(0.0))
|
||||
|
||||
def forward(self, x):
|
||||
return torch.abs(x).to(dtype=x.dtype)
|
||||
def forward(self, x: torch.Tensor):
|
||||
return x.abs().add(self.bias).to(dtype=x.dtype)
|
||||
|
||||
|
||||
class MagScale(nn.Module):
|
||||
def __init__(self, bias=False):
|
||||
super(MagScale, self).__init__()
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.tensor(0.0))
|
||||
else:
|
||||
self.register_buffer("bias", torch.tensor(0.0))
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return x.abs().add(self.bias).to(dtype=x.dtype).sin().mul(x)
|
||||
|
||||
class PowScale(nn.Module):
|
||||
def __init__(self, bias=False):
|
||||
super(PowScale, self).__init__()
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.tensor(0.0))
|
||||
else:
|
||||
self.register_buffer("bias", torch.tensor(0.0))
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return x.mul(x.abs().square().add(self.bias).to(dtype=x.dtype).sin())
|
||||
|
||||
|
||||
class ModReLU(nn.Module):
|
||||
@@ -118,17 +427,21 @@ class ModReLU(nn.Module):
|
||||
= ReLU(||z|| + b)*z/||z||
|
||||
"""
|
||||
|
||||
def __init__(self, b=0):
|
||||
def __init__(self, bias=True):
|
||||
super(ModReLU, self).__init__()
|
||||
self.b = torch.tensor(b)
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.tensor(0.0))
|
||||
else:
|
||||
self.register_buffer("bias", torch.tensor(0.0))
|
||||
|
||||
def forward(self, x):
|
||||
if x.is_complex():
|
||||
mod = torch.abs(x.real**2 + x.imag**2)
|
||||
return torch.relu(mod + self.b) * x / mod
|
||||
mod = x.abs()
|
||||
out = torch.relu(mod + self.bias) * x / mod
|
||||
return out.to(dtype=x.dtype)
|
||||
|
||||
else:
|
||||
return torch.relu(x + self.b)
|
||||
return torch.relu(x + self.bias).to(dtype=x.dtype)
|
||||
|
||||
def __repr__(self):
|
||||
return f"ModReLU(b={self.b})"
|
||||
@@ -166,3 +479,26 @@ class ZReLU(nn.Module):
|
||||
return x * (torch.angle(x) >= 0) * (torch.angle(x) <= torch.pi / 2)
|
||||
else:
|
||||
return torch.relu(x)
|
||||
|
||||
|
||||
__all__ = [
|
||||
complex_sse_loss,
|
||||
complex_mse_loss,
|
||||
UnitaryLayer,
|
||||
unitary,
|
||||
clamp,
|
||||
ONN,
|
||||
ONNMiller,
|
||||
SemiUnitaryLayer,
|
||||
DropoutComplex,
|
||||
Identity,
|
||||
Pow,
|
||||
PowRot,
|
||||
Mag,
|
||||
ModReLU,
|
||||
CReLU,
|
||||
ZReLU,
|
||||
# SaturableAbsorberLambertW,
|
||||
# SaturableAbsorber,
|
||||
# SpreadLayer,
|
||||
]
|
||||
@@ -20,7 +20,7 @@ def _optional_suggest(
|
||||
type: str,
|
||||
log: bool = False,
|
||||
step: int | float | None = None,
|
||||
add_user: bool = False,
|
||||
add_user: bool = True,
|
||||
force: bool = False,
|
||||
multiply: float | int = 1,
|
||||
set_new: bool = True,
|
||||
@@ -96,7 +96,7 @@ def suggest_categorical_optional(
|
||||
trial: trial.Trial,
|
||||
name: str,
|
||||
choices_or_value: tuple[Any] | list[Any] | Any,
|
||||
add_user: bool = False,
|
||||
add_user: bool = True,
|
||||
force: bool = False,
|
||||
set_new: bool = True,
|
||||
):
|
||||
@@ -129,7 +129,7 @@ def suggest_int_optional(
|
||||
range_or_value: tuple[int] | list[int] | int,
|
||||
step: int = 1,
|
||||
log: bool = False,
|
||||
add_user: bool = False,
|
||||
add_user: bool = True,
|
||||
force: bool = False,
|
||||
multiply: int = 1,
|
||||
set_new: bool = True,
|
||||
@@ -174,7 +174,7 @@ def suggest_float_optional(
|
||||
range_or_value: tuple[float] | list[float] | float,
|
||||
step: float | None = None,
|
||||
log: bool = False,
|
||||
add_user: bool = False,
|
||||
add_user: bool = True,
|
||||
force: bool = False,
|
||||
multiply: float = 1,
|
||||
set_new: bool = True,
|
||||
@@ -222,7 +222,7 @@ def suggest_categorical_optional_wrapper(
|
||||
self: trial.Trial,
|
||||
name: str,
|
||||
choices_or_value: tuple[Any] | list[Any] | Any,
|
||||
add_user: bool = False,
|
||||
add_user: bool = True,
|
||||
force: bool = False,
|
||||
set_new: bool = True,
|
||||
):
|
||||
@@ -253,7 +253,7 @@ def suggest_int_optional_wrapper(
|
||||
range_or_value: tuple[int] | list[int] | int,
|
||||
step: int = 1,
|
||||
log: bool = False,
|
||||
add_user: bool = False,
|
||||
add_user: bool = True,
|
||||
force: bool = False,
|
||||
multiply: int = 1,
|
||||
set_new: bool = True,
|
||||
@@ -295,7 +295,7 @@ def suggest_float_optional_wrapper(
|
||||
range_or_value: tuple[float] | list[float] | float,
|
||||
step: float | None = None,
|
||||
log: bool = False,
|
||||
add_user: bool = False,
|
||||
add_user: bool = True,
|
||||
force: bool = False,
|
||||
multiply: float = 1,
|
||||
set_new: bool = True,
|
||||
|
||||
Reference in New Issue
Block a user