Compare commits

..

2 Commits

10 changed files with 561 additions and 305 deletions

View File

@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:746ea83013870351296e01e294905ee027291ef79fd78c1e6b69dd9ebaa1cba0 oid sha256:76934d1d202aea1311ba67f5ea35eeb99a9c5c856f491565032e7d54ca6f072d
size 10240000 size 13598720

View File

@@ -26,6 +26,8 @@ import torch
import torch.optim as optim import torch.optim as optim
import torch.utils.data import torch.utils.data
import hypertraining.models as models
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
import multiprocessing import multiprocessing
@@ -253,14 +255,17 @@ class HyperTraining:
model_kwargs = { model_kwargs = {
"dims": (input_dim, *hidden_dims, self.model_settings.output_dim), "dims": (input_dim, *hidden_dims, self.model_settings.output_dim),
"layer_function": layer_func, "layer_function": layer_func,
"layer_parametrizations": layer_parametrizations, "layer_func_kwargs": self.model_settings.model_layer_kwargs,
"activation_function": afunc, "act_function": afunc,
"act_func_kwargs": None,
"parametrizations": layer_parametrizations,
"dtype": dtype, "dtype": dtype,
"droupout_prob": self.model_settings.dropout_prob, "dropout_prob": self.model_settings.dropout_prob,
"scale": scale_layers, "scale_layers": scale_layers,
"rotate": False,
} }
model = util.complexNN.regenerator(**model_kwargs) model = models.regenerator(*model_kwargs.pop("dims"), **model_kwargs)
n_nodes = sum(hidden_dims) n_nodes = sum(hidden_dims)
if writer is not None: if writer is not None:
@@ -381,7 +386,10 @@ class HyperTraining:
running_loss = 0.0 running_loss = 0.0
model.train() model.train()
loader_len = len(train_loader) loader_len = len(train_loader)
for batch_idx, (x, y, _) in enumerate(train_loader): for batch_idx, batch in enumerate(train_loader):
x = batch["x"]
y = batch["y"]
if batch_idx >= self.optuna_settings._n_train_batches: if batch_idx >= self.optuna_settings._n_train_batches:
break break
model.zero_grad(set_to_none=True) model.zero_grad(set_to_none=True)
@@ -390,7 +398,7 @@ class HyperTraining:
y.to(self.pytorch_settings.device), y.to(self.pytorch_settings.device),
) )
y_pred = model(x) y_pred = model(x)
loss = util.complexNN.complex_mse_loss(y_pred, y) loss = util.complexNN.complex_mse_loss(y_pred, y, power=True)
loss_value = loss.item() loss_value = loss.item()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
@@ -444,7 +452,9 @@ class HyperTraining:
model.eval() model.eval()
running_error = 0 running_error = 0
with torch.no_grad(): with torch.no_grad():
for batch_idx, (x, y, _) in enumerate(valid_loader): for batch_idx, batch in enumerate(valid_loader):
x = batch["x"]
y = batch["y"]
if batch_idx >= self.optuna_settings._n_valid_batches: if batch_idx >= self.optuna_settings._n_valid_batches:
break break
x, y = ( x, y = (
@@ -452,50 +462,44 @@ class HyperTraining:
y.to(self.pytorch_settings.device), y.to(self.pytorch_settings.device),
) )
y_pred = model(x) y_pred = model(x)
error = util.complexNN.complex_mse_loss(y_pred, y) error = util.complexNN.complex_mse_loss(y_pred, y, power=True)
error_value = error.item() error_value = error.item()
running_error += error_value running_error += error_value
running_error /= min(len(valid_loader), self.optuna_settings._n_valid_batches) running_error /= min(len(valid_loader), self.optuna_settings._n_valid_batches)
if writer is not None: if writer is not None:
title_append, subtitle = self.build_title(trial) writer.add_scalar(
writer.add_figure( "eval loss",
"fiber response", running_error,
self.plot_model_response( epoch,
trial,
model=model,
title_append=title_append,
subtitle=subtitle,
show=False,
),
epoch + 1,
)
writer.add_figure(
"eye diagram",
self.plot_model_response(
trial,
model=self.model,
title_append=title_append,
subtitle=subtitle,
show=False,
mode="eye",
),
epoch + 1,
) )
# if (epoch + 1) % 10 == 0 or epoch < 10:
# # plotting is slow, so only do it every 10 epochs
# title_append, subtitle = self.build_title(trial)
# head_fig, eye_fig, powers_fig = self.plot_model_response(
# model=model,
# title_append=title_append,
# subtitle=subtitle,
# show=False,
# )
# writer.add_figure(
# "fiber response",
# head_fig,
# epoch + 1,
# )
# writer.add_figure(
# "eye diagram",
# eye_fig,
# epoch + 1,
# )
writer.add_figure( # writer.add_figure(
"powers", # "powers",
self.plot_model_response( # powers_fig,
trial, # epoch + 1,
model=self.model, # )
title_append=title_append, # writer.flush()
subtitle=subtitle,
mode="powers",
show=False,
),
epoch + 1,
)
# if enable_progress: # if enable_progress:
# progress.stop() # progress.stop()
@@ -511,15 +515,18 @@ class HyperTraining:
with torch.no_grad(): with torch.no_grad():
model = model.to(self.pytorch_settings.device) model = model.to(self.pytorch_settings.device)
for x, y, timestamp in loader: for batch in loader:
x = batch["x"]
y = batch["y"]
timestamp = batch["timestamp"]
x, y = ( x, y = (
x.to(self.pytorch_settings.device), x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device), y.to(self.pytorch_settings.device),
) )
if trace_powers: if trace_powers:
y_pred, powers = model(x, trace_powers).cpu() y_pred, powers = model(x, trace_powers=True).cpu()
else: else:
y_pred = model(x, trace_powers).cpu() y_pred = model(x, trace_powers=True).cpu()
# x = x.cpu() # x = x.cpu()
# y = y.cpu() # y = y.cpu()
y_pred = y_pred.view(y_pred.shape[0], -1, 2) y_pred = y_pred.view(y_pred.shape[0], -1, 2)
@@ -539,7 +546,7 @@ class HyperTraining:
return fiber_in, fiber_out, regen, timestamps, powers return fiber_in, fiber_out, regen, timestamps, powers
return fiber_in, fiber_out, regen, timestamps return fiber_in, fiber_out, regen, timestamps
def objective(self, trial: optuna.Trial, plot_before=False): def objective(self, trial: optuna.Trial):
if self.stop_study: if self.stop_study:
trial.study.stop() trial.study.stop()
model = None model = None
@@ -555,54 +562,54 @@ class HyperTraining:
title_append, subtitle = self.build_title(trial) title_append, subtitle = self.build_title(trial)
writer.add_figure( # writer.add_figure(
"fiber response", # "fiber response",
self.plot_model_response( # self.plot_model_response(
trial, # trial,
model=model, # model=model,
title_append=title_append, # title_append=title_append,
subtitle=subtitle, # subtitle=subtitle,
show=False, # show=False,
), # ),
0, # 0,
) # )
writer.add_figure( # writer.add_figure(
"eye diagram", # "eye diagram",
self.plot_model_response( # self.plot_model_response(
trial, # trial,
model=self.model, # model=self.model,
title_append=title_append, # title_append=title_append,
subtitle=subtitle, # subtitle=subtitle,
mode="eye", # mode="eye",
show=False, # show=False,
), # ),
0, # 0,
) # )
writer.add_figure( # writer.add_figure(
"powers", # "powers",
self.plot_model_response( # self.plot_model_response(
trial, # trial,
model=self.model, # model=self.model,
title_append=title_append, # title_append=title_append,
subtitle=subtitle, # subtitle=subtitle,
mode="powers", # mode="powers",
show=False, # show=False,
), # ),
0, # 0,
) # )
train_loader, valid_loader = self.get_sliced_data(trial) train_loader, valid_loader = self.get_sliced_data(trial)
optimizer_name = trial.suggest_categorical_optional("optimizer", self.optimizer_settings.optimizer) optimizer_name = trial.suggest_categorical_optional("optimizer", self.optimizer_settings.optimizer)
lr = trial.suggest_float_optional("lr", self.optimizer_settings.learning_rate, log=True) lr = trial.suggest_float_optional("lr", self.optimizer_settings.optimizer_kwargs["lr"], log=True)
optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr) optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)
if self.optimizer_settings.scheduler is not None: # if self.optimizer_settings.scheduler is not None:
scheduler = getattr(optim.lr_scheduler, self.optimizer_settings.scheduler)( # scheduler = getattr(optim.lr_scheduler, self.optimizer_settings.scheduler)(
optimizer, **self.optimizer_settings.scheduler_kwargs # optimizer, **self.optimizer_settings.scheduler_kwargs
) # )
for epoch in range(self.pytorch_settings.epochs): for epoch in range(self.pytorch_settings.epochs):
trial.set_user_attr("epoch", epoch) trial.set_user_attr("epoch", epoch)
@@ -628,8 +635,8 @@ class HyperTraining:
writer, writer,
# enable_progress=enable_progress, # enable_progress=enable_progress,
) )
if self.optimizer_settings.scheduler is not None: # if self.optimizer_settings.scheduler is not None:
scheduler.step(error) # scheduler.step(error)
trial.set_user_attr("mse", error) trial.set_user_attr("mse", error)
trial.set_user_attr("log_mse", np.log10(error + np.finfo(float).eps)) trial.set_user_attr("log_mse", np.log10(error + np.finfo(float).eps))
@@ -645,10 +652,10 @@ class HyperTraining:
if self.optuna_settings._multi_objective: if self.optuna_settings._multi_objective:
return -np.log10(error + np.finfo(float).eps), trial.user_attrs.get("model_n_nodes", -1) return -np.log10(error + np.finfo(float).eps), trial.user_attrs.get("model_n_nodes", -1)
if self.pytorch_settings.save_models and model is not None: # if self.pytorch_settings.save_models and model is not None:
save_path = Path(self.pytorch_settings.model_dir) / f"{self.optuna_settings.study_name}_{trial.number}.pth" # save_path = Path(self.pytorch_settings.model_dir) / f"{self.optuna_settings.study_name}_{trial.number}.pth"
save_path.parent.mkdir(parents=True, exist_ok=True) # save_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(model, save_path) # torch.save(model, save_path)
return error return error

