update dataset configurations, add rotation module, and refine model settings for training, new hyperparameter tuning run for corrected datasets
This commit is contained in:
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:746ea83013870351296e01e294905ee027291ef79fd78c1e6b69dd9ebaa1cba0
|
||||
size 10240000
|
||||
oid sha256:76934d1d202aea1311ba67f5ea35eeb99a9c5c856f491565032e7d54ca6f072d
|
||||
size 13598720
|
||||
|
||||
@@ -26,6 +26,8 @@ import torch
|
||||
import torch.optim as optim
|
||||
import torch.utils.data
|
||||
|
||||
import hypertraining.models as models
|
||||
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
import multiprocessing
|
||||
@@ -253,14 +255,17 @@ class HyperTraining:
|
||||
model_kwargs = {
|
||||
"dims": (input_dim, *hidden_dims, self.model_settings.output_dim),
|
||||
"layer_function": layer_func,
|
||||
"layer_parametrizations": layer_parametrizations,
|
||||
"activation_function": afunc,
|
||||
"layer_func_kwargs": self.model_settings.model_layer_kwargs,
|
||||
"act_function": afunc,
|
||||
"act_func_kwargs": None,
|
||||
"parametrizations": layer_parametrizations,
|
||||
"dtype": dtype,
|
||||
"droupout_prob": self.model_settings.dropout_prob,
|
||||
"scale": scale_layers,
|
||||
"dropout_prob": self.model_settings.dropout_prob,
|
||||
"scale_layers": scale_layers,
|
||||
"rotate": False,
|
||||
}
|
||||
|
||||
model = util.complexNN.regenerator(**model_kwargs)
|
||||
model = models.regenerator(*model_kwargs.pop("dims"), **model_kwargs)
|
||||
n_nodes = sum(hidden_dims)
|
||||
|
||||
if writer is not None:
|
||||
@@ -381,7 +386,10 @@ class HyperTraining:
|
||||
running_loss = 0.0
|
||||
model.train()
|
||||
loader_len = len(train_loader)
|
||||
for batch_idx, (x, y, _) in enumerate(train_loader):
|
||||
for batch_idx, batch in enumerate(train_loader):
|
||||
x = batch["x"]
|
||||
y = batch["y"]
|
||||
|
||||
if batch_idx >= self.optuna_settings._n_train_batches:
|
||||
break
|
||||
model.zero_grad(set_to_none=True)
|
||||
@@ -390,7 +398,7 @@ class HyperTraining:
|
||||
y.to(self.pytorch_settings.device),
|
||||
)
|
||||
y_pred = model(x)
|
||||
loss = util.complexNN.complex_mse_loss(y_pred, y)
|
||||
loss = util.complexNN.complex_mse_loss(y_pred, y, power=True)
|
||||
loss_value = loss.item()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
@@ -444,7 +452,9 @@ class HyperTraining:
|
||||
model.eval()
|
||||
running_error = 0
|
||||
with torch.no_grad():
|
||||
for batch_idx, (x, y, _) in enumerate(valid_loader):
|
||||
for batch_idx, batch in enumerate(valid_loader):
|
||||
x = batch["x"]
|
||||
y = batch["y"]
|
||||
if batch_idx >= self.optuna_settings._n_valid_batches:
|
||||
break
|
||||
x, y = (
|
||||
@@ -452,50 +462,44 @@ class HyperTraining:
|
||||
y.to(self.pytorch_settings.device),
|
||||
)
|
||||
y_pred = model(x)
|
||||
error = util.complexNN.complex_mse_loss(y_pred, y)
|
||||
error = util.complexNN.complex_mse_loss(y_pred, y, power=True)
|
||||
error_value = error.item()
|
||||
running_error += error_value
|
||||
|
||||
running_error /= min(len(valid_loader), self.optuna_settings._n_valid_batches)
|
||||
|
||||
if writer is not None:
|
||||
title_append, subtitle = self.build_title(trial)
|
||||
writer.add_figure(
|
||||
"fiber response",
|
||||
self.plot_model_response(
|
||||
trial,
|
||||
model=model,
|
||||
title_append=title_append,
|
||||
subtitle=subtitle,
|
||||
show=False,
|
||||
),
|
||||
epoch + 1,
|
||||
)
|
||||
writer.add_figure(
|
||||
"eye diagram",
|
||||
self.plot_model_response(
|
||||
trial,
|
||||
model=self.model,
|
||||
title_append=title_append,
|
||||
subtitle=subtitle,
|
||||
show=False,
|
||||
mode="eye",
|
||||
),
|
||||
epoch + 1,
|
||||
writer.add_scalar(
|
||||
"eval loss",
|
||||
running_error,
|
||||
epoch,
|
||||
)
|
||||
# if (epoch + 1) % 10 == 0 or epoch < 10:
|
||||
# # plotting is slow, so only do it every 10 epochs
|
||||
# title_append, subtitle = self.build_title(trial)
|
||||
# head_fig, eye_fig, powers_fig = self.plot_model_response(
|
||||
# model=model,
|
||||
# title_append=title_append,
|
||||
# subtitle=subtitle,
|
||||
# show=False,
|
||||
# )
|
||||
# writer.add_figure(
|
||||
# "fiber response",
|
||||
# head_fig,
|
||||
# epoch + 1,
|
||||
# )
|
||||
# writer.add_figure(
|
||||
# "eye diagram",
|
||||
# eye_fig,
|
||||
# epoch + 1,
|
||||
# )
|
||||
|
||||
writer.add_figure(
|
||||
"powers",
|
||||
self.plot_model_response(
|
||||
trial,
|
||||
model=self.model,
|
||||
title_append=title_append,
|
||||
subtitle=subtitle,
|
||||
mode="powers",
|
||||
show=False,
|
||||
),
|
||||
epoch + 1,
|
||||
)
|
||||
# writer.add_figure(
|
||||
# "powers",
|
||||
# powers_fig,
|
||||
# epoch + 1,
|
||||
# )
|
||||
# writer.flush()
|
||||
|
||||
# if enable_progress:
|
||||
# progress.stop()
|
||||
@@ -511,15 +515,18 @@ class HyperTraining:
|
||||
|
||||
with torch.no_grad():
|
||||
model = model.to(self.pytorch_settings.device)
|
||||
for x, y, timestamp in loader:
|
||||
for batch in loader:
|
||||
x = batch["x"]
|
||||
y = batch["y"]
|
||||
timestamp = batch["timestamp"]
|
||||
x, y = (
|
||||
x.to(self.pytorch_settings.device),
|
||||
y.to(self.pytorch_settings.device),
|
||||
)
|
||||
if trace_powers:
|
||||
y_pred, powers = model(x, trace_powers).cpu()
|
||||
y_pred, powers = model(x, trace_powers=True).cpu()
|
||||
else:
|
||||
y_pred = model(x, trace_powers).cpu()
|
||||
y_pred = model(x, trace_powers=True).cpu()
|
||||
# x = x.cpu()
|
||||
# y = y.cpu()
|
||||
y_pred = y_pred.view(y_pred.shape[0], -1, 2)
|
||||
@@ -539,7 +546,7 @@ class HyperTraining:
|
||||
return fiber_in, fiber_out, regen, timestamps, powers
|
||||
return fiber_in, fiber_out, regen, timestamps
|
||||
|
||||
def objective(self, trial: optuna.Trial, plot_before=False):
|
||||
def objective(self, trial: optuna.Trial):
|
||||
if self.stop_study:
|
||||
trial.study.stop()
|
||||
model = None
|
||||
@@ -555,54 +562,54 @@ class HyperTraining:
|
||||
|
||||
title_append, subtitle = self.build_title(trial)
|
||||
|
||||
writer.add_figure(
|
||||
"fiber response",
|
||||
self.plot_model_response(
|
||||
trial,
|
||||
model=model,
|
||||
title_append=title_append,
|
||||
subtitle=subtitle,
|
||||
show=False,
|
||||
),
|
||||
0,
|
||||
)
|
||||
writer.add_figure(
|
||||
"eye diagram",
|
||||
self.plot_model_response(
|
||||
trial,
|
||||
model=self.model,
|
||||
title_append=title_append,
|
||||
subtitle=subtitle,
|
||||
mode="eye",
|
||||
show=False,
|
||||
),
|
||||
0,
|
||||
)
|
||||
# writer.add_figure(
|
||||
# "fiber response",
|
||||
# self.plot_model_response(
|
||||
# trial,
|
||||
# model=model,
|
||||
# title_append=title_append,
|
||||
# subtitle=subtitle,
|
||||
# show=False,
|
||||
# ),
|
||||
# 0,
|
||||
# )
|
||||
# writer.add_figure(
|
||||
# "eye diagram",
|
||||
# self.plot_model_response(
|
||||
# trial,
|
||||
# model=self.model,
|
||||
# title_append=title_append,
|
||||
# subtitle=subtitle,
|
||||
# mode="eye",
|
||||
# show=False,
|
||||
# ),
|
||||
# 0,
|
||||
# )
|
||||
|
||||
writer.add_figure(
|
||||
"powers",
|
||||
self.plot_model_response(
|
||||
trial,
|
||||
model=self.model,
|
||||
title_append=title_append,
|
||||
subtitle=subtitle,
|
||||
mode="powers",
|
||||
show=False,
|
||||
),
|
||||
0,
|
||||
)
|
||||
# writer.add_figure(
|
||||
# "powers",
|
||||
# self.plot_model_response(
|
||||
# trial,
|
||||
# model=self.model,
|
||||
# title_append=title_append,
|
||||
# subtitle=subtitle,
|
||||
# mode="powers",
|
||||
# show=False,
|
||||
# ),
|
||||
# 0,
|
||||
# )
|
||||
|
||||
train_loader, valid_loader = self.get_sliced_data(trial)
|
||||
|
||||
optimizer_name = trial.suggest_categorical_optional("optimizer", self.optimizer_settings.optimizer)
|
||||
|
||||
lr = trial.suggest_float_optional("lr", self.optimizer_settings.learning_rate, log=True)
|
||||
lr = trial.suggest_float_optional("lr", self.optimizer_settings.optimizer_kwargs["lr"], log=True)
|
||||
|
||||
optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)
|
||||
if self.optimizer_settings.scheduler is not None:
|
||||
scheduler = getattr(optim.lr_scheduler, self.optimizer_settings.scheduler)(
|
||||
optimizer, **self.optimizer_settings.scheduler_kwargs
|
||||
)
|
||||
# if self.optimizer_settings.scheduler is not None:
|
||||
# scheduler = getattr(optim.lr_scheduler, self.optimizer_settings.scheduler)(
|
||||
# optimizer, **self.optimizer_settings.scheduler_kwargs
|
||||
# )
|
||||
|
||||
for epoch in range(self.pytorch_settings.epochs):
|
||||
trial.set_user_attr("epoch", epoch)
|
||||
@@ -628,8 +635,8 @@ class HyperTraining:
|
||||
writer,
|
||||
# enable_progress=enable_progress,
|
||||
)
|
||||
if self.optimizer_settings.scheduler is not None:
|
||||
scheduler.step(error)
|
||||
# if self.optimizer_settings.scheduler is not None:
|
||||
# scheduler.step(error)
|
||||
|
||||
trial.set_user_attr("mse", error)
|
||||
trial.set_user_attr("log_mse", np.log10(error + np.finfo(float).eps))
|
||||
@@ -645,10 +652,10 @@ class HyperTraining:
|
||||
if self.optuna_settings._multi_objective:
|
||||
return -np.log10(error + np.finfo(float).eps), trial.user_attrs.get("model_n_nodes", -1)
|
||||
|
||||
if self.pytorch_settings.save_models and model is not None:
|
||||
save_path = Path(self.pytorch_settings.model_dir) / f"{self.optuna_settings.study_name}_{trial.number}.pth"
|
||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
torch.save(model, save_path)
|
||||
# if self.pytorch_settings.save_models and model is not None:
|
||||
# save_path = Path(self.pytorch_settings.model_dir) / f"{self.optuna_settings.study_name}_{trial.number}.pth"
|
||||
# save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
# torch.save(model, save_path)
|
||||
|
||||
return error
|
||||
|
||||
|
||||
@@ -8,7 +8,8 @@ from util.complexNN import (
|
||||
photodiode,
|
||||
EOActivation,
|
||||
polarimeter,
|
||||
normalize_by_first
|
||||
# normalize_by_first,
|
||||
rotate,
|
||||
)
|
||||
|
||||
|
||||
@@ -19,11 +20,11 @@ class polarisation_estimator2(Module):
|
||||
polarimeter(),
|
||||
torch.nn.Linear(4, 4),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Dropout(p=0.01),
|
||||
# torch.nn.Dropout(p=0.01),
|
||||
torch.nn.Linear(4, 4),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Dropout(p=0.01),
|
||||
torch.nn.Linear(4, 4),
|
||||
# torch.nn.Dropout(p=0.01),
|
||||
torch.nn.Linear(4, 1),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
@@ -124,6 +125,7 @@ class regenerator(Module):
|
||||
dtype=torch.float64,
|
||||
dropout_prob=0.01,
|
||||
scale_layers=False,
|
||||
rotate=False,
|
||||
):
|
||||
super(regenerator, self).__init__()
|
||||
self._n_hidden_layers = len(dims) - 2
|
||||
@@ -131,6 +133,8 @@ class regenerator(Module):
|
||||
layer_func_kwargs = layer_func_kwargs or {}
|
||||
act_func_kwargs = act_func_kwargs or {}
|
||||
|
||||
self.rotation = rotate
|
||||
|
||||
self.build_model(dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob, scale_layers)
|
||||
|
||||
def build_model(self, dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob, scale_layers):
|
||||
@@ -146,6 +150,7 @@ class regenerator(Module):
|
||||
module = act_function(size=dims[i + 1], **act_func_kwargs)
|
||||
self.get_submodule(f"layer_{i}").add_module("activation", module)
|
||||
|
||||
if dropout_prob is not None and dropout_prob > 0:
|
||||
module = DropoutComplex(p=dropout_prob)
|
||||
self.get_submodule(f"layer_{i}").add_module("dropout", module)
|
||||
|
||||
@@ -160,6 +165,10 @@ class regenerator(Module):
|
||||
module = act_function(size=dims[-1], **act_func_kwargs)
|
||||
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("activation", module)
|
||||
|
||||
if self.rotation:
|
||||
module = rotate()
|
||||
self.add_module("rotate", module)
|
||||
|
||||
# module = Scale(size=dims[-1])
|
||||
# self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("out_scale", module)
|
||||
|
||||
@@ -190,15 +199,27 @@ class regenerator(Module):
|
||||
powers.append(x.abs().square().sum())
|
||||
return powers
|
||||
|
||||
def forward(self, x, trace_powers=False):
|
||||
def forward(self, x, angle=None, pre_rot=False, trace_powers=False):
|
||||
powers = self._trace_powers(trace_powers, x)
|
||||
x = self.layer_0(x)
|
||||
powers = self._trace_powers(trace_powers, x, powers)
|
||||
for i in range(1, self._n_hidden_layers):
|
||||
# x = self.layer_0(x)
|
||||
# powers = self._trace_powers(trace_powers, x, powers)
|
||||
for i in range(0, self._n_hidden_layers):
|
||||
x = getattr(self, f"layer_{i}")(x)
|
||||
powers = self._trace_powers(trace_powers, x, powers)
|
||||
x = getattr(self, f"layer_{self._n_hidden_layers}")(x)
|
||||
powers = self._trace_powers(trace_powers, x, powers)
|
||||
if trace_powers:
|
||||
return x, powers
|
||||
return x
|
||||
if self.rotation:
|
||||
try:
|
||||
x_rot = self.rotate(x, angle)
|
||||
except AttributeError:
|
||||
pass
|
||||
powers = self._trace_powers(trace_powers, x_rot, powers)
|
||||
else:
|
||||
x_rot = x
|
||||
|
||||
if pre_rot and trace_powers:
|
||||
return x_rot, x, powers
|
||||
if pre_rot and not trace_powers:
|
||||
return x_rot, x
|
||||
if not pre_rot and trace_powers:
|
||||
return x_rot, powers
|
||||
return x_rot
|
||||
@@ -22,6 +22,8 @@ class DataSettings:
|
||||
train_split: float = 0.8
|
||||
polarisations: tuple | list = (0,)
|
||||
randomise_polarisations: bool = False
|
||||
osnr: float | int = None
|
||||
seed: int = None
|
||||
|
||||
"""
|
||||
change to:
|
||||
|
||||
@@ -2,7 +2,6 @@ import copy
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
import random
|
||||
from typing import Literal
|
||||
import matplotlib
|
||||
from matplotlib.colors import LinearSegmentedColormap
|
||||
import torch.nn.utils.parametrize
|
||||
@@ -60,25 +59,26 @@ def traverse_dict_update(target, source):
|
||||
except TypeError:
|
||||
target.__dict__[k] = v
|
||||
|
||||
|
||||
def get_parameter_names_and_values(model):
|
||||
def is_parametrized(module):
|
||||
if hasattr(module, "parametrizations"):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _get_param_info(module, prefix='', parametrization=False):
|
||||
def _get_param_info(module, prefix="", parametrization=False):
|
||||
param_list = []
|
||||
for name, param in module.named_parameters(recurse=parametrization):
|
||||
if parametrization and name.startswith("parametrizations"):
|
||||
name_parts = name.split('.')
|
||||
name_parts = name.split(".")
|
||||
name = name_parts[1]
|
||||
param = getattr(module, name)
|
||||
full_name = prefix + ('.' if prefix else '') + name
|
||||
full_name = prefix + ("." if prefix else "") + name
|
||||
param_value = param.data
|
||||
param_list.append((full_name, param_value))
|
||||
|
||||
for child_name, child_module in module.named_children():
|
||||
child_prefix = prefix + ('.' if prefix else '') + child_name
|
||||
child_prefix = prefix + ("." if prefix else "") + child_name
|
||||
if child_name == "parametrizations":
|
||||
continue
|
||||
param_list.extend(_get_param_info(child_module, child_prefix, is_parametrized(child_module)))
|
||||
@@ -87,6 +87,7 @@ def get_parameter_names_and_values(model):
|
||||
|
||||
return _get_param_info(model)
|
||||
|
||||
|
||||
class PolarizationTrainer:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -219,7 +220,7 @@ class PolarizationTrainer:
|
||||
|
||||
# dims = self.model_kwargs.pop("dims")
|
||||
model_kwargs = copy.deepcopy(self.model_kwargs)
|
||||
self.model = models.polarisation_estimator(*model_kwargs.pop('dims'),**model_kwargs)
|
||||
self.model = models.polarisation_estimator(*model_kwargs.pop("dims"), **model_kwargs)
|
||||
# self.model = models.polarisation_estimator2()
|
||||
|
||||
if self.writer is not None:
|
||||
@@ -336,17 +337,20 @@ class PolarizationTrainer:
|
||||
write_div = 0
|
||||
loss_div = 0
|
||||
for batch_idx, batch in enumerate(train_loader):
|
||||
x = batch["x"]
|
||||
y = batch["sop"]
|
||||
x = batch["angle_data2"]
|
||||
y = batch["center_angle"]
|
||||
self.model.zero_grad(set_to_none=True)
|
||||
x, y = (
|
||||
x.to(self.pytorch_settings.device),
|
||||
y.to(self.pytorch_settings.device),
|
||||
)
|
||||
y_pred = self.model(x)
|
||||
y_pred = self.model(x).abs().real
|
||||
# y_pred = torch.fmod(y_pred, self.mod)
|
||||
y = y.abs().real
|
||||
# y = torch.fmod(y, self.mod)
|
||||
# loss = torch.nn.functional.mse_loss(torch.cos(y_pred), torch.cos(y))
|
||||
# loss = torch.nn.functional.smooth_l1_loss(torch.cos(torch.fmod(y_pred, torch.pi/2)).squeeze(), torch.cos(torch.fmod(y, torch.pi/2)).squeeze(), beta=0.5)
|
||||
loss = torch.nn.functional.mse_loss(y_pred, y)
|
||||
# loss = util.complexNN.naive_angle_loss(y_pred, y, mod=torch.pi/2)
|
||||
loss = util.complexNN.naive_angle_loss(y_pred, y, mod=self.mod)
|
||||
loss_value = loss.item()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
@@ -356,7 +360,7 @@ class PolarizationTrainer:
|
||||
loss_div += 1
|
||||
|
||||
if enable_progress:
|
||||
progress.update(task, advance=1, description=f"{loss_value:.3e}")
|
||||
progress.update(task, advance=1, description=f"{loss_value/np.pi*180:.3e} °")
|
||||
|
||||
if batch_idx % self.pytorch_settings.write_every == 0:
|
||||
self.writer.add_scalar(
|
||||
@@ -395,22 +399,26 @@ class PolarizationTrainer:
|
||||
loss_div = 0
|
||||
with torch.no_grad():
|
||||
for _, batch in enumerate(valid_loader):
|
||||
x = batch["x"]
|
||||
y = batch["sop"]
|
||||
# x = batch["angle_data2"]
|
||||
x = batch["angle_data2"]
|
||||
y = batch["center_angle"]
|
||||
x, y = (
|
||||
x.to(self.pytorch_settings.device),
|
||||
y.to(self.pytorch_settings.device),
|
||||
)
|
||||
y_pred = self.model(x)
|
||||
y_pred = self.model(x).abs().real
|
||||
# y_pred = torch.fmod(y_pred, self.mod)
|
||||
y = y.abs().real
|
||||
# y = torch.fmod(y, self.mod)
|
||||
# loss = torch.nn.functional.mse_loss(torch.cos(y_pred), torch.cos(y))
|
||||
# loss = torch.nn.functional.smooth_l1_loss(torch.cos(torch.fmod(y_pred, torch.pi/2)).squeeze(), torch.cos(torch.fmod(y, torch.pi/2)).squeeze(), beta=0.5)
|
||||
loss = torch.nn.functional.mse_loss(y_pred, y)
|
||||
# loss = util.complexNN.naive_angle_loss(y_pred, y, mod=torch.pi/2)
|
||||
loss = util.complexNN.naive_angle_loss(y_pred, y, mod=self.mod)
|
||||
loss_value = loss.item()
|
||||
running_loss += loss_value
|
||||
loss_div += 1
|
||||
|
||||
if enable_progress:
|
||||
progress.update(task, advance=1, description=f"{loss_value:.3e}")
|
||||
progress.update(task, advance=1, description=f"{loss_value/np.pi*180:.3e} °")
|
||||
|
||||
running_loss = running_loss / loss_div
|
||||
|
||||
@@ -506,19 +514,19 @@ class PolarizationTrainer:
|
||||
for i, config_path in enumerate(self.data_settings.config_path):
|
||||
paths = Path.cwd().glob(config_path)
|
||||
for j, path in enumerate(paths):
|
||||
text = str(path) + '\n'
|
||||
with open(path, 'r') as f:
|
||||
text = str(path) + "\n"
|
||||
with open(path, "r") as f:
|
||||
text += f.read()
|
||||
text += '\n'
|
||||
text += "\n"
|
||||
self.writer.add_text(f"config_{i * len(self.data_settings.config_path) + j}", text)
|
||||
|
||||
elif isinstance(self.data_settings.config_path, str):
|
||||
paths = Path.cwd().glob(self.data_settings.config_path)
|
||||
for j, path in enumerate(paths):
|
||||
text = str(path) + '\n'
|
||||
with open(path, 'r') as f:
|
||||
text = str(path) + "\n"
|
||||
with open(path, "r") as f:
|
||||
text += f.read()
|
||||
text += '\n'
|
||||
text += "\n"
|
||||
self.writer.add_text(f"config_{j}", text)
|
||||
|
||||
self.writer.flush()
|
||||
@@ -571,7 +579,8 @@ class PolarizationTrainer:
|
||||
if loss < self.best["loss"]:
|
||||
self.best = checkpoint
|
||||
save_path = (
|
||||
Path(self.pytorch_settings.model_dir) / f"best_pol_{self.writer.get_logdir().split('/')[-1]}.tar"
|
||||
Path(self.pytorch_settings.model_dir)
|
||||
/ f"best_pol_{self.writer.get_logdir().split('/')[-1]}.tar"
|
||||
)
|
||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.save_checkpoint(self.best, save_path)
|
||||
@@ -580,6 +589,7 @@ class PolarizationTrainer:
|
||||
self.writer.close()
|
||||
return self.best
|
||||
|
||||
|
||||
class RegenerationTrainer:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -636,6 +646,10 @@ class RegenerationTrainer:
|
||||
self.model_settings: ModelSettings = model_settings
|
||||
self.optimizer_settings: OptimizerSettings = optimizer_settings
|
||||
|
||||
if self.global_settings.seed is not None:
|
||||
random.seed(self.global_settings.seed)
|
||||
np.random.seed(self.global_settings.seed)
|
||||
|
||||
self.console = console or Console()
|
||||
self.writer = None
|
||||
|
||||
@@ -706,10 +720,12 @@ class RegenerationTrainer:
|
||||
|
||||
# dims = self.model_kwargs.pop("dims")
|
||||
model_kwargs = copy.deepcopy(self.model_kwargs)
|
||||
self.model = models.regenerator(*model_kwargs.pop('dims'),**model_kwargs)
|
||||
self.model = models.regenerator(*model_kwargs.pop("dims"), **model_kwargs)
|
||||
|
||||
if self.writer is not None:
|
||||
self.writer.add_graph(self.model, torch.zeros(1, input_dim, dtype=dtype))
|
||||
self.writer.add_graph(
|
||||
self.model, (torch.rand(1, input_dim, dtype=dtype), torch.rand(1, 1, dtype=dtype.to_real()))
|
||||
)
|
||||
|
||||
self.model = self.model.to(self.pytorch_settings.device)
|
||||
if self.resume:
|
||||
@@ -728,12 +744,12 @@ class RegenerationTrainer:
|
||||
|
||||
num_symbols = None
|
||||
config_path = self.data_settings.config_path
|
||||
polarisations = self.data_settings.polarisations
|
||||
randomise_polarisations = self.data_settings.randomise_polarisations
|
||||
osnr = self.data_settings.osnr
|
||||
if override is not None:
|
||||
num_symbols = override.get("num_symbols", None)
|
||||
config_path = override.get("config_path", config_path)
|
||||
polarisations = override.get("polarisations", polarisations)
|
||||
# polarisations = override.get("polarisations", polarisations)
|
||||
randomise_polarisations = override.get("randomise_polarisation", randomise_polarisations)
|
||||
# get dataset
|
||||
dataset = FiberRegenerationDataset(
|
||||
@@ -746,8 +762,8 @@ class RegenerationTrainer:
|
||||
dtype=dtype,
|
||||
real=not dtype.is_complex,
|
||||
num_symbols=num_symbols,
|
||||
polarisations=polarisations,
|
||||
randomise_polarisations=randomise_polarisations,
|
||||
osnr = osnr,
|
||||
)
|
||||
|
||||
dataset_size = len(dataset)
|
||||
@@ -819,12 +835,14 @@ class RegenerationTrainer:
|
||||
for batch_idx, batch in enumerate(train_loader):
|
||||
x = batch["x"]
|
||||
y = batch["y"]
|
||||
angles = batch["mean_angle"]
|
||||
self.model.zero_grad(set_to_none=True)
|
||||
x, y = (
|
||||
x, y, angles = (
|
||||
x.to(self.pytorch_settings.device),
|
||||
y.to(self.pytorch_settings.device),
|
||||
angles.to(self.pytorch_settings.device),
|
||||
)
|
||||
y_pred = self.model(x)
|
||||
y_pred = self.model(x, -angles)
|
||||
loss = util.complexNN.complex_mse_loss(y_pred, y, power=True)
|
||||
loss_value = loss.item()
|
||||
loss.backward()
|
||||
@@ -872,11 +890,13 @@ class RegenerationTrainer:
|
||||
for _, batch in enumerate(valid_loader):
|
||||
x = batch["x"]
|
||||
y = batch["y"]
|
||||
x, y = (
|
||||
angles = batch["mean_angle"]
|
||||
x, y, angles = (
|
||||
x.to(self.pytorch_settings.device),
|
||||
y.to(self.pytorch_settings.device),
|
||||
angles.to(self.pytorch_settings.device),
|
||||
)
|
||||
y_pred = self.model(x)
|
||||
y_pred = self.model(x, -angles)
|
||||
error = util.complexNN.complex_mse_loss(y_pred, y, power=True)
|
||||
error_value = error.item()
|
||||
running_error += error_value
|
||||
@@ -928,45 +948,65 @@ class RegenerationTrainer:
|
||||
def run_model(self, model, loader, trace_powers=False):
|
||||
model.eval()
|
||||
fiber_out = []
|
||||
fiber_out_rot = []
|
||||
fiber_in = []
|
||||
regen = []
|
||||
timestamps = []
|
||||
angles = []
|
||||
|
||||
with torch.no_grad():
|
||||
model = model.to(self.pytorch_settings.device)
|
||||
for batch in loader:
|
||||
x = batch["x"]
|
||||
y = batch["y"]
|
||||
plot_target = batch["plot_target"]
|
||||
angle = batch["mean_angle"]
|
||||
center_angle = batch["center_angle"]
|
||||
timestamp = batch["timestamp"]
|
||||
plot_data = batch["plot_data"]
|
||||
x, y = (
|
||||
plot_data_rot = batch["plot_data_rot"]
|
||||
x, y, angle = (
|
||||
x.to(self.pytorch_settings.device),
|
||||
y.to(self.pytorch_settings.device),
|
||||
angle.to(self.pytorch_settings.device),
|
||||
)
|
||||
if trace_powers:
|
||||
y_pred, powers = model(x, trace_powers).cpu()
|
||||
y_pred, powers = model(x, angle, True).cpu()
|
||||
else:
|
||||
y_pred = model(x, trace_powers).cpu()
|
||||
y_pred = model(x, angle).cpu()
|
||||
# x = x.cpu()
|
||||
# y = y.cpu()
|
||||
y_pred = y_pred.view(y_pred.shape[0], -1, 2)
|
||||
y_pred = y_pred[:, y_pred.shape[1]//2, :]
|
||||
y = y.view(y.shape[0], -1, 2)
|
||||
plot_data = plot_data.view(plot_data.shape[0], -1, 2)
|
||||
# plot_data = plot_data.view(plot_data.shape[0], -1, 2)
|
||||
# c = torch.cos(-angle).cpu()
|
||||
# s = torch.sin(-angle).cpu()
|
||||
# rot = torch.stack([torch.stack([c, -s], dim=1), torch.stack([s, c], dim=1)], dim=2).squeeze(-1)
|
||||
# plot_data = torch.bmm(plot_data, rot.to(dtype=plot_data.dtype))
|
||||
# plot_data = plot_data
|
||||
# sines = torch.sin(-angle.cpu())
|
||||
# cosines = torch.cos(-angle.cpu())
|
||||
# plot_data = torch.stack((plot_data[..., 0] * cosines - plot_data[..., 1] * sines, plot_data[..., 0] * sines + plot_data[..., 1] * cosines), dim=-1)
|
||||
# x = x.view(x.shape[0], -1, 2)
|
||||
|
||||
# timestamp = timestamp.view(-1, 1)
|
||||
fiber_out.append(plot_data.squeeze())
|
||||
fiber_in.append(y.squeeze())
|
||||
fiber_out_rot.append(plot_data_rot.squeeze())
|
||||
fiber_in.append(plot_target.squeeze())
|
||||
regen.append(y_pred.squeeze())
|
||||
timestamps.append(timestamp.squeeze())
|
||||
angles.append(center_angle.squeeze())
|
||||
|
||||
fiber_out = torch.vstack(fiber_out).cpu()
|
||||
fiber_out_rot = torch.vstack(fiber_out_rot).cpu()
|
||||
fiber_in = torch.vstack(fiber_in).cpu()
|
||||
regen = torch.vstack(regen).cpu()
|
||||
angles = torch.vstack(angles).cpu()
|
||||
timestamps = torch.concat(timestamps).cpu()
|
||||
if trace_powers:
|
||||
return fiber_in, fiber_out, regen, timestamps, powers
|
||||
return fiber_in, fiber_out, regen, timestamps
|
||||
return fiber_in, fiber_out, fiber_out_rot, angles, regen, timestamps, powers
|
||||
return fiber_in, fiber_out, fiber_out_rot, angles, regen, timestamps
|
||||
|
||||
def write_parameters(self, epoch, attributes: list[str] | tuple[str] = None):
|
||||
parameter_list = get_parameter_names_and_values(self.model)
|
||||
@@ -1027,18 +1067,18 @@ class RegenerationTrainer:
|
||||
for i, config_path in enumerate(self.data_settings.config_path):
|
||||
paths = Path.cwd().glob(config_path)
|
||||
for j, path in enumerate(paths):
|
||||
text = str(path) + '\n'
|
||||
with open(path, 'r') as f:
|
||||
text = str(path) + "\n"
|
||||
with open(path, "r") as f:
|
||||
text += f.read()
|
||||
text += '\n'
|
||||
text += "\n"
|
||||
self.writer.add_text(f"config_{i * len(self.data_settings.config_path) + j}", text)
|
||||
elif isinstance(self.data_settings.config_path, str):
|
||||
paths = Path.cwd().glob(self.data_settings.config_path)
|
||||
for j, path in enumerate(paths):
|
||||
text = str(path) + '\n'
|
||||
with open(path, 'r') as f:
|
||||
text = str(path) + "\n"
|
||||
with open(path, "r") as f:
|
||||
text += f.read()
|
||||
text += '\n'
|
||||
text += "\n"
|
||||
self.writer.add_text(f"config_{j}", text)
|
||||
|
||||
self.writer.flush()
|
||||
@@ -1116,6 +1156,7 @@ class RegenerationTrainer:
|
||||
powers = [power / powers[0] for power in powers]
|
||||
fig, ax = plt.subplots()
|
||||
fig.set_figwidth(18)
|
||||
fig.set_figheight(4)
|
||||
fig.suptitle(
|
||||
f"Energy conservation{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}"
|
||||
)
|
||||
@@ -1184,6 +1225,7 @@ class RegenerationTrainer:
|
||||
|
||||
fig, axs = plt.subplots(1, 2 * len(signals), sharex=True, sharey=True)
|
||||
fig.set_figwidth(18)
|
||||
fig.set_figheight(4)
|
||||
fig.suptitle(f"Eye diagram{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}")
|
||||
# xaxis = timestamps / sps
|
||||
# xaxis = np.arange(2 * sps) / sps
|
||||
@@ -1253,7 +1295,7 @@ class RegenerationTrainer:
|
||||
xaxis = timestamps / sps
|
||||
else:
|
||||
xaxis = timestamps
|
||||
ax.plot(xaxis, np.abs(signal[:, i]) ** 2, label=label)
|
||||
ax.plot(xaxis, np.abs(signal[:, i]) ** 2, label=label, alpha=0.7)
|
||||
ax.set_xlabel("Sample" if sps is None else "Symbol")
|
||||
ax.set_ylabel("normalized power")
|
||||
ax.minorticks_on()
|
||||
@@ -1281,7 +1323,9 @@ class RegenerationTrainer:
|
||||
model = model.to(self.pytorch_settings.device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
_, powers = model(input_data, trace_powers=True)
|
||||
_, powers = model(
|
||||
input_data, torch.zeros(input_data.shape[0], 1).to(self.pytorch_settings.device), trace_powers=True
|
||||
)
|
||||
|
||||
powers = [power.item() for power in powers]
|
||||
layer_names = [name for (name, _) in model.named_children()]
|
||||
@@ -1296,8 +1340,12 @@ class RegenerationTrainer:
|
||||
self.data_settings.shuffle = False
|
||||
self.data_settings.train_split = 1.0
|
||||
self.pytorch_settings.batchsize = max(self.pytorch_settings.head_symbols, self.pytorch_settings.eye_symbols)
|
||||
config_path = random.choice(self.data_settings.config_path) if isinstance(self.data_settings.config_path, (list, tuple)) else self.data_settings.config_path
|
||||
fiber_length = int(float(str(config_path).split('-')[4])/1000)
|
||||
config_path = (
|
||||
random.choice(self.data_settings.config_path)
|
||||
if isinstance(self.data_settings.config_path, (list, tuple))
|
||||
else self.data_settings.config_path
|
||||
)
|
||||
fiber_length = int(float(str(config_path).split("-")[4]) / 1000)
|
||||
if not hasattr(self, "_plot_loader"):
|
||||
self._plot_loader, _ = self.get_sliced_data(
|
||||
override={
|
||||
@@ -1305,20 +1353,29 @@ class RegenerationTrainer:
|
||||
"config_path": config_path,
|
||||
"shuffle": False,
|
||||
"polarisations": (np.random.rand(1) * np.pi * 2,),
|
||||
"randomise_polarisation": False,
|
||||
"randomise_polarisation": self.data_settings.randomise_polarisations,
|
||||
}
|
||||
)
|
||||
self._sps = self._plot_loader.dataset.samples_per_symbol
|
||||
self.data_settings = data_settings_backup
|
||||
self.pytorch_settings = pytorch_settings_backup
|
||||
|
||||
fiber_in, fiber_out, regen, timestamps = self.run_model(model, self._plot_loader)
|
||||
fiber_in, fiber_out, fiber_out_rot, angles, regen, timestamps = self.run_model(model, self._plot_loader)
|
||||
fiber_in = fiber_in.view(-1, 2)
|
||||
fiber_out = fiber_out.view(-1, 2)
|
||||
fiber_out_rot = fiber_out_rot.view(-1, 2)
|
||||
angles = angles.view(-1, 1)
|
||||
angles = angles.real
|
||||
angles = torch.fmod(angles, 2 * torch.pi)
|
||||
angles = torch.div(angles, 2*torch.pi)
|
||||
angles = torch.repeat_interleave(angles, 2, dim=1)
|
||||
|
||||
regen = regen.view(-1, 2)
|
||||
|
||||
fiber_in = fiber_in.numpy()
|
||||
fiber_out = fiber_out.numpy()
|
||||
fiber_out_rot = fiber_out_rot.numpy()
|
||||
angles = angles.numpy()
|
||||
regen = regen.numpy()
|
||||
timestamps = timestamps.numpy()
|
||||
|
||||
@@ -1327,11 +1384,12 @@ class RegenerationTrainer:
|
||||
import gc
|
||||
|
||||
head_fig = self._plot_model_response_head(
|
||||
fiber_out_rot[: self.pytorch_settings.head_symbols * self._sps],
|
||||
fiber_in[: self.pytorch_settings.head_symbols * self._sps],
|
||||
fiber_out[:self.pytorch_settings.head_symbols*self._sps],
|
||||
regen[: self.pytorch_settings.head_symbols * self._sps],
|
||||
angles[: self.pytorch_settings.head_symbols * self._sps],
|
||||
timestamps=timestamps[: self.pytorch_settings.head_symbols * self._sps],
|
||||
labels=("fiber in", "fiber out", "regen"),
|
||||
labels=("fiber out", "fiber in", "regen", "normed angle"),
|
||||
sps=self._sps,
|
||||
title_append=title_append + f" ({fiber_length} km)",
|
||||
subtitle=subtitle,
|
||||
@@ -1340,7 +1398,7 @@ class RegenerationTrainer:
|
||||
# raise NotImplementedError("Eye diagram not implemented")
|
||||
eye_fig = self._plot_model_response_eye(
|
||||
fiber_in[: self.pytorch_settings.eye_symbols * self._sps],
|
||||
fiber_out[:self.pytorch_settings.eye_symbols*self._sps],
|
||||
fiber_out_rot[: self.pytorch_settings.eye_symbols * self._sps],
|
||||
regen[: self.pytorch_settings.eye_symbols * self._sps],
|
||||
timestamps=timestamps[: self.pytorch_settings.eye_symbols * self._sps],
|
||||
labels=("fiber in", "fiber out", "regen"),
|
||||
@@ -1361,7 +1419,7 @@ class RegenerationTrainer:
|
||||
self.model_settings.overrides.get(f"n_hidden_nodes_{i}", -1) for i in range(model_n_hidden_layers)
|
||||
]
|
||||
model_dims.insert(0, input_dim)
|
||||
model_dims.append(2)
|
||||
model_dims.append(self.model_settings.output_dim)
|
||||
model_dims = [str(dim) for dim in model_dims]
|
||||
model_activation_func = self.model_settings.model_activation_func
|
||||
model_dtype = self.data_settings.dtype
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from datetime import datetime
|
||||
|
||||
import optuna
|
||||
import torch
|
||||
import util
|
||||
from hypertraining.hypertraining import HyperTraining
|
||||
from hypertraining.settings import (
|
||||
GlobalSettings,
|
||||
@@ -16,24 +18,29 @@ global_settings = GlobalSettings(
|
||||
)
|
||||
|
||||
data_settings = DataSettings(
|
||||
config_path="data/*-128-16384-100000-0-0-17-0-PAM4-0.ini",
|
||||
# config_path="data/*-128-16384-100000-0-0-17-0-PAM4-0.ini",
|
||||
config_path="data/20241204-131003-128-16384-100000-0-0-17-0-PAM4-0.ini",
|
||||
dtype="complex64",
|
||||
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
|
||||
symbols=13, # study: single_core_regen_20241123_011232
|
||||
# symbols=13, # study: single_core_regen_20241123_011232
|
||||
# symbols = (3, 13),
|
||||
symbols=4,
|
||||
# output_size = (11, 32), # ballpark 26 taps -> 2 taps per input symbol -> 1 tap every 0.01m (model has 52 inputs)
|
||||
output_size=26, # study: single_core_regen_20241123_011232 (model_input_dim/2)
|
||||
# output_size=26, # study: single_core_regen_20241123_011232 (model_input_dim/2)
|
||||
output_size=(8, 30),
|
||||
shuffle=True,
|
||||
in_out_delay=0,
|
||||
xy_delay=0,
|
||||
drop_first=128 * 100,
|
||||
drop_first=256,
|
||||
train_split=0.8,
|
||||
randomise_polarisations=False,
|
||||
)
|
||||
|
||||
pytorch_settings = PytorchSettings(
|
||||
epochs=10000,
|
||||
epochs=10,
|
||||
batchsize=2**10,
|
||||
device="cuda",
|
||||
dataloader_workers=12,
|
||||
dataloader_workers=4,
|
||||
dataloader_prefetch=4,
|
||||
summary_dir=".runs",
|
||||
write_every=2**5,
|
||||
@@ -43,28 +50,70 @@ pytorch_settings = PytorchSettings(
|
||||
|
||||
model_settings = ModelSettings(
|
||||
output_dim=2,
|
||||
# n_hidden_layers = (3, 8),
|
||||
n_hidden_layers=4,
|
||||
overrides={
|
||||
"n_hidden_nodes_0": 8,
|
||||
"n_hidden_nodes_1": 6,
|
||||
"n_hidden_nodes_2": 4,
|
||||
"n_hidden_nodes_3": 8,
|
||||
n_hidden_layers = (2, 5),
|
||||
n_hidden_nodes=(2, 16),
|
||||
model_activation_func="EOActivation",
|
||||
dropout_prob=0,
|
||||
model_layer_function="ONNRect",
|
||||
model_layer_kwargs={"square": True},
|
||||
# scale=(False, True),
|
||||
scale=False,
|
||||
model_layer_parametrizations=[
|
||||
{
|
||||
"tensor_name": "weight",
|
||||
"parametrization": util.complexNN.energy_conserving,
|
||||
},
|
||||
model_activation_func="Mag",
|
||||
# satabsT0=(1e-6, 1),
|
||||
{
|
||||
"tensor_name": "alpha",
|
||||
"parametrization": util.complexNN.clamp,
|
||||
},
|
||||
{
|
||||
"tensor_name": "gain",
|
||||
"parametrization": util.complexNN.clamp,
|
||||
"kwargs": {
|
||||
"min": 0,
|
||||
"max": float("inf"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"tensor_name": "phase_bias",
|
||||
"parametrization": util.complexNN.clamp,
|
||||
"kwargs": {
|
||||
"min": 0,
|
||||
"max": 2 * torch.pi,
|
||||
},
|
||||
},
|
||||
{
|
||||
"tensor_name": "scales",
|
||||
"parametrization": util.complexNN.clamp,
|
||||
},
|
||||
{
|
||||
"tensor_name": "angle",
|
||||
"parametrization": util.complexNN.clamp,
|
||||
"kwargs": {
|
||||
"min": -torch.pi,
|
||||
"max": torch.pi,
|
||||
},
|
||||
},
|
||||
{
|
||||
"tensor_name": "loss",
|
||||
"parametrization": util.complexNN.clamp,
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
optimizer_settings = OptimizerSettings(
|
||||
optimizer="Adam",
|
||||
# learning_rate = (1e-5, 1e-1),
|
||||
learning_rate=5e-3
|
||||
# learning_rate=5e-4,
|
||||
optimizer="AdamW",
|
||||
optimizer_kwargs={
|
||||
"lr": 5e-3,
|
||||
"amsgrad": True,
|
||||
# "weight_decay": 1e-7,
|
||||
},
|
||||
)
|
||||
|
||||
optuna_settings = OptunaSettings(
|
||||
n_trials=1,
|
||||
n_workers=1,
|
||||
n_trials=1024,
|
||||
n_workers=8,
|
||||
timeout=3600,
|
||||
directions=("minimize",),
|
||||
metrics_names=("mse",),
|
||||
|
||||
@@ -26,24 +26,26 @@ global_settings = GlobalSettings(
|
||||
)
|
||||
|
||||
data_settings = DataSettings(
|
||||
config_path="data/*-128-16384-10000-0-0-0-0-PAM4-0-0.4.ini",
|
||||
# config_path="data/*-128-16384-1-0-0-0-0-PAM4-0-0.ini",
|
||||
config_path="data/*-128-16384-10000-0-0-17-0-PAM4-0.ini",
|
||||
# config_path=[f"data/20241202-*-128-16384-{length}-0-0-17-0-PAM4-0.ini" for length in range(48000, 53000, 1000)],
|
||||
dtype="complex64",
|
||||
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
|
||||
symbols=13, # study: single_core_regen_20241123_011232
|
||||
symbols=4, # study: single_core_regen_20241123_011232
|
||||
# output_size = (11, 32), # 26 taps -> 2 taps per input symbol -> 1 tap every 1cm (model has 52 inputs (x/y))
|
||||
output_size=26, # study: single_core_regen_20241123_011232 (model_input_dim/2)
|
||||
output_size=20, # study: single_core_regen_20241123_011232 (model_input_dim/2)
|
||||
shuffle=True,
|
||||
drop_first=64,
|
||||
train_split=0.8,
|
||||
randomise_polarisations=True,
|
||||
osnr=10,
|
||||
)
|
||||
|
||||
pytorch_settings = PytorchSettings(
|
||||
epochs=10000,
|
||||
batchsize=2**14,
|
||||
device="cuda",
|
||||
dataloader_workers=16,
|
||||
dataloader_workers=24,
|
||||
dataloader_prefetch=8,
|
||||
summary_dir=".runs",
|
||||
write_every=2**5,
|
||||
@@ -53,17 +55,17 @@ pytorch_settings = PytorchSettings(
|
||||
|
||||
model_settings = ModelSettings(
|
||||
output_dim=2,
|
||||
n_hidden_layers=5,
|
||||
n_hidden_layers=3,
|
||||
overrides={
|
||||
# "hidden_layer_dims": (8, 8, 4, 4),
|
||||
"n_hidden_nodes_0": 8,
|
||||
"n_hidden_nodes_0": 16,
|
||||
"n_hidden_nodes_1": 8,
|
||||
"n_hidden_nodes_2": 4,
|
||||
"n_hidden_nodes_3": 4,
|
||||
"n_hidden_nodes_4": 2,
|
||||
"n_hidden_nodes_2": 8,
|
||||
# "n_hidden_nodes_3": 4,
|
||||
# "n_hidden_nodes_4": 2,
|
||||
},
|
||||
model_activation_func="EOActivation",
|
||||
dropout_prob=0.01,
|
||||
dropout_prob=0,
|
||||
model_layer_function="ONNRect",
|
||||
model_layer_kwargs={"square": True},
|
||||
scale=False,
|
||||
@@ -126,7 +128,7 @@ model_settings = ModelSettings(
|
||||
optimizer_settings = OptimizerSettings(
|
||||
optimizer="AdamW",
|
||||
optimizer_kwargs={
|
||||
"lr": 0.01,
|
||||
"lr": 0.005,
|
||||
"amsgrad": True,
|
||||
# "weight_decay": 1e-7,
|
||||
},
|
||||
@@ -242,7 +244,15 @@ if __name__ == "__main__":
|
||||
pytorch_settings=pytorch_settings,
|
||||
model_settings=model_settings,
|
||||
optimizer_settings=optimizer_settings,
|
||||
# checkpoint_path=".models/best_20241205_235929.tar",
|
||||
checkpoint_path=".models/best_20241216_221359.tar",
|
||||
reset_epoch=True,
|
||||
# settings_override={
|
||||
# "optimizer_settings": {
|
||||
# "optimizer_kwargs": {
|
||||
# "lr": 0.01,
|
||||
# },
|
||||
# }
|
||||
# }
|
||||
# 20241202_143149
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
@@ -26,7 +26,7 @@ global_settings = GlobalSettings(
|
||||
)
|
||||
|
||||
data_settings = DataSettings(
|
||||
config_path="data/*-128-16384-10000-0-0-0-0-PAM4-0-0.4.ini",
|
||||
config_path="data/20241211-105524-128-16384-1-0-0-0-0-PAM4-0-0.ini",
|
||||
# config_path=[f"data/20241202-*-128-16384-{length}-0-0-17-0-PAM4-0.ini" for length in range(48000, 53000, 1000)],
|
||||
dtype="complex64",
|
||||
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
|
||||
@@ -53,14 +53,14 @@ pytorch_settings = PytorchSettings(
|
||||
)
|
||||
|
||||
model_settings = ModelSettings(
|
||||
output_dim=3,
|
||||
output_dim=1,
|
||||
n_hidden_layers=3,
|
||||
overrides={
|
||||
"n_hidden_nodes_0": 2,
|
||||
"n_hidden_nodes_1": 2,
|
||||
"n_hidden_nodes_2": 2,
|
||||
"n_hidden_nodes_0": 4,
|
||||
"n_hidden_nodes_1": 4,
|
||||
"n_hidden_nodes_2": 4,
|
||||
},
|
||||
dropout_prob=0.01,
|
||||
dropout_prob=0,
|
||||
model_layer_function="ONNRect",
|
||||
model_activation_func="EOActivation",
|
||||
model_layer_kwargs={"square": True},
|
||||
@@ -110,20 +110,24 @@ model_settings = ModelSettings(
|
||||
)
|
||||
|
||||
optimizer_settings = OptimizerSettings(
|
||||
optimizer="AdamW",
|
||||
optimizer="RMSprop",
|
||||
# optimizer="AdamW",
|
||||
optimizer_kwargs={
|
||||
"lr": 0.005,
|
||||
"amsgrad": True,
|
||||
"lr": 0.01,
|
||||
"alpha": 0.9,
|
||||
"momentum": 0.1,
|
||||
"eps": 1e-8,
|
||||
"centered": True,
|
||||
# "amsgrad": True,
|
||||
# "weight_decay": 1e-7,
|
||||
},
|
||||
# learning_rate=0.05,
|
||||
scheduler="ReduceLROnPlateau",
|
||||
scheduler_kwargs={
|
||||
"patience": 2**6,
|
||||
"patience": 2**5,
|
||||
"factor": 0.75,
|
||||
# "threshold": 1e-3,
|
||||
"min_lr": 1e-6,
|
||||
"cooldown": 10,
|
||||
# "cooldown": 10,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -320,6 +320,29 @@ class normalize_by_first(nn.Module):
|
||||
def forward(self, data):
|
||||
return data / data[:, 0].unsqueeze(1)
|
||||
|
||||
class rotate(nn.Module):
|
||||
def __init__(self):
|
||||
super(rotate, self).__init__()
|
||||
|
||||
def forward(self, data, angle):
|
||||
# data -> (batch, n*2)
|
||||
# angle -> (batch, n)
|
||||
data_ = data
|
||||
if angle.ndim == 1:
|
||||
angle_ = angle.unsqueeze(1)
|
||||
else:
|
||||
angle_ = angle
|
||||
angle_ = angle_.expand(-1, data_.shape[1]//2)
|
||||
c = torch.cos(angle_)
|
||||
s = torch.sin(angle_)
|
||||
rot = torch.stack([torch.stack([c, -s], dim=2),
|
||||
torch.stack([s, c], dim=2)], dim=3)
|
||||
d = torch.bmm(data_.reshape(-1, 1, 2), rot.view(-1, 2, 2).to(dtype=data_.dtype)).reshape(*data.shape)
|
||||
# d = torch.bmm(data.unsqueeze(-1).mT, rot.to(dtype=data.dtype).mT).mT.squeeze(-1)
|
||||
|
||||
return d
|
||||
|
||||
|
||||
class photodiode(nn.Module):
|
||||
def __init__(self, size, bias=True):
|
||||
super(photodiode, self).__init__()
|
||||
@@ -487,7 +510,7 @@ class MZISingle(nn.Module):
|
||||
return x * torch.exp(1j * self.phi) * torch.sin(self.omega + self.func(x))
|
||||
|
||||
def naive_angle_loss(x: torch.Tensor, target: torch.Tensor, mod=2*torch.pi):
|
||||
return torch.fmod((x - target), mod).square().mean()
|
||||
return torch.fmod((x.abs().real - target.abs().real), mod).abs().mean()
|
||||
|
||||
def cosine_loss(x: torch.Tensor, target: torch.Tensor):
|
||||
return (2*(1 - torch.cos(x - target))).mean()
|
||||
|
||||
@@ -5,6 +5,7 @@ from torch.utils.data import Dataset
|
||||
# from torch.utils.data import Sampler
|
||||
import numpy as np
|
||||
import configparser
|
||||
import multiprocessing as mp
|
||||
|
||||
# class SubsetSampler(Sampler[int]):
|
||||
# """
|
||||
@@ -113,8 +114,9 @@ class FiberRegenerationDataset(Dataset):
|
||||
dtype: torch.dtype = None,
|
||||
real: bool = False,
|
||||
device=None,
|
||||
polarisations: tuple | list = (0,),
|
||||
osnr: float = None,
|
||||
randomise_polarisations: bool = False,
|
||||
repeat_randoms: int = 1,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -190,18 +192,20 @@ class FiberRegenerationDataset(Dataset):
|
||||
files.append(config["data"]["file"].strip('"'))
|
||||
self.config["data"]["file"] = str(files)
|
||||
|
||||
for i, angle in enumerate(torch.tensor(np.array(polarisations))):
|
||||
data_raw_copy = data_raw.clone()
|
||||
if angle == 0:
|
||||
continue
|
||||
sine = torch.sin(angle)
|
||||
cosine = torch.cos(angle)
|
||||
data_raw_copy[:, 2] = data_raw[:, 2] * cosine - data_raw[:, 3] * sine
|
||||
data_raw_copy[:, 3] = data_raw[:, 2] * sine + data_raw[:, 3] * cosine
|
||||
if i == 0:
|
||||
data_raw = data_raw_copy
|
||||
else:
|
||||
data_raw = torch.cat([data_raw, data_raw_copy], dim=0)
|
||||
# if polarisations is not None:
|
||||
# self.angles = torch.tensor(polarisations).repeat(len(data_raw), 1)
|
||||
# for i, angle in enumerate(torch.tensor(np.array(polarisations))):
|
||||
# data_raw_copy = data_raw.clone()
|
||||
# if angle == 0:
|
||||
# continue
|
||||
# sine = torch.sin(angle)
|
||||
# cosine = torch.cos(angle)
|
||||
# data_raw_copy[:, 2] = data_raw[:, 2] * cosine - data_raw[:, 3] * sine
|
||||
# data_raw_copy[:, 3] = data_raw[:, 2] * sine + data_raw[:, 3] * cosine
|
||||
# if i == 0:
|
||||
# data_raw = data_raw_copy
|
||||
# else:
|
||||
# data_raw = torch.cat([data_raw, data_raw_copy], dim=0)
|
||||
|
||||
self.device = data_raw.device
|
||||
|
||||
@@ -278,23 +282,61 @@ class FiberRegenerationDataset(Dataset):
|
||||
timestamps = data_raw[4, :]
|
||||
data_raw = data_raw[:4, :]
|
||||
data_raw = data_raw.view(2, 2, -1)
|
||||
timestamps_doubled = torch.cat([timestamps.unsqueeze(dim=0), timestamps.unsqueeze(dim=0)], dim=0).unsqueeze(
|
||||
dim=1
|
||||
)
|
||||
data_raw = torch.cat([data_raw, timestamps_doubled], dim=1)
|
||||
fiber_in = data_raw[0, :, :]
|
||||
fiber_out = data_raw[1, :, :]
|
||||
# timestamps_doubled = torch.cat([timestamps.unsqueeze(dim=0), timestamps.unsqueeze(dim=0)], dim=0).unsqueeze(
|
||||
# dim=1
|
||||
# )
|
||||
fiber_in = torch.cat([fiber_in, timestamps.unsqueeze(0)], dim=0)
|
||||
fiber_out = torch.cat([fiber_out, timestamps.unsqueeze(0)], dim=0)
|
||||
|
||||
if repeat_randoms > 1:
|
||||
fiber_in = fiber_in.repeat(1, 1, repeat_randoms)
|
||||
fiber_out = fiber_out.repeat(1, 1, repeat_randoms)
|
||||
# review: potential problems with repeated timestamps when plotting
|
||||
else:
|
||||
repeat_randoms = 1
|
||||
|
||||
if self.randomise_polarisations:
|
||||
angles = torch.fmod(torch.randperm(data_raw.shape[-1]*repeat_randoms), 2) * torch.pi
|
||||
# start_angle = torch.rand(1) * 2 * torch.pi
|
||||
# angles = start_angle + torch.cumsum(torch.randn(data_raw.shape[-1])/333, dim=0) * torch.pi * 2 # random walk
|
||||
# self.angles = torch.rand(self.data.shape[0]) * 2 * torch.pi
|
||||
else:
|
||||
angles = torch.zeros(data_raw.shape[-1])
|
||||
|
||||
sin = torch.sin(angles)
|
||||
cos = torch.cos(angles)
|
||||
rot = torch.stack([torch.stack([cos, -sin], dim=1), torch.stack([sin, cos], dim=1)], dim=2)
|
||||
data_rot = torch.bmm(fiber_out[:2, :].T.unsqueeze(1), rot.to(dtype=data_raw.dtype)).squeeze(1).T
|
||||
fiber_out = torch.cat((fiber_out, data_rot), dim=0)
|
||||
fiber_out = torch.cat([fiber_out, angles.unsqueeze(0)], dim=0)
|
||||
|
||||
if osnr is not None:
|
||||
popt = torch.mean(fiber_out[:2, :, :].abs().flatten(), dim=-1)
|
||||
noise = torch.randn_like(fiber_out[:2, :, :])
|
||||
pn = torch.mean(noise.abs().flatten(), dim=-1)
|
||||
noise = noise * (popt / pn) * 10 ** (-osnr / 20)
|
||||
fiber_out[:2, :, :] = torch.add(fiber_out[:2, :, :], noise)
|
||||
|
||||
|
||||
|
||||
|
||||
# data_raw = torch.cat([data_raw, timestamps_doubled], dim=1)
|
||||
# data layout
|
||||
# [ [E_in_x, E_in_y, timestamps],
|
||||
# [E_out_x, E_out_y, timestamps] ]
|
||||
|
||||
self.data = data_raw.unfold(dimension=-1, size=self.samples_per_slice, step=1)
|
||||
self.data = self.data.movedim(-2, 0)
|
||||
self.fiber_in = fiber_in.unfold(dimension=-1, size=self.samples_per_slice, step=1)
|
||||
self.fiber_in = self.fiber_in.movedim(-2, 0)
|
||||
|
||||
if randomise_polarisations:
|
||||
self.angles = torch.rand(self.data.shape[0]) * np.pi * 2
|
||||
# self.data[:, 1, :2, :] = self.rotate(self.data[:, 1, :2, :], self.angles)
|
||||
else:
|
||||
self.angles = torch.zeros(self.data.shape[0])
|
||||
self.fiber_out = fiber_out.unfold(dimension=-1, size=self.samples_per_slice, step=1)
|
||||
self.fiber_out = self.fiber_out.movedim(-2, 0)
|
||||
|
||||
# self.data = data_raw.unfold(dimension=-1, size=self.samples_per_slice, step=1)
|
||||
# self.data = self.data.movedim(-2, 0)
|
||||
# self.angles = torch.zeros(self.data.shape[0])
|
||||
...
|
||||
# ...
|
||||
# -> [no_slices, 2, 3, samples_per_slice]
|
||||
|
||||
@@ -305,51 +347,56 @@ class FiberRegenerationDataset(Dataset):
|
||||
# ...
|
||||
# ] -> [no_slices, 2, 3, samples_per_slice]
|
||||
|
||||
...
|
||||
|
||||
def __len__(self):
|
||||
return self.data.shape[0]
|
||||
return self.fiber_in.shape[0]
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if isinstance(idx, slice):
|
||||
return [self.__getitem__(i) for i in range(*idx.indices(len(self)))]
|
||||
else:
|
||||
data_slice = self.data[idx].squeeze()
|
||||
# fiber in: [E_in_x, E_in_y, timestamps]
|
||||
# fiber out: [E_out_x, E_out_y, timestamps, E_out_x_rot, E_out_y_rot, angle]
|
||||
|
||||
data_slice = data_slice[:, :, : data_slice.shape[2] // self.output_dim * self.output_dim]
|
||||
fiber_in = self.fiber_in[idx].squeeze()
|
||||
fiber_out = self.fiber_out[idx].squeeze()
|
||||
|
||||
data_slice = data_slice.view(data_slice.shape[0], data_slice.shape[1], self.output_dim, -1)
|
||||
fiber_in = fiber_in[..., : fiber_in.shape[-1] // self.output_dim * self.output_dim]
|
||||
fiber_out = fiber_out[..., : fiber_out.shape[-1] // self.output_dim * self.output_dim]
|
||||
|
||||
fiber_in = fiber_in.view(fiber_in.shape[0], self.output_dim, -1)
|
||||
fiber_out = fiber_out.view(fiber_out.shape[0], self.output_dim, -1)
|
||||
|
||||
# data_slice = data_slice[:, :, : data_slice.shape[2] // self.output_dim * self.output_dim]
|
||||
|
||||
# data_slice = data_slice.view(data_slice.shape[0], data_slice.shape[1], self.output_dim, -1)
|
||||
|
||||
# angle = self.angles[idx]
|
||||
|
||||
center_angle = fiber_out[5, self.output_dim // 2, 0]
|
||||
angles = fiber_out[5, :, 0]
|
||||
plot_data = fiber_out[:2, self.output_dim // 2, 0].detach().clone()
|
||||
plot_data_rot = fiber_out[3:5, self.output_dim // 2, 0].detach().clone()
|
||||
data = fiber_out[3:5, :, 0]
|
||||
|
||||
# if self.randomise_polarisations:
|
||||
# angle = torch.rand(1) * torch.pi * 2
|
||||
# sine = torch.sin(angle)
|
||||
# cosine = torch.cos(angle)
|
||||
# data_slice_ = data_slice[1]
|
||||
# data_slice[1, 0] = data_slice_[0] * cosine - data_slice_[1] * sine
|
||||
# data_slice[1,1] = data_slice_[0] * sine + data_slice_[1] * cosine
|
||||
# else:
|
||||
# angle = torch.zeros(1)
|
||||
# data = data.mT
|
||||
# c = torch.cos(angle).unsqueeze(-1)
|
||||
# s = torch.sin(angle).unsqueeze(-1)
|
||||
# rot = torch.stack([torch.stack([c, -s], dim=1), torch.stack([s, c], dim=1)], dim=2).squeeze(-1)
|
||||
# data = torch.bmm(data.mT.unsqueeze(0), rot.to(dtype=data.dtype)).squeeze(-1)
|
||||
...
|
||||
|
||||
# data = data_slice[1, :2, :, 0]
|
||||
|
||||
angle = self.angles[idx]
|
||||
|
||||
data_index = 1
|
||||
|
||||
data_slice[1, :2, :, :] = self.rotate(data_slice[data_index, :2, :, :], angle)
|
||||
|
||||
data = data_slice[1, :2, :, 0]
|
||||
# data = self.rotate(data, angle)
|
||||
# angle = torch.zeros_like(angle)
|
||||
|
||||
# for both polarisations (x/y), calculate the mean of the signal around the current symbol (-> corresponding to a lowpass filter)
|
||||
angle_data = data_slice[1, :2, :, :].reshape(2, -1).mean(dim=1)
|
||||
angle_data2 = self.complex_max(data_slice[1, :2, :, :].reshape(2, -1))
|
||||
plot_data = data_slice[1, :2, self.output_dim // 2, 0]
|
||||
sop = self.polarimeter(plot_data)
|
||||
angle_data = fiber_out[:2, :, :].reshape(2, -1).mean(dim=1).repeat(1, self.output_dim)
|
||||
angle_data2 = self.complex_max(fiber_out[:2, :, :].reshape(2, -1)).repeat(1, self.output_dim)
|
||||
# sop = self.polarimeter(plot_data)
|
||||
# angle_data = data_slice[1, :2, :self.output_dim//2, :].squeeze().reshape(2, -1).mean(dim=-1)
|
||||
# angle = data_slice[1, 3, self.output_dim // 2, 0].real
|
||||
target = data_slice[0, :2, self.output_dim // 2, 0]
|
||||
target_timestamp = data_slice[0, 2, self.output_dim // 2, 0].real
|
||||
target = fiber_in[:2, self.output_dim // 2, 0]
|
||||
plot_target = fiber_in[:2, self.output_dim // 2, 0].detach().clone()
|
||||
target_timestamp = fiber_in[2, self.output_dim // 2, 0].real
|
||||
...
|
||||
|
||||
# data_timestamps = data[-1,:].real
|
||||
@@ -360,14 +407,32 @@ class FiberRegenerationDataset(Dataset):
|
||||
|
||||
# transpose to interleave the x and y data in the output tensor
|
||||
data = data.transpose(0, 1).flatten().squeeze()
|
||||
angle_data = angle_data.flatten().squeeze()
|
||||
angle_data2 = angle_data.flatten().squeeze()
|
||||
angle = angle.flatten().squeeze()
|
||||
angle_data = angle_data.transpose(0, 1).flatten().squeeze()
|
||||
angle_data2 = angle_data2.transpose(0,1).flatten().squeeze()
|
||||
center_angle = center_angle.flatten().squeeze()
|
||||
angles = angles.flatten().squeeze()
|
||||
# data_timestamps = data_timestamps.flatten().squeeze()
|
||||
# target = target.transpose(0,1).flatten().squeeze()
|
||||
target = target.flatten().squeeze()
|
||||
target_timestamp = target_timestamp.flatten().squeeze()
|
||||
plot_target = plot_target.flatten().squeeze()
|
||||
plot_data = plot_data.flatten().squeeze()
|
||||
plot_data_rot = plot_data_rot.flatten().squeeze()
|
||||
|
||||
return {"x": data, "y": target, "angle": angle, "sop": sop, "angle_data": angle_data, "angle_data2": angle_data2, "timestamp": target_timestamp, "plot_data": plot_data}
|
||||
return {
|
||||
"x": data,
|
||||
"y": target,
|
||||
"center_angle": center_angle,
|
||||
"angles": angles,
|
||||
"mean_angle": angles.mean(),
|
||||
# "sop": sop,
|
||||
"angle_data": angle_data,
|
||||
"angle_data2": angle_data2,
|
||||
"timestamp": target_timestamp,
|
||||
"plot_target": plot_target,
|
||||
"plot_data": plot_data,
|
||||
"plot_data_rot": plot_data_rot,
|
||||
}
|
||||
|
||||
def complex_max(self, data, dim=-1):
|
||||
# returns element(s) with the maximum absolute value along a given dimension
|
||||
@@ -376,7 +441,6 @@ class FiberRegenerationDataset(Dataset):
|
||||
# return max_values
|
||||
return torch.gather(data, dim, torch.argmax(data.abs(), dim=dim, keepdim=True)).squeeze(dim=dim)
|
||||
|
||||
|
||||
def rotate(self, data, angle):
|
||||
# rotates a 2d tensor by a given angle
|
||||
# data: [2, ...]
|
||||
@@ -389,6 +453,24 @@ class FiberRegenerationDataset(Dataset):
|
||||
|
||||
return torch.stack([data[0] * cosine - data[1] * sine, data[0] * sine + data[1] * cosine], dim=0)
|
||||
|
||||
def rotate_all(self):
|
||||
def do_rotation(j, num_processes):
|
||||
for i in range(len(self) // num_processes):
|
||||
index = i * num_processes + j
|
||||
self.data[index, 1, :2, :] = self.rotate(self.data[index, 1, :2, :], self.angles[index])
|
||||
|
||||
self.processes = []
|
||||
|
||||
for j in range(mp.cpu_count()):
|
||||
self.processes.append(mp.Process(target=do_rotation, args=(j, mp.cpu_count())))
|
||||
self.processes[-1].start()
|
||||
|
||||
for p in self.processes:
|
||||
p.join()
|
||||
|
||||
for i in range(len(self) // mp.cpu_count() * mp.cpu_count(), len(self)):
|
||||
self.data[i, 1, :2, :] = self.rotate(self.data[i, 1, :2, :], self.angles[i])
|
||||
|
||||
def polarimeter(self, data):
|
||||
# data: [2, ...] -> x, y
|
||||
# returns [4] -> S0, S1, S2, S3
|
||||
@@ -404,4 +486,4 @@ class FiberRegenerationDataset(Dataset):
|
||||
S2 = (2 * I_45 - S0) / S0
|
||||
S3 = (2 * I_RHC - S0) / S0
|
||||
|
||||
return torch.stack([S1, S2, S3], dim=0)
|
||||
return torch.stack([S0, S1, S2, S3], dim=0)
|
||||
|
||||
Reference in New Issue
Block a user