View File

@@ -8,7 +8,8 @@ from util.complexNN import (
photodiode, photodiode,
EOActivation, EOActivation,
polarimeter, polarimeter,
normalize_by_first # normalize_by_first,
rotate,
) )
@@ -19,11 +20,11 @@ class polarisation_estimator2(Module):
polarimeter(), polarimeter(),
torch.nn.Linear(4, 4), torch.nn.Linear(4, 4),
torch.nn.ReLU(), torch.nn.ReLU(),
torch.nn.Dropout(p=0.01), # torch.nn.Dropout(p=0.01),
torch.nn.Linear(4, 4), torch.nn.Linear(4, 4),
torch.nn.ReLU(), torch.nn.ReLU(),
torch.nn.Dropout(p=0.01), # torch.nn.Dropout(p=0.01),
torch.nn.Linear(4, 4), torch.nn.Linear(4, 1),
) )
def forward(self, x): def forward(self, x):
@@ -124,6 +125,7 @@ class regenerator(Module):
dtype=torch.float64, dtype=torch.float64,
dropout_prob=0.01, dropout_prob=0.01,
scale_layers=False, scale_layers=False,
rotate=False,
): ):
super(regenerator, self).__init__() super(regenerator, self).__init__()
self._n_hidden_layers = len(dims) - 2 self._n_hidden_layers = len(dims) - 2
@@ -131,6 +133,8 @@ class regenerator(Module):
layer_func_kwargs = layer_func_kwargs or {} layer_func_kwargs = layer_func_kwargs or {}
act_func_kwargs = act_func_kwargs or {} act_func_kwargs = act_func_kwargs or {}
self.rotation = rotate
self.build_model(dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob, scale_layers) self.build_model(dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob, scale_layers)
def build_model(self, dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob, scale_layers): def build_model(self, dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob, scale_layers):
@@ -146,8 +150,9 @@ class regenerator(Module):
module = act_function(size=dims[i + 1], **act_func_kwargs) module = act_function(size=dims[i + 1], **act_func_kwargs)
self.get_submodule(f"layer_{i}").add_module("activation", module) self.get_submodule(f"layer_{i}").add_module("activation", module)
module = DropoutComplex(p=dropout_prob) if dropout_prob is not None and dropout_prob > 0:
self.get_submodule(f"layer_{i}").add_module("dropout", module) module = DropoutComplex(p=dropout_prob)
self.get_submodule(f"layer_{i}").add_module("dropout", module)
self.add_module(f"layer_{self._n_hidden_layers}", Sequential()) self.add_module(f"layer_{self._n_hidden_layers}", Sequential())
@@ -160,6 +165,10 @@ class regenerator(Module):
module = act_function(size=dims[-1], **act_func_kwargs) module = act_function(size=dims[-1], **act_func_kwargs)
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("activation", module) self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("activation", module)
if self.rotation:
module = rotate()
self.add_module("rotate", module)
# module = Scale(size=dims[-1]) # module = Scale(size=dims[-1])
# self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("out_scale", module) # self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("out_scale", module)
@@ -190,15 +199,27 @@ class regenerator(Module):
powers.append(x.abs().square().sum()) powers.append(x.abs().square().sum())
return powers return powers
def forward(self, x, trace_powers=False): def forward(self, x, angle=None, pre_rot=False, trace_powers=False):
powers = self._trace_powers(trace_powers, x) powers = self._trace_powers(trace_powers, x)
x = self.layer_0(x) # x = self.layer_0(x)
powers = self._trace_powers(trace_powers, x, powers) # powers = self._trace_powers(trace_powers, x, powers)
for i in range(1, self._n_hidden_layers): for i in range(0, self._n_hidden_layers):
x = getattr(self, f"layer_{i}")(x) x = getattr(self, f"layer_{i}")(x)
powers = self._trace_powers(trace_powers, x, powers) powers = self._trace_powers(trace_powers, x, powers)
x = getattr(self, f"layer_{self._n_hidden_layers}")(x) x = getattr(self, f"layer_{self._n_hidden_layers}")(x)
powers = self._trace_powers(trace_powers, x, powers) if self.rotation:
if trace_powers: try:
return x, powers x_rot = self.rotate(x, angle)
return x except AttributeError:
pass
powers = self._trace_powers(trace_powers, x_rot, powers)
else:
x_rot = x
if pre_rot and trace_powers:
return x_rot, x, powers
if pre_rot and not trace_powers:
return x_rot, x
if not pre_rot and trace_powers:
return x_rot, powers
return x_rot

View File

@@ -22,6 +22,8 @@ class DataSettings:
train_split: float = 0.8 train_split: float = 0.8
polarisations: tuple | list = (0,) polarisations: tuple | list = (0,)
randomise_polarisations: bool = False randomise_polarisations: bool = False
osnr: float | int = None
seed: int = None
""" """
change to: change to:

View File

@@ -2,7 +2,6 @@ import copy
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
import random import random
from typing import Literal
import matplotlib import matplotlib
from matplotlib.colors import LinearSegmentedColormap from matplotlib.colors import LinearSegmentedColormap
import torch.nn.utils.parametrize import torch.nn.utils.parametrize
@@ -60,33 +59,35 @@ def traverse_dict_update(target, source):
except TypeError: except TypeError:
target.__dict__[k] = v target.__dict__[k] = v
def get_parameter_names_and_values(model): def get_parameter_names_and_values(model):
def is_parametrized(module): def is_parametrized(module):
if hasattr(module, "parametrizations"): if hasattr(module, "parametrizations"):
return True return True
return False return False
def _get_param_info(module, prefix='', parametrization=False): def _get_param_info(module, prefix="", parametrization=False):
param_list = [] param_list = []
for name, param in module.named_parameters(recurse = parametrization): for name, param in module.named_parameters(recurse=parametrization):
if parametrization and name.startswith("parametrizations"): if parametrization and name.startswith("parametrizations"):
name_parts = name.split('.') name_parts = name.split(".")
name = name_parts[1] name = name_parts[1]
param = getattr(module, name) param = getattr(module, name)
full_name = prefix + ('.' if prefix else '') + name full_name = prefix + ("." if prefix else "") + name
param_value = param.data param_value = param.data
param_list.append((full_name, param_value)) param_list.append((full_name, param_value))
for child_name, child_module in module.named_children(): for child_name, child_module in module.named_children():
child_prefix = prefix + ('.' if prefix else '') + child_name child_prefix = prefix + ("." if prefix else "") + child_name
if child_name == "parametrizations": if child_name == "parametrizations":
continue continue
param_list.extend(_get_param_info(child_module, child_prefix, is_parametrized(child_module))) param_list.extend(_get_param_info(child_module, child_prefix, is_parametrized(child_module)))
return param_list return param_list
return _get_param_info(model) return _get_param_info(model)
class PolarizationTrainer: class PolarizationTrainer:
def __init__( def __init__(
self, self,
@@ -101,7 +102,7 @@ class PolarizationTrainer:
settings_override=None, settings_override=None,
reset_epoch=False, reset_epoch=False,
): ):
self.mod = torch.pi/2 self.mod = torch.pi / 2
self.resume = checkpoint_path is not None self.resume = checkpoint_path is not None
torch.serialization.add_safe_globals([ torch.serialization.add_safe_globals([
*util.complexNN.__all__, *util.complexNN.__all__,
@@ -219,7 +220,7 @@ class PolarizationTrainer:
# dims = self.model_kwargs.pop("dims") # dims = self.model_kwargs.pop("dims")
model_kwargs = copy.deepcopy(self.model_kwargs) model_kwargs = copy.deepcopy(self.model_kwargs)
self.model = models.polarisation_estimator(*model_kwargs.pop('dims'),**model_kwargs) self.model = models.polarisation_estimator(*model_kwargs.pop("dims"), **model_kwargs)
# self.model = models.polarisation_estimator2() # self.model = models.polarisation_estimator2()
if self.writer is not None: if self.writer is not None:
@@ -336,17 +337,20 @@ class PolarizationTrainer:
write_div = 0 write_div = 0
loss_div = 0 loss_div = 0
for batch_idx, batch in enumerate(train_loader): for batch_idx, batch in enumerate(train_loader):
x = batch["x"] x = batch["angle_data2"]
y = batch["sop"] y = batch["center_angle"]
self.model.zero_grad(set_to_none=True) self.model.zero_grad(set_to_none=True)
x, y = ( x, y = (
x.to(self.pytorch_settings.device), x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device), y.to(self.pytorch_settings.device),
) )
y_pred = self.model(x) y_pred = self.model(x).abs().real
# y_pred = torch.fmod(y_pred, self.mod)
y = y.abs().real
# y = torch.fmod(y, self.mod)
# loss = torch.nn.functional.mse_loss(torch.cos(y_pred), torch.cos(y))
# loss = torch.nn.functional.smooth_l1_loss(torch.cos(torch.fmod(y_pred, torch.pi/2)).squeeze(), torch.cos(torch.fmod(y, torch.pi/2)).squeeze(), beta=0.5) # loss = torch.nn.functional.smooth_l1_loss(torch.cos(torch.fmod(y_pred, torch.pi/2)).squeeze(), torch.cos(torch.fmod(y, torch.pi/2)).squeeze(), beta=0.5)
loss = torch.nn.functional.mse_loss(y_pred, y) loss = util.complexNN.naive_angle_loss(y_pred, y, mod=self.mod)
# loss = util.complexNN.naive_angle_loss(y_pred, y, mod=torch.pi/2)
loss_value = loss.item() loss_value = loss.item()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
@@ -356,7 +360,7 @@ class PolarizationTrainer:
loss_div += 1 loss_div += 1
if enable_progress: if enable_progress:
progress.update(task, advance=1, description=f"{loss_value:.3e}") progress.update(task, advance=1, description=f"{loss_value/np.pi*180:.3e} °")
if batch_idx % self.pytorch_settings.write_every == 0: if batch_idx % self.pytorch_settings.write_every == 0:
self.writer.add_scalar( self.writer.add_scalar(
@@ -395,24 +399,28 @@ class PolarizationTrainer:
loss_div = 0 loss_div = 0
with torch.no_grad(): with torch.no_grad():
for _, batch in enumerate(valid_loader): for _, batch in enumerate(valid_loader):
x = batch["x"] # x = batch["angle_data2"]
y = batch["sop"] x = batch["angle_data2"]
y = batch["center_angle"]
x, y = ( x, y = (
x.to(self.pytorch_settings.device), x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device), y.to(self.pytorch_settings.device),
) )
y_pred = self.model(x) y_pred = self.model(x).abs().real
# y_pred = torch.fmod(y_pred, self.mod)
y = y.abs().real
# y = torch.fmod(y, self.mod)
# loss = torch.nn.functional.mse_loss(torch.cos(y_pred), torch.cos(y))
# loss = torch.nn.functional.smooth_l1_loss(torch.cos(torch.fmod(y_pred, torch.pi/2)).squeeze(), torch.cos(torch.fmod(y, torch.pi/2)).squeeze(), beta=0.5) # loss = torch.nn.functional.smooth_l1_loss(torch.cos(torch.fmod(y_pred, torch.pi/2)).squeeze(), torch.cos(torch.fmod(y, torch.pi/2)).squeeze(), beta=0.5)
loss = torch.nn.functional.mse_loss(y_pred, y) loss = util.complexNN.naive_angle_loss(y_pred, y, mod=self.mod)
# loss = util.complexNN.naive_angle_loss(y_pred, y, mod=torch.pi/2)
loss_value = loss.item() loss_value = loss.item()
running_loss += loss_value running_loss += loss_value
loss_div += 1 loss_div += 1
if enable_progress: if enable_progress:
progress.update(task, advance=1, description=f"{loss_value:.3e}") progress.update(task, advance=1, description=f"{loss_value/np.pi*180:.3e} °")
running_loss = running_loss/loss_div running_loss = running_loss / loss_div
self.writer.add_scalar( self.writer.add_scalar(
"eval loss", "eval loss",
@@ -506,19 +514,19 @@ class PolarizationTrainer:
for i, config_path in enumerate(self.data_settings.config_path): for i, config_path in enumerate(self.data_settings.config_path):
paths = Path.cwd().glob(config_path) paths = Path.cwd().glob(config_path)
for j, path in enumerate(paths): for j, path in enumerate(paths):
text = str(path) + '\n' text = str(path) + "\n"
with open(path, 'r') as f: with open(path, "r") as f:
text += f.read() text += f.read()
text += '\n' text += "\n"
self.writer.add_text(f"config_{i*len(self.data_settings.config_path)+j}", text) self.writer.add_text(f"config_{i * len(self.data_settings.config_path) + j}", text)
elif isinstance(self.data_settings.config_path, str): elif isinstance(self.data_settings.config_path, str):
paths = Path.cwd().glob(self.data_settings.config_path) paths = Path.cwd().glob(self.data_settings.config_path)
for j, path in enumerate(paths): for j, path in enumerate(paths):
text = str(path) + '\n' text = str(path) + "\n"
with open(path, 'r') as f: with open(path, "r") as f:
text += f.read() text += f.read()
text += '\n' text += "\n"
self.writer.add_text(f"config_{j}", text) self.writer.add_text(f"config_{j}", text)
self.writer.flush() self.writer.flush()
@@ -571,7 +579,8 @@ class PolarizationTrainer:
if loss < self.best["loss"]: if loss < self.best["loss"]:
self.best = checkpoint self.best = checkpoint
save_path = ( save_path = (
Path(self.pytorch_settings.model_dir) / f"best_pol_{self.writer.get_logdir().split('/')[-1]}.tar" Path(self.pytorch_settings.model_dir)
/ f"best_pol_{self.writer.get_logdir().split('/')[-1]}.tar"
) )
save_path.parent.mkdir(parents=True, exist_ok=True) save_path.parent.mkdir(parents=True, exist_ok=True)
self.save_checkpoint(self.best, save_path) self.save_checkpoint(self.best, save_path)
@@ -580,6 +589,7 @@ class PolarizationTrainer:
self.writer.close() self.writer.close()
return self.best return self.best
class RegenerationTrainer: class RegenerationTrainer:
def __init__( def __init__(
self, self,
@@ -636,6 +646,10 @@ class RegenerationTrainer:
self.model_settings: ModelSettings = model_settings self.model_settings: ModelSettings = model_settings
self.optimizer_settings: OptimizerSettings = optimizer_settings self.optimizer_settings: OptimizerSettings = optimizer_settings
if self.global_settings.seed is not None:
random.seed(self.global_settings.seed)
np.random.seed(self.global_settings.seed)
self.console = console or Console() self.console = console or Console()
self.writer = None self.writer = None
@@ -706,10 +720,12 @@ class RegenerationTrainer:
# dims = self.model_kwargs.pop("dims") # dims = self.model_kwargs.pop("dims")
model_kwargs = copy.deepcopy(self.model_kwargs) model_kwargs = copy.deepcopy(self.model_kwargs)
self.model = models.regenerator(*model_kwargs.pop('dims'),**model_kwargs) self.model = models.regenerator(*model_kwargs.pop("dims"), **model_kwargs)
if self.writer is not None: if self.writer is not None:
self.writer.add_graph(self.model, torch.zeros(1, input_dim, dtype=dtype)) self.writer.add_graph(
self.model, (torch.rand(1, input_dim, dtype=dtype), torch.rand(1, 1, dtype=dtype.to_real()))
)
self.model = self.model.to(self.pytorch_settings.device) self.model = self.model.to(self.pytorch_settings.device)
if self.resume: if self.resume:
@@ -728,12 +744,12 @@ class RegenerationTrainer:
num_symbols = None num_symbols = None
config_path = self.data_settings.config_path config_path = self.data_settings.config_path
polarisations = self.data_settings.polarisations
randomise_polarisations = self.data_settings.randomise_polarisations randomise_polarisations = self.data_settings.randomise_polarisations
osnr = self.data_settings.osnr
if override is not None: if override is not None:
num_symbols = override.get("num_symbols", None) num_symbols = override.get("num_symbols", None)
config_path = override.get("config_path", config_path) config_path = override.get("config_path", config_path)
polarisations = override.get("polarisations", polarisations) # polarisations = override.get("polarisations", polarisations)
randomise_polarisations = override.get("randomise_polarisation", randomise_polarisations) randomise_polarisations = override.get("randomise_polarisation", randomise_polarisations)
# get dataset # get dataset
dataset = FiberRegenerationDataset( dataset = FiberRegenerationDataset(
@@ -746,8 +762,8 @@ class RegenerationTrainer:
dtype=dtype, dtype=dtype,
real=not dtype.is_complex, real=not dtype.is_complex,
num_symbols=num_symbols, num_symbols=num_symbols,
polarisations=polarisations,
randomise_polarisations=randomise_polarisations, randomise_polarisations=randomise_polarisations,
osnr = osnr,
) )
dataset_size = len(dataset) dataset_size = len(dataset)
@@ -819,12 +835,14 @@ class RegenerationTrainer:
for batch_idx, batch in enumerate(train_loader): for batch_idx, batch in enumerate(train_loader):
x = batch["x"] x = batch["x"]
y = batch["y"] y = batch["y"]
angles = batch["mean_angle"]
self.model.zero_grad(set_to_none=True) self.model.zero_grad(set_to_none=True)
x, y = ( x, y, angles = (
x.to(self.pytorch_settings.device), x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device), y.to(self.pytorch_settings.device),
angles.to(self.pytorch_settings.device),
) )
y_pred = self.model(x) y_pred = self.model(x, -angles)
loss = util.complexNN.complex_mse_loss(y_pred, y, power=True) loss = util.complexNN.complex_mse_loss(y_pred, y, power=True)
loss_value = loss.item() loss_value = loss.item()
loss.backward() loss.backward()
@@ -872,11 +890,13 @@ class RegenerationTrainer:
for _, batch in enumerate(valid_loader): for _, batch in enumerate(valid_loader):
x = batch["x"] x = batch["x"]
y = batch["y"] y = batch["y"]
x, y = ( angles = batch["mean_angle"]
x, y, angles = (
x.to(self.pytorch_settings.device), x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device), y.to(self.pytorch_settings.device),
angles.to(self.pytorch_settings.device),
) )
y_pred = self.model(x) y_pred = self.model(x, -angles)
error = util.complexNN.complex_mse_loss(y_pred, y, power=True) error = util.complexNN.complex_mse_loss(y_pred, y, power=True)
error_value = error.item() error_value = error.item()
running_error += error_value running_error += error_value
@@ -884,7 +904,7 @@ class RegenerationTrainer:
if enable_progress: if enable_progress:
progress.update(task, advance=1, description=f"{error_value:.3e}") progress.update(task, advance=1, description=f"{error_value:.3e}")
running_error = running_error/len(valid_loader) running_error = running_error / len(valid_loader)
self.writer.add_scalar( self.writer.add_scalar(
"eval loss", "eval loss",
@@ -928,45 +948,65 @@ class RegenerationTrainer:
def run_model(self, model, loader, trace_powers=False): def run_model(self, model, loader, trace_powers=False):
model.eval() model.eval()
fiber_out = [] fiber_out = []
fiber_out_rot = []
fiber_in = [] fiber_in = []
regen = [] regen = []
timestamps = [] timestamps = []
angles = []
with torch.no_grad(): with torch.no_grad():
model = model.to(self.pytorch_settings.device) model = model.to(self.pytorch_settings.device)
for batch in loader: for batch in loader:
x = batch["x"] x = batch["x"]
y = batch["y"] y = batch["y"]
plot_target = batch["plot_target"]
angle = batch["mean_angle"]
center_angle = batch["center_angle"]
timestamp = batch["timestamp"] timestamp = batch["timestamp"]
plot_data = batch["plot_data"] plot_data = batch["plot_data"]
x, y = ( plot_data_rot = batch["plot_data_rot"]
x, y, angle = (
x.to(self.pytorch_settings.device), x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device), y.to(self.pytorch_settings.device),
angle.to(self.pytorch_settings.device),
) )
if trace_powers: if trace_powers:
y_pred, powers = model(x, trace_powers).cpu() y_pred, powers = model(x, angle, True).cpu()
else: else:
y_pred = model(x, trace_powers).cpu() y_pred = model(x, angle).cpu()
# x = x.cpu() # x = x.cpu()
# y = y.cpu() # y = y.cpu()
y_pred = y_pred.view(y_pred.shape[0], -1, 2) y_pred = y_pred.view(y_pred.shape[0], -1, 2)
y_pred = y_pred[:, y_pred.shape[1]//2, :]
y = y.view(y.shape[0], -1, 2) y = y.view(y.shape[0], -1, 2)
plot_data = plot_data.view(plot_data.shape[0], -1, 2) # plot_data = plot_data.view(plot_data.shape[0], -1, 2)
# c = torch.cos(-angle).cpu()
# s = torch.sin(-angle).cpu()
# rot = torch.stack([torch.stack([c, -s], dim=1), torch.stack([s, c], dim=1)], dim=2).squeeze(-1)
# plot_data = torch.bmm(plot_data, rot.to(dtype=plot_data.dtype))
# plot_data = plot_data
# sines = torch.sin(-angle.cpu())
# cosines = torch.cos(-angle.cpu())
# plot_data = torch.stack((plot_data[..., 0] * cosines - plot_data[..., 1] * sines, plot_data[..., 0] * sines + plot_data[..., 1] * cosines), dim=-1)
# x = x.view(x.shape[0], -1, 2) # x = x.view(x.shape[0], -1, 2)
# timestamp = timestamp.view(-1, 1) # timestamp = timestamp.view(-1, 1)
fiber_out.append(plot_data.squeeze()) fiber_out.append(plot_data.squeeze())
fiber_in.append(y.squeeze()) fiber_out_rot.append(plot_data_rot.squeeze())
fiber_in.append(plot_target.squeeze())
regen.append(y_pred.squeeze()) regen.append(y_pred.squeeze())
timestamps.append(timestamp.squeeze()) timestamps.append(timestamp.squeeze())
angles.append(center_angle.squeeze())
fiber_out = torch.vstack(fiber_out).cpu() fiber_out = torch.vstack(fiber_out).cpu()
fiber_out_rot = torch.vstack(fiber_out_rot).cpu()
fiber_in = torch.vstack(fiber_in).cpu() fiber_in = torch.vstack(fiber_in).cpu()
regen = torch.vstack(regen).cpu() regen = torch.vstack(regen).cpu()
angles = torch.vstack(angles).cpu()
timestamps = torch.concat(timestamps).cpu() timestamps = torch.concat(timestamps).cpu()
if trace_powers: if trace_powers:
return fiber_in, fiber_out, regen, timestamps, powers return fiber_in, fiber_out, fiber_out_rot, angles, regen, timestamps, powers
return fiber_in, fiber_out, regen, timestamps return fiber_in, fiber_out, fiber_out_rot, angles, regen, timestamps
def write_parameters(self, epoch, attributes: list[str] | tuple[str] = None): def write_parameters(self, epoch, attributes: list[str] | tuple[str] = None):
parameter_list = get_parameter_names_and_values(self.model) parameter_list = get_parameter_names_and_values(self.model)
@@ -1027,18 +1067,18 @@ class RegenerationTrainer:
for i, config_path in enumerate(self.data_settings.config_path): for i, config_path in enumerate(self.data_settings.config_path):
paths = Path.cwd().glob(config_path) paths = Path.cwd().glob(config_path)
for j, path in enumerate(paths): for j, path in enumerate(paths):
text = str(path) + '\n' text = str(path) + "\n"
with open(path, 'r') as f: with open(path, "r") as f:
text += f.read() text += f.read()
text += '\n' text += "\n"
self.writer.add_text(f"config_{i*len(self.data_settings.config_path)+j}", text) self.writer.add_text(f"config_{i * len(self.data_settings.config_path) + j}", text)
elif isinstance(self.data_settings.config_path, str): elif isinstance(self.data_settings.config_path, str):
paths = Path.cwd().glob(self.data_settings.config_path) paths = Path.cwd().glob(self.data_settings.config_path)
for j, path in enumerate(paths): for j, path in enumerate(paths):
text = str(path) + '\n' text = str(path) + "\n"
with open(path, 'r') as f: with open(path, "r") as f:
text += f.read() text += f.read()
text += '\n' text += "\n"
self.writer.add_text(f"config_{j}", text) self.writer.add_text(f"config_{j}", text)
self.writer.flush() self.writer.flush()
@@ -1116,6 +1156,7 @@ class RegenerationTrainer:
powers = [power / powers[0] for power in powers] powers = [power / powers[0] for power in powers]
fig, ax = plt.subplots() fig, ax = plt.subplots()
fig.set_figwidth(18) fig.set_figwidth(18)
fig.set_figheight(4)
fig.suptitle( fig.suptitle(
f"Energy conservation{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}" f"Energy conservation{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}"
) )
@@ -1184,6 +1225,7 @@ class RegenerationTrainer:
fig, axs = plt.subplots(1, 2 * len(signals), sharex=True, sharey=True) fig, axs = plt.subplots(1, 2 * len(signals), sharex=True, sharey=True)
fig.set_figwidth(18) fig.set_figwidth(18)
fig.set_figheight(4)
fig.suptitle(f"Eye diagram{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}") fig.suptitle(f"Eye diagram{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}")
# xaxis = timestamps / sps # xaxis = timestamps / sps
# xaxis = np.arange(2 * sps) / sps # xaxis = np.arange(2 * sps) / sps
@@ -1253,7 +1295,7 @@ class RegenerationTrainer:
xaxis = timestamps / sps xaxis = timestamps / sps
else: else:
xaxis = timestamps xaxis = timestamps
ax.plot(xaxis, np.abs(signal[:, i]) ** 2, label=label) ax.plot(xaxis, np.abs(signal[:, i]) ** 2, label=label, alpha=0.7)
ax.set_xlabel("Sample" if sps is None else "Symbol") ax.set_xlabel("Sample" if sps is None else "Symbol")
ax.set_ylabel("normalized power") ax.set_ylabel("normalized power")
ax.minorticks_on() ax.minorticks_on()
@@ -1269,7 +1311,7 @@ class RegenerationTrainer:
def plot_model_response( def plot_model_response(
self, self,
model:torch.nn.Module=None, model: torch.nn.Module = None,
title_append="", title_append="",
subtitle="", subtitle="",
# mode: Literal["eye", "head", "powers"] = "head", # mode: Literal["eye", "head", "powers"] = "head",
@@ -1281,7 +1323,9 @@ class RegenerationTrainer:
model = model.to(self.pytorch_settings.device) model = model.to(self.pytorch_settings.device)
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
_, powers = model(input_data, trace_powers=True) _, powers = model(
input_data, torch.zeros(input_data.shape[0], 1).to(self.pytorch_settings.device), trace_powers=True
)
powers = [power.item() for power in powers] powers = [power.item() for power in powers]
layer_names = [name for (name, _) in model.named_children()] layer_names = [name for (name, _) in model.named_children()]
@@ -1296,29 +1340,42 @@ class RegenerationTrainer:
self.data_settings.shuffle = False self.data_settings.shuffle = False
self.data_settings.train_split = 1.0 self.data_settings.train_split = 1.0
self.pytorch_settings.batchsize = max(self.pytorch_settings.head_symbols, self.pytorch_settings.eye_symbols) self.pytorch_settings.batchsize = max(self.pytorch_settings.head_symbols, self.pytorch_settings.eye_symbols)
config_path = random.choice(self.data_settings.config_path) if isinstance(self.data_settings.config_path, (list, tuple)) else self.data_settings.config_path config_path = (
fiber_length = int(float(str(config_path).split('-')[4])/1000) random.choice(self.data_settings.config_path)
if isinstance(self.data_settings.config_path, (list, tuple))
else self.data_settings.config_path
)
fiber_length = int(float(str(config_path).split("-")[4]) / 1000)
if not hasattr(self, "_plot_loader"): if not hasattr(self, "_plot_loader"):
self._plot_loader, _ = self.get_sliced_data( self._plot_loader, _ = self.get_sliced_data(
override={ override={
"num_symbols": self.pytorch_settings.batchsize, "num_symbols": self.pytorch_settings.batchsize,
"config_path": config_path, "config_path": config_path,
"shuffle": False, "shuffle": False,
"polarisations": (np.random.rand(1)*np.pi*2,), "polarisations": (np.random.rand(1) * np.pi * 2,),
"randomise_polarisation": False, "randomise_polarisation": self.data_settings.randomise_polarisations,
} }
) )
self._sps = self._plot_loader.dataset.samples_per_symbol self._sps = self._plot_loader.dataset.samples_per_symbol
self.data_settings = data_settings_backup self.data_settings = data_settings_backup
self.pytorch_settings = pytorch_settings_backup self.pytorch_settings = pytorch_settings_backup
fiber_in, fiber_out, regen, timestamps = self.run_model(model, self._plot_loader) fiber_in, fiber_out, fiber_out_rot, angles, regen, timestamps = self.run_model(model, self._plot_loader)
fiber_in = fiber_in.view(-1, 2) fiber_in = fiber_in.view(-1, 2)
fiber_out = fiber_out.view(-1, 2) fiber_out = fiber_out.view(-1, 2)
fiber_out_rot = fiber_out_rot.view(-1, 2)
angles = angles.view(-1, 1)
angles = angles.real
angles = torch.fmod(angles, 2 * torch.pi)
angles = torch.div(angles, 2*torch.pi)
angles = torch.repeat_interleave(angles, 2, dim=1)
regen = regen.view(-1, 2) regen = regen.view(-1, 2)
fiber_in = fiber_in.numpy() fiber_in = fiber_in.numpy()
fiber_out = fiber_out.numpy() fiber_out = fiber_out.numpy()
fiber_out_rot = fiber_out_rot.numpy()
angles = angles.numpy()
regen = regen.numpy() regen = regen.numpy()
timestamps = timestamps.numpy() timestamps = timestamps.numpy()
@@ -1327,28 +1384,29 @@ class RegenerationTrainer:
import gc import gc
head_fig = self._plot_model_response_head( head_fig = self._plot_model_response_head(
fiber_in[:self.pytorch_settings.head_symbols*self._sps], fiber_out_rot[: self.pytorch_settings.head_symbols * self._sps],
fiber_out[:self.pytorch_settings.head_symbols*self._sps], fiber_in[: self.pytorch_settings.head_symbols * self._sps],
regen[:self.pytorch_settings.head_symbols*self._sps], regen[: self.pytorch_settings.head_symbols * self._sps],
timestamps=timestamps[:self.pytorch_settings.head_symbols*self._sps], angles[: self.pytorch_settings.head_symbols * self._sps],
timestamps=timestamps[: self.pytorch_settings.head_symbols * self._sps],
labels=("fiber out", "fiber in", "regen", "normed angle"),
sps=self._sps,
title_append=title_append + f" ({fiber_length} km)",
subtitle=subtitle,
show=show,
)
# raise NotImplementedError("Eye diagram not implemented")
eye_fig = self._plot_model_response_eye(
fiber_in[: self.pytorch_settings.eye_symbols * self._sps],
fiber_out_rot[: self.pytorch_settings.eye_symbols * self._sps],
regen[: self.pytorch_settings.eye_symbols * self._sps],
timestamps=timestamps[: self.pytorch_settings.eye_symbols * self._sps],
labels=("fiber in", "fiber out", "regen"), labels=("fiber in", "fiber out", "regen"),
sps=self._sps, sps=self._sps,
title_append=title_append + f" ({fiber_length} km)", title_append=title_append + f" ({fiber_length} km)",
subtitle=subtitle, subtitle=subtitle,
show=show, show=show,
) )
# raise NotImplementedError("Eye diagram not implemented")
eye_fig = self._plot_model_response_eye(
fiber_in[:self.pytorch_settings.eye_symbols*self._sps],
fiber_out[:self.pytorch_settings.eye_symbols*self._sps],
regen[:self.pytorch_settings.eye_symbols*self._sps],
timestamps=timestamps[:self.pytorch_settings.eye_symbols*self._sps],
labels=("fiber in", "fiber out", "regen"),
sps=self._sps,
title_append=title_append + f" ({fiber_length} km)",
subtitle=subtitle,
show=show,
)
gc.collect() gc.collect()
return head_fig, eye_fig, power_fig return head_fig, eye_fig, power_fig
@@ -1361,7 +1419,7 @@ class RegenerationTrainer:
self.model_settings.overrides.get(f"n_hidden_nodes_{i}", -1) for i in range(model_n_hidden_layers) self.model_settings.overrides.get(f"n_hidden_nodes_{i}", -1) for i in range(model_n_hidden_layers)
] ]
model_dims.insert(0, input_dim) model_dims.insert(0, input_dim)
model_dims.append(2) model_dims.append(self.model_settings.output_dim)
model_dims = [str(dim) for dim in model_dims] model_dims = [str(dim) for dim in model_dims]
model_activation_func = self.model_settings.model_activation_func model_activation_func = self.model_settings.model_activation_func
model_dtype = self.data_settings.dtype model_dtype = self.data_settings.dtype

View File

@@ -1,6 +1,8 @@
from datetime import datetime from datetime import datetime
import optuna import optuna
import torch
import util
from hypertraining.hypertraining import HyperTraining from hypertraining.hypertraining import HyperTraining
from hypertraining.settings import ( from hypertraining.settings import (
GlobalSettings, GlobalSettings,
@@ -16,24 +18,29 @@ global_settings = GlobalSettings(
) )
data_settings = DataSettings( data_settings = DataSettings(
config_path="data/*-128-16384-100000-0-0-17-0-PAM4-0.ini", # config_path="data/*-128-16384-100000-0-0-17-0-PAM4-0.ini",
config_path="data/20241204-131003-128-16384-100000-0-0-17-0-PAM4-0.ini",
dtype="complex64", dtype="complex64",
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber # symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
symbols=13, # study: single_core_regen_20241123_011232 # symbols=13, # study: single_core_regen_20241123_011232
# symbols = (3, 13),
symbols=4,
# output_size = (11, 32), # ballpark 26 taps -> 2 taps per input symbol -> 1 tap every 0.01m (model has 52 inputs) # output_size = (11, 32), # ballpark 26 taps -> 2 taps per input symbol -> 1 tap every 0.01m (model has 52 inputs)
output_size=26, # study: single_core_regen_20241123_011232 (model_input_dim/2) # output_size=26, # study: single_core_regen_20241123_011232 (model_input_dim/2)
output_size=(8, 30),
shuffle=True, shuffle=True,
in_out_delay=0, in_out_delay=0,
xy_delay=0, xy_delay=0,
drop_first=128 * 100, drop_first=256,
train_split=0.8, train_split=0.8,
randomise_polarisations=False,
) )
pytorch_settings = PytorchSettings( pytorch_settings = PytorchSettings(
epochs=10000, epochs=10,
batchsize=2**10, batchsize=2**10,
device="cuda", device="cuda",
dataloader_workers=12, dataloader_workers=4,
dataloader_prefetch=4, dataloader_prefetch=4,
summary_dir=".runs", summary_dir=".runs",
write_every=2**5, write_every=2**5,
@@ -43,28 +50,70 @@ pytorch_settings = PytorchSettings(
model_settings = ModelSettings( model_settings = ModelSettings(
output_dim=2, output_dim=2,
# n_hidden_layers = (3, 8), n_hidden_layers = (2, 5),
n_hidden_layers=4, n_hidden_nodes=(2, 16),
overrides={ model_activation_func="EOActivation",
"n_hidden_nodes_0": 8, dropout_prob=0,
"n_hidden_nodes_1": 6, model_layer_function="ONNRect",
"n_hidden_nodes_2": 4, model_layer_kwargs={"square": True},
"n_hidden_nodes_3": 8, # scale=(False, True),
}, scale=False,
model_activation_func="Mag", model_layer_parametrizations=[
# satabsT0=(1e-6, 1), {
"tensor_name": "weight",
"parametrization": util.complexNN.energy_conserving,
},
{
"tensor_name": "alpha",
"parametrization": util.complexNN.clamp,
},
{
"tensor_name": "gain",
"parametrization": util.complexNN.clamp,
"kwargs": {
"min": 0,
"max": float("inf"),
},
},
{
"tensor_name": "phase_bias",
"parametrization": util.complexNN.clamp,
"kwargs": {
"min": 0,
"max": 2 * torch.pi,
},
},
{
"tensor_name": "scales",
"parametrization": util.complexNN.clamp,
},
{
"tensor_name": "angle",
"parametrization": util.complexNN.clamp,
"kwargs": {
"min": -torch.pi,
"max": torch.pi,
},
},
{
"tensor_name": "loss",
"parametrization": util.complexNN.clamp,
},
],
) )
optimizer_settings = OptimizerSettings( optimizer_settings = OptimizerSettings(
optimizer="Adam", optimizer="AdamW",
# learning_rate = (1e-5, 1e-1), optimizer_kwargs={
learning_rate=5e-3 "lr": 5e-3,
# learning_rate=5e-4, "amsgrad": True,
# "weight_decay": 1e-7,
},
) )
optuna_settings = OptunaSettings( optuna_settings = OptunaSettings(
n_trials=1, n_trials=1024,
n_workers=1, n_workers=8,
timeout=3600, timeout=3600,
directions=("minimize",), directions=("minimize",),
metrics_names=("mse",), metrics_names=("mse",),

View File

@@ -26,24 +26,26 @@ global_settings = GlobalSettings(
) )
data_settings = DataSettings( data_settings = DataSettings(
config_path="data/*-128-16384-10000-0-0-0-0-PAM4-0-0.4.ini", # config_path="data/*-128-16384-1-0-0-0-0-PAM4-0-0.ini",
config_path="data/*-128-16384-10000-0-0-17-0-PAM4-0.ini",
# config_path=[f"data/20241202-*-128-16384-{length}-0-0-17-0-PAM4-0.ini" for length in range(48000, 53000, 1000)], # config_path=[f"data/20241202-*-128-16384-{length}-0-0-17-0-PAM4-0.ini" for length in range(48000, 53000, 1000)],
dtype="complex64", dtype="complex64",
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber # symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
symbols=13, # study: single_core_regen_20241123_011232 symbols=4, # study: single_core_regen_20241123_011232
# output_size = (11, 32), # 26 taps -> 2 taps per input symbol -> 1 tap every 1cm (model has 52 inputs (x/y)) # output_size = (11, 32), # 26 taps -> 2 taps per input symbol -> 1 tap every 1cm (model has 52 inputs (x/y))
output_size=26, # study: single_core_regen_20241123_011232 (model_input_dim/2) output_size=20, # study: single_core_regen_20241123_011232 (model_input_dim/2)
shuffle=True, shuffle=True,
drop_first=64, drop_first=64,
train_split=0.8, train_split=0.8,
randomise_polarisations=True, randomise_polarisations=True,
osnr=10,
) )
pytorch_settings = PytorchSettings( pytorch_settings = PytorchSettings(
epochs=10000, epochs=10000,
batchsize=2**14, batchsize=2**14,
device="cuda", device="cuda",
dataloader_workers=16, dataloader_workers=24,
dataloader_prefetch=8, dataloader_prefetch=8,
summary_dir=".runs", summary_dir=".runs",
write_every=2**5, write_every=2**5,
@@ -53,17 +55,17 @@ pytorch_settings = PytorchSettings(
model_settings = ModelSettings( model_settings = ModelSettings(
output_dim=2, output_dim=2,
n_hidden_layers=5, n_hidden_layers=3,
overrides={ overrides={
# "hidden_layer_dims": (8, 8, 4, 4), # "hidden_layer_dims": (8, 8, 4, 4),
"n_hidden_nodes_0": 8, "n_hidden_nodes_0": 16,
"n_hidden_nodes_1": 8, "n_hidden_nodes_1": 8,
"n_hidden_nodes_2": 4, "n_hidden_nodes_2": 8,
"n_hidden_nodes_3": 4, # "n_hidden_nodes_3": 4,
"n_hidden_nodes_4": 2, # "n_hidden_nodes_4": 2,
}, },
model_activation_func="EOActivation", model_activation_func="EOActivation",
dropout_prob=0.01, dropout_prob=0,
model_layer_function="ONNRect", model_layer_function="ONNRect",
model_layer_kwargs={"square": True}, model_layer_kwargs={"square": True},
scale=False, scale=False,
@@ -126,7 +128,7 @@ model_settings = ModelSettings(
optimizer_settings = OptimizerSettings( optimizer_settings = OptimizerSettings(
optimizer="AdamW", optimizer="AdamW",
optimizer_kwargs={ optimizer_kwargs={
"lr": 0.01, "lr": 0.005,
"amsgrad": True, "amsgrad": True,
# "weight_decay": 1e-7, # "weight_decay": 1e-7,
}, },
@@ -242,7 +244,15 @@ if __name__ == "__main__":
pytorch_settings=pytorch_settings, pytorch_settings=pytorch_settings,
model_settings=model_settings, model_settings=model_settings,
optimizer_settings=optimizer_settings, optimizer_settings=optimizer_settings,
# checkpoint_path=".models/best_20241205_235929.tar", checkpoint_path=".models/best_20241216_221359.tar",
reset_epoch=True,
# settings_override={
# "optimizer_settings": {
# "optimizer_kwargs": {
# "lr": 0.01,
# },
# }
# }
# 20241202_143149 # 20241202_143149
) )
trainer.train() trainer.train()

View File

@@ -26,7 +26,7 @@ global_settings = GlobalSettings(
) )
data_settings = DataSettings( data_settings = DataSettings(
config_path="data/*-128-16384-10000-0-0-0-0-PAM4-0-0.4.ini", config_path="data/20241211-105524-128-16384-1-0-0-0-0-PAM4-0-0.ini",
# config_path=[f"data/20241202-*-128-16384-{length}-0-0-17-0-PAM4-0.ini" for length in range(48000, 53000, 1000)], # config_path=[f"data/20241202-*-128-16384-{length}-0-0-17-0-PAM4-0.ini" for length in range(48000, 53000, 1000)],
dtype="complex64", dtype="complex64",
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber # symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
@@ -53,14 +53,14 @@ pytorch_settings = PytorchSettings(
) )
model_settings = ModelSettings( model_settings = ModelSettings(
output_dim=3, output_dim=1,
n_hidden_layers=3, n_hidden_layers=3,
overrides={ overrides={
"n_hidden_nodes_0": 2, "n_hidden_nodes_0": 4,
"n_hidden_nodes_1": 2, "n_hidden_nodes_1": 4,
"n_hidden_nodes_2": 2, "n_hidden_nodes_2": 4,
}, },
dropout_prob=0.01, dropout_prob=0,
model_layer_function="ONNRect", model_layer_function="ONNRect",
model_activation_func="EOActivation", model_activation_func="EOActivation",
model_layer_kwargs={"square": True}, model_layer_kwargs={"square": True},
@@ -110,20 +110,24 @@ model_settings = ModelSettings(
) )
optimizer_settings = OptimizerSettings( optimizer_settings = OptimizerSettings(
optimizer="AdamW", optimizer="RMSprop",
# optimizer="AdamW",
optimizer_kwargs={ optimizer_kwargs={
"lr": 0.005, "lr": 0.01,
"amsgrad": True, "alpha": 0.9,
"momentum": 0.1,
"eps": 1e-8,
"centered": True,
# "amsgrad": True,
# "weight_decay": 1e-7, # "weight_decay": 1e-7,
}, },
# learning_rate=0.05,
scheduler="ReduceLROnPlateau", scheduler="ReduceLROnPlateau",
scheduler_kwargs={ scheduler_kwargs={
"patience": 2**6, "patience": 2**5,
"factor": 0.75, "factor": 0.75,
# "threshold": 1e-3, # "threshold": 1e-3,
"min_lr": 1e-6, "min_lr": 1e-6,
"cooldown": 10, # "cooldown": 10,
}, },
) )

View File

@@ -319,6 +319,29 @@ class normalize_by_first(nn.Module):
def forward(self, data): def forward(self, data):
return data / data[:, 0].unsqueeze(1) return data / data[:, 0].unsqueeze(1)
class rotate(nn.Module):
def __init__(self):
super(rotate, self).__init__()
def forward(self, data, angle):
# data -> (batch, n*2)
# angle -> (batch, n)
data_ = data
if angle.ndim == 1:
angle_ = angle.unsqueeze(1)
else:
angle_ = angle
angle_ = angle_.expand(-1, data_.shape[1]//2)
c = torch.cos(angle_)
s = torch.sin(angle_)
rot = torch.stack([torch.stack([c, -s], dim=2),
torch.stack([s, c], dim=2)], dim=3)
d = torch.bmm(data_.reshape(-1, 1, 2), rot.view(-1, 2, 2).to(dtype=data_.dtype)).reshape(*data.shape)
# d = torch.bmm(data.unsqueeze(-1).mT, rot.to(dtype=data.dtype).mT).mT.squeeze(-1)
return d
class photodiode(nn.Module): class photodiode(nn.Module):
def __init__(self, size, bias=True): def __init__(self, size, bias=True):
@@ -487,7 +510,7 @@ class MZISingle(nn.Module):
return x * torch.exp(1j * self.phi) * torch.sin(self.omega + self.func(x)) return x * torch.exp(1j * self.phi) * torch.sin(self.omega + self.func(x))
def naive_angle_loss(x: torch.Tensor, target: torch.Tensor, mod=2*torch.pi): def naive_angle_loss(x: torch.Tensor, target: torch.Tensor, mod=2*torch.pi):
return torch.fmod((x - target), mod).square().mean() return torch.fmod((x.abs().real - target.abs().real), mod).abs().mean()
def cosine_loss(x: torch.Tensor, target: torch.Tensor): def cosine_loss(x: torch.Tensor, target: torch.Tensor):
return (2*(1 - torch.cos(x - target))).mean() return (2*(1 - torch.cos(x - target))).mean()

View File

@@ -5,6 +5,7 @@ from torch.utils.data import Dataset
# from torch.utils.data import Sampler # from torch.utils.data import Sampler
import numpy as np import numpy as np
import configparser import configparser
import multiprocessing as mp
# class SubsetSampler(Sampler[int]): # class SubsetSampler(Sampler[int]):
# """ # """
@@ -113,8 +114,9 @@ class FiberRegenerationDataset(Dataset):
dtype: torch.dtype = None, dtype: torch.dtype = None,
real: bool = False, real: bool = False,
device=None, device=None,
polarisations: tuple | list = (0,), osnr: float = None,
randomise_polarisations: bool = False, randomise_polarisations: bool = False,
repeat_randoms: int = 1,
**kwargs, **kwargs,
): ):
""" """
@@ -190,18 +192,20 @@ class FiberRegenerationDataset(Dataset):
files.append(config["data"]["file"].strip('"')) files.append(config["data"]["file"].strip('"'))
self.config["data"]["file"] = str(files) self.config["data"]["file"] = str(files)
for i, angle in enumerate(torch.tensor(np.array(polarisations))): # if polarisations is not None:
data_raw_copy = data_raw.clone() # self.angles = torch.tensor(polarisations).repeat(len(data_raw), 1)
if angle == 0: # for i, angle in enumerate(torch.tensor(np.array(polarisations))):
continue # data_raw_copy = data_raw.clone()
sine = torch.sin(angle) # if angle == 0:
cosine = torch.cos(angle) # continue
data_raw_copy[:, 2] = data_raw[:, 2] * cosine - data_raw[:, 3] * sine # sine = torch.sin(angle)
data_raw_copy[:, 3] = data_raw[:, 2] * sine + data_raw[:, 3] * cosine # cosine = torch.cos(angle)
if i == 0: # data_raw_copy[:, 2] = data_raw[:, 2] * cosine - data_raw[:, 3] * sine
data_raw = data_raw_copy # data_raw_copy[:, 3] = data_raw[:, 2] * sine + data_raw[:, 3] * cosine
else: # if i == 0:
data_raw = torch.cat([data_raw, data_raw_copy], dim=0) # data_raw = data_raw_copy
# else:
# data_raw = torch.cat([data_raw, data_raw_copy], dim=0)
self.device = data_raw.device self.device = data_raw.device
@@ -278,23 +282,61 @@ class FiberRegenerationDataset(Dataset):
timestamps = data_raw[4, :] timestamps = data_raw[4, :]
data_raw = data_raw[:4, :] data_raw = data_raw[:4, :]
data_raw = data_raw.view(2, 2, -1) data_raw = data_raw.view(2, 2, -1)
timestamps_doubled = torch.cat([timestamps.unsqueeze(dim=0), timestamps.unsqueeze(dim=0)], dim=0).unsqueeze( fiber_in = data_raw[0, :, :]
dim=1 fiber_out = data_raw[1, :, :]
) # timestamps_doubled = torch.cat([timestamps.unsqueeze(dim=0), timestamps.unsqueeze(dim=0)], dim=0).unsqueeze(
data_raw = torch.cat([data_raw, timestamps_doubled], dim=1) # dim=1
# )
fiber_in = torch.cat([fiber_in, timestamps.unsqueeze(0)], dim=0)
fiber_out = torch.cat([fiber_out, timestamps.unsqueeze(0)], dim=0)
if repeat_randoms > 1:
fiber_in = fiber_in.repeat(1, 1, repeat_randoms)
fiber_out = fiber_out.repeat(1, 1, repeat_randoms)
# review: potential problems with repeated timestamps when plotting
else:
repeat_randoms = 1
if self.randomise_polarisations:
angles = torch.fmod(torch.randperm(data_raw.shape[-1]*repeat_randoms), 2) * torch.pi
# start_angle = torch.rand(1) * 2 * torch.pi
# angles = start_angle + torch.cumsum(torch.randn(data_raw.shape[-1])/333, dim=0) * torch.pi * 2 # random walk
# self.angles = torch.rand(self.data.shape[0]) * 2 * torch.pi
else:
angles = torch.zeros(data_raw.shape[-1])
sin = torch.sin(angles)
cos = torch.cos(angles)
rot = torch.stack([torch.stack([cos, -sin], dim=1), torch.stack([sin, cos], dim=1)], dim=2)
data_rot = torch.bmm(fiber_out[:2, :].T.unsqueeze(1), rot.to(dtype=data_raw.dtype)).squeeze(1).T
fiber_out = torch.cat((fiber_out, data_rot), dim=0)
fiber_out = torch.cat([fiber_out, angles.unsqueeze(0)], dim=0)
if osnr is not None:
popt = torch.mean(fiber_out[:2, :, :].abs().flatten(), dim=-1)
noise = torch.randn_like(fiber_out[:2, :, :])
pn = torch.mean(noise.abs().flatten(), dim=-1)
noise = noise * (popt / pn) * 10 ** (-osnr / 20)
fiber_out[:2, :, :] = torch.add(fiber_out[:2, :, :], noise)
# data_raw = torch.cat([data_raw, timestamps_doubled], dim=1) # data_raw = torch.cat([data_raw, timestamps_doubled], dim=1)
# data layout # data layout
# [ [E_in_x, E_in_y, timestamps], # [ [E_in_x, E_in_y, timestamps],
# [E_out_x, E_out_y, timestamps] ] # [E_out_x, E_out_y, timestamps] ]
self.data = data_raw.unfold(dimension=-1, size=self.samples_per_slice, step=1) self.fiber_in = fiber_in.unfold(dimension=-1, size=self.samples_per_slice, step=1)
self.data = self.data.movedim(-2, 0) self.fiber_in = self.fiber_in.movedim(-2, 0)
if randomise_polarisations: self.fiber_out = fiber_out.unfold(dimension=-1, size=self.samples_per_slice, step=1)
self.angles = torch.rand(self.data.shape[0]) * np.pi * 2 self.fiber_out = self.fiber_out.movedim(-2, 0)
# self.data[:, 1, :2, :] = self.rotate(self.data[:, 1, :2, :], self.angles)
else: # self.data = data_raw.unfold(dimension=-1, size=self.samples_per_slice, step=1)
self.angles = torch.zeros(self.data.shape[0]) # self.data = self.data.movedim(-2, 0)
# self.angles = torch.zeros(self.data.shape[0])
...
# ... # ...
# -> [no_slices, 2, 3, samples_per_slice] # -> [no_slices, 2, 3, samples_per_slice]
@@ -305,51 +347,56 @@ class FiberRegenerationDataset(Dataset):
# ... # ...
# ] -> [no_slices, 2, 3, samples_per_slice] # ] -> [no_slices, 2, 3, samples_per_slice]
...
def __len__(self): def __len__(self):
return self.data.shape[0] return self.fiber_in.shape[0]
def __getitem__(self, idx): def __getitem__(self, idx):
if isinstance(idx, slice): if isinstance(idx, slice):
return [self.__getitem__(i) for i in range(*idx.indices(len(self)))] return [self.__getitem__(i) for i in range(*idx.indices(len(self)))]
else: else:
data_slice = self.data[idx].squeeze() # fiber in: [E_in_x, E_in_y, timestamps]
# fiber out: [E_out_x, E_out_y, timestamps, E_out_x_rot, E_out_y_rot, angle]
data_slice = data_slice[:, :, : data_slice.shape[2] // self.output_dim * self.output_dim]
data_slice = data_slice.view(data_slice.shape[0], data_slice.shape[1], self.output_dim, -1) fiber_in = self.fiber_in[idx].squeeze()
fiber_out = self.fiber_out[idx].squeeze()
fiber_in = fiber_in[..., : fiber_in.shape[-1] // self.output_dim * self.output_dim]
fiber_out = fiber_out[..., : fiber_out.shape[-1] // self.output_dim * self.output_dim]
fiber_in = fiber_in.view(fiber_in.shape[0], self.output_dim, -1)
fiber_out = fiber_out.view(fiber_out.shape[0], self.output_dim, -1)
# data_slice = data_slice[:, :, : data_slice.shape[2] // self.output_dim * self.output_dim]
# data_slice = data_slice.view(data_slice.shape[0], data_slice.shape[1], self.output_dim, -1)
# angle = self.angles[idx]
center_angle = fiber_out[5, self.output_dim // 2, 0]
angles = fiber_out[5, :, 0]
plot_data = fiber_out[:2, self.output_dim // 2, 0].detach().clone()
plot_data_rot = fiber_out[3:5, self.output_dim // 2, 0].detach().clone()
data = fiber_out[3:5, :, 0]
# if self.randomise_polarisations: # if self.randomise_polarisations:
# angle = torch.rand(1) * torch.pi * 2 # data = data.mT
# sine = torch.sin(angle) # c = torch.cos(angle).unsqueeze(-1)
# cosine = torch.cos(angle) # s = torch.sin(angle).unsqueeze(-1)
# data_slice_ = data_slice[1] # rot = torch.stack([torch.stack([c, -s], dim=1), torch.stack([s, c], dim=1)], dim=2).squeeze(-1)
# data_slice[1, 0] = data_slice_[0] * cosine - data_slice_[1] * sine # data = torch.bmm(data.mT.unsqueeze(0), rot.to(dtype=data.dtype)).squeeze(-1)
# data_slice[1,1] = data_slice_[0] * sine + data_slice_[1] * cosine ...
# else:
# angle = torch.zeros(1)
# data = data_slice[1, :2, :, 0] # angle = torch.zeros_like(angle)
angle = self.angles[idx]
data_index = 1
data_slice[1, :2, :, :] = self.rotate(data_slice[data_index, :2, :, :], angle)
data = data_slice[1, :2, :, 0]
# data = self.rotate(data, angle)
# for both polarisations (x/y), calculate the mean of the signal around the current symbol (-> corresponding to a lowpass filter) # for both polarisations (x/y), calculate the mean of the signal around the current symbol (-> corresponding to a lowpass filter)
angle_data = data_slice[1, :2, :, :].reshape(2, -1).mean(dim=1) angle_data = fiber_out[:2, :, :].reshape(2, -1).mean(dim=1).repeat(1, self.output_dim)
angle_data2 = self.complex_max(data_slice[1, :2, :, :].reshape(2, -1)) angle_data2 = self.complex_max(fiber_out[:2, :, :].reshape(2, -1)).repeat(1, self.output_dim)
plot_data = data_slice[1, :2, self.output_dim // 2, 0] # sop = self.polarimeter(plot_data)
sop = self.polarimeter(plot_data)
# angle_data = data_slice[1, :2, :self.output_dim//2, :].squeeze().reshape(2, -1).mean(dim=-1) # angle_data = data_slice[1, :2, :self.output_dim//2, :].squeeze().reshape(2, -1).mean(dim=-1)
# angle = data_slice[1, 3, self.output_dim // 2, 0].real # angle = data_slice[1, 3, self.output_dim // 2, 0].real
target = data_slice[0, :2, self.output_dim // 2, 0] target = fiber_in[:2, self.output_dim // 2, 0]
target_timestamp = data_slice[0, 2, self.output_dim // 2, 0].real plot_target = fiber_in[:2, self.output_dim // 2, 0].detach().clone()
target_timestamp = fiber_in[2, self.output_dim // 2, 0].real
... ...
# data_timestamps = data[-1,:].real # data_timestamps = data[-1,:].real
@@ -360,22 +407,39 @@ class FiberRegenerationDataset(Dataset):
# transpose to interleave the x and y data in the output tensor # transpose to interleave the x and y data in the output tensor
data = data.transpose(0, 1).flatten().squeeze() data = data.transpose(0, 1).flatten().squeeze()
angle_data = angle_data.flatten().squeeze() angle_data = angle_data.transpose(0, 1).flatten().squeeze()
angle_data2 = angle_data.flatten().squeeze() angle_data2 = angle_data2.transpose(0,1).flatten().squeeze()
angle = angle.flatten().squeeze() center_angle = center_angle.flatten().squeeze()
angles = angles.flatten().squeeze()
# data_timestamps = data_timestamps.flatten().squeeze() # data_timestamps = data_timestamps.flatten().squeeze()
# target = target.transpose(0,1).flatten().squeeze()
target = target.flatten().squeeze() target = target.flatten().squeeze()
target_timestamp = target_timestamp.flatten().squeeze() target_timestamp = target_timestamp.flatten().squeeze()
plot_target = plot_target.flatten().squeeze()
plot_data = plot_data.flatten().squeeze()
plot_data_rot = plot_data_rot.flatten().squeeze()
return {
"x": data,
"y": target,
"center_angle": center_angle,
"angles": angles,
"mean_angle": angles.mean(),
# "sop": sop,
"angle_data": angle_data,
"angle_data2": angle_data2,
"timestamp": target_timestamp,
"plot_target": plot_target,
"plot_data": plot_data,
"plot_data_rot": plot_data_rot,
}
return {"x": data, "y": target, "angle": angle, "sop": sop, "angle_data": angle_data, "angle_data2": angle_data2, "timestamp": target_timestamp, "plot_data": plot_data}
def complex_max(self, data, dim=-1): def complex_max(self, data, dim=-1):
# returns element(s) with the maximum absolute value along a given dimension # returns element(s) with the maximum absolute value along a given dimension
# ind = torch.argmax(data.abs(), dim=dim, keepdim=True) # ind = torch.argmax(data.abs(), dim=dim, keepdim=True)
# max_values = torch.gather(data, dim, ind).squeeze(dim=dim) # max_values = torch.gather(data, dim, ind).squeeze(dim=dim)
# return max_values # return max_values
return torch.gather(data, dim, torch.argmax(data.abs(), dim=dim, keepdim=True)).squeeze(dim=dim) return torch.gather(data, dim, torch.argmax(data.abs(), dim=dim, keepdim=True)).squeeze(dim=dim)
def rotate(self, data, angle): def rotate(self, data, angle):
# rotates a 2d tensor by a given angle # rotates a 2d tensor by a given angle
@@ -388,7 +452,25 @@ class FiberRegenerationDataset(Dataset):
cosine = torch.cos(angle) cosine = torch.cos(angle)
return torch.stack([data[0] * cosine - data[1] * sine, data[0] * sine + data[1] * cosine], dim=0) return torch.stack([data[0] * cosine - data[1] * sine, data[0] * sine + data[1] * cosine], dim=0)
def rotate_all(self):
def do_rotation(j, num_processes):
for i in range(len(self) // num_processes):
index = i * num_processes + j
self.data[index, 1, :2, :] = self.rotate(self.data[index, 1, :2, :], self.angles[index])
self.processes = []
for j in range(mp.cpu_count()):
self.processes.append(mp.Process(target=do_rotation, args=(j, mp.cpu_count())))
self.processes[-1].start()
for p in self.processes:
p.join()
for i in range(len(self) // mp.cpu_count() * mp.cpu_count(), len(self)):
self.data[i, 1, :2, :] = self.rotate(self.data[i, 1, :2, :], self.angles[i])
def polarimeter(self, data): def polarimeter(self, data):
# data: [2, ...] -> x, y # data: [2, ...] -> x, y
# returns [4] -> S0, S1, S2, S3 # returns [4] -> S0, S1, S2, S3
@@ -396,12 +478,12 @@ class FiberRegenerationDataset(Dataset):
y = data[1].mean() y = data[1].mean()
I_X = x.abs().square() I_X = x.abs().square()
I_Y = y.abs().square() I_Y = y.abs().square()
I_45 = (x+y).abs().square() I_45 = (x + y).abs().square()
I_RHC = (x + 1j*y).abs().square() I_RHC = (x + 1j * y).abs().square()
S0 = I_X + I_Y S0 = I_X + I_Y
S1 = (2*I_X - S0) / S0 S1 = (2 * I_X - S0) / S0
S2 = (2*I_45 - S0) / S0 S2 = (2 * I_45 - S0) / S0
S3 = (2*I_RHC - S0) / S0 S3 = (2 * I_RHC - S0) / S0
return torch.stack([S1, S2, S3], dim=0) return torch.stack([S0, S1, S2, S3], dim=0)