Compare commits
2 Commits
33141bdf41
...
7a0b65f82d
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7a0b65f82d | ||
|
|
98305fdf47 |
@@ -1,3 +1,3 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:746ea83013870351296e01e294905ee027291ef79fd78c1e6b69dd9ebaa1cba0
|
oid sha256:76934d1d202aea1311ba67f5ea35eeb99a9c5c856f491565032e7d54ca6f072d
|
||||||
size 10240000
|
size 13598720
|
||||||
|
|||||||
@@ -26,6 +26,8 @@ import torch
|
|||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
import torch.utils.data
|
import torch.utils.data
|
||||||
|
|
||||||
|
import hypertraining.models as models
|
||||||
|
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
@@ -253,14 +255,17 @@ class HyperTraining:
|
|||||||
model_kwargs = {
|
model_kwargs = {
|
||||||
"dims": (input_dim, *hidden_dims, self.model_settings.output_dim),
|
"dims": (input_dim, *hidden_dims, self.model_settings.output_dim),
|
||||||
"layer_function": layer_func,
|
"layer_function": layer_func,
|
||||||
"layer_parametrizations": layer_parametrizations,
|
"layer_func_kwargs": self.model_settings.model_layer_kwargs,
|
||||||
"activation_function": afunc,
|
"act_function": afunc,
|
||||||
|
"act_func_kwargs": None,
|
||||||
|
"parametrizations": layer_parametrizations,
|
||||||
"dtype": dtype,
|
"dtype": dtype,
|
||||||
"droupout_prob": self.model_settings.dropout_prob,
|
"dropout_prob": self.model_settings.dropout_prob,
|
||||||
"scale": scale_layers,
|
"scale_layers": scale_layers,
|
||||||
|
"rotate": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
model = util.complexNN.regenerator(**model_kwargs)
|
model = models.regenerator(*model_kwargs.pop("dims"), **model_kwargs)
|
||||||
n_nodes = sum(hidden_dims)
|
n_nodes = sum(hidden_dims)
|
||||||
|
|
||||||
if writer is not None:
|
if writer is not None:
|
||||||
@@ -381,7 +386,10 @@ class HyperTraining:
|
|||||||
running_loss = 0.0
|
running_loss = 0.0
|
||||||
model.train()
|
model.train()
|
||||||
loader_len = len(train_loader)
|
loader_len = len(train_loader)
|
||||||
for batch_idx, (x, y, _) in enumerate(train_loader):
|
for batch_idx, batch in enumerate(train_loader):
|
||||||
|
x = batch["x"]
|
||||||
|
y = batch["y"]
|
||||||
|
|
||||||
if batch_idx >= self.optuna_settings._n_train_batches:
|
if batch_idx >= self.optuna_settings._n_train_batches:
|
||||||
break
|
break
|
||||||
model.zero_grad(set_to_none=True)
|
model.zero_grad(set_to_none=True)
|
||||||
@@ -390,7 +398,7 @@ class HyperTraining:
|
|||||||
y.to(self.pytorch_settings.device),
|
y.to(self.pytorch_settings.device),
|
||||||
)
|
)
|
||||||
y_pred = model(x)
|
y_pred = model(x)
|
||||||
loss = util.complexNN.complex_mse_loss(y_pred, y)
|
loss = util.complexNN.complex_mse_loss(y_pred, y, power=True)
|
||||||
loss_value = loss.item()
|
loss_value = loss.item()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
@@ -444,7 +452,9 @@ class HyperTraining:
|
|||||||
model.eval()
|
model.eval()
|
||||||
running_error = 0
|
running_error = 0
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for batch_idx, (x, y, _) in enumerate(valid_loader):
|
for batch_idx, batch in enumerate(valid_loader):
|
||||||
|
x = batch["x"]
|
||||||
|
y = batch["y"]
|
||||||
if batch_idx >= self.optuna_settings._n_valid_batches:
|
if batch_idx >= self.optuna_settings._n_valid_batches:
|
||||||
break
|
break
|
||||||
x, y = (
|
x, y = (
|
||||||
@@ -452,50 +462,44 @@ class HyperTraining:
|
|||||||
y.to(self.pytorch_settings.device),
|
y.to(self.pytorch_settings.device),
|
||||||
)
|
)
|
||||||
y_pred = model(x)
|
y_pred = model(x)
|
||||||
error = util.complexNN.complex_mse_loss(y_pred, y)
|
error = util.complexNN.complex_mse_loss(y_pred, y, power=True)
|
||||||
error_value = error.item()
|
error_value = error.item()
|
||||||
running_error += error_value
|
running_error += error_value
|
||||||
|
|
||||||
running_error /= min(len(valid_loader), self.optuna_settings._n_valid_batches)
|
running_error /= min(len(valid_loader), self.optuna_settings._n_valid_batches)
|
||||||
|
|
||||||
if writer is not None:
|
if writer is not None:
|
||||||
title_append, subtitle = self.build_title(trial)
|
writer.add_scalar(
|
||||||
writer.add_figure(
|
"eval loss",
|
||||||
"fiber response",
|
running_error,
|
||||||
self.plot_model_response(
|
epoch,
|
||||||
trial,
|
|
||||||
model=model,
|
|
||||||
title_append=title_append,
|
|
||||||
subtitle=subtitle,
|
|
||||||
show=False,
|
|
||||||
),
|
|
||||||
epoch + 1,
|
|
||||||
)
|
|
||||||
writer.add_figure(
|
|
||||||
"eye diagram",
|
|
||||||
self.plot_model_response(
|
|
||||||
trial,
|
|
||||||
model=self.model,
|
|
||||||
title_append=title_append,
|
|
||||||
subtitle=subtitle,
|
|
||||||
show=False,
|
|
||||||
mode="eye",
|
|
||||||
),
|
|
||||||
epoch + 1,
|
|
||||||
)
|
)
|
||||||
|
# if (epoch + 1) % 10 == 0 or epoch < 10:
|
||||||
|
# # plotting is slow, so only do it every 10 epochs
|
||||||
|
# title_append, subtitle = self.build_title(trial)
|
||||||
|
# head_fig, eye_fig, powers_fig = self.plot_model_response(
|
||||||
|
# model=model,
|
||||||
|
# title_append=title_append,
|
||||||
|
# subtitle=subtitle,
|
||||||
|
# show=False,
|
||||||
|
# )
|
||||||
|
# writer.add_figure(
|
||||||
|
# "fiber response",
|
||||||
|
# head_fig,
|
||||||
|
# epoch + 1,
|
||||||
|
# )
|
||||||
|
# writer.add_figure(
|
||||||
|
# "eye diagram",
|
||||||
|
# eye_fig,
|
||||||
|
# epoch + 1,
|
||||||
|
# )
|
||||||
|
|
||||||
writer.add_figure(
|
# writer.add_figure(
|
||||||
"powers",
|
# "powers",
|
||||||
self.plot_model_response(
|
# powers_fig,
|
||||||
trial,
|
# epoch + 1,
|
||||||
model=self.model,
|
# )
|
||||||
title_append=title_append,
|
# writer.flush()
|
||||||
subtitle=subtitle,
|
|
||||||
mode="powers",
|
|
||||||
show=False,
|
|
||||||
),
|
|
||||||
epoch + 1,
|
|
||||||
)
|
|
||||||
|
|
||||||
# if enable_progress:
|
# if enable_progress:
|
||||||
# progress.stop()
|
# progress.stop()
|
||||||
@@ -511,15 +515,18 @@ class HyperTraining:
|
|||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
model = model.to(self.pytorch_settings.device)
|
model = model.to(self.pytorch_settings.device)
|
||||||
for x, y, timestamp in loader:
|
for batch in loader:
|
||||||
|
x = batch["x"]
|
||||||
|
y = batch["y"]
|
||||||
|
timestamp = batch["timestamp"]
|
||||||
x, y = (
|
x, y = (
|
||||||
x.to(self.pytorch_settings.device),
|
x.to(self.pytorch_settings.device),
|
||||||
y.to(self.pytorch_settings.device),
|
y.to(self.pytorch_settings.device),
|
||||||
)
|
)
|
||||||
if trace_powers:
|
if trace_powers:
|
||||||
y_pred, powers = model(x, trace_powers).cpu()
|
y_pred, powers = model(x, trace_powers=True).cpu()
|
||||||
else:
|
else:
|
||||||
y_pred = model(x, trace_powers).cpu()
|
y_pred = model(x, trace_powers=True).cpu()
|
||||||
# x = x.cpu()
|
# x = x.cpu()
|
||||||
# y = y.cpu()
|
# y = y.cpu()
|
||||||
y_pred = y_pred.view(y_pred.shape[0], -1, 2)
|
y_pred = y_pred.view(y_pred.shape[0], -1, 2)
|
||||||
@@ -539,7 +546,7 @@ class HyperTraining:
|
|||||||
return fiber_in, fiber_out, regen, timestamps, powers
|
return fiber_in, fiber_out, regen, timestamps, powers
|
||||||
return fiber_in, fiber_out, regen, timestamps
|
return fiber_in, fiber_out, regen, timestamps
|
||||||
|
|
||||||
def objective(self, trial: optuna.Trial, plot_before=False):
|
def objective(self, trial: optuna.Trial):
|
||||||
if self.stop_study:
|
if self.stop_study:
|
||||||
trial.study.stop()
|
trial.study.stop()
|
||||||
model = None
|
model = None
|
||||||
@@ -555,54 +562,54 @@ class HyperTraining:
|
|||||||
|
|
||||||
title_append, subtitle = self.build_title(trial)
|
title_append, subtitle = self.build_title(trial)
|
||||||
|
|
||||||
writer.add_figure(
|
# writer.add_figure(
|
||||||
"fiber response",
|
# "fiber response",
|
||||||
self.plot_model_response(
|
# self.plot_model_response(
|
||||||
trial,
|
# trial,
|
||||||
model=model,
|
# model=model,
|
||||||
title_append=title_append,
|
# title_append=title_append,
|
||||||
subtitle=subtitle,
|
# subtitle=subtitle,
|
||||||
show=False,
|
# show=False,
|
||||||
),
|
# ),
|
||||||
0,
|
# 0,
|
||||||
)
|
# )
|
||||||
writer.add_figure(
|
# writer.add_figure(
|
||||||
"eye diagram",
|
# "eye diagram",
|
||||||
self.plot_model_response(
|
# self.plot_model_response(
|
||||||
trial,
|
# trial,
|
||||||
model=self.model,
|
# model=self.model,
|
||||||
title_append=title_append,
|
# title_append=title_append,
|
||||||
subtitle=subtitle,
|
# subtitle=subtitle,
|
||||||
mode="eye",
|
# mode="eye",
|
||||||
show=False,
|
# show=False,
|
||||||
),
|
# ),
|
||||||
0,
|
# 0,
|
||||||
)
|
# )
|
||||||
|
|
||||||
writer.add_figure(
|
# writer.add_figure(
|
||||||
"powers",
|
# "powers",
|
||||||
self.plot_model_response(
|
# self.plot_model_response(
|
||||||
trial,
|
# trial,
|
||||||
model=self.model,
|
# model=self.model,
|
||||||
title_append=title_append,
|
# title_append=title_append,
|
||||||
subtitle=subtitle,
|
# subtitle=subtitle,
|
||||||
mode="powers",
|
# mode="powers",
|
||||||
show=False,
|
# show=False,
|
||||||
),
|
# ),
|
||||||
0,
|
# 0,
|
||||||
)
|
# )
|
||||||
|
|
||||||
train_loader, valid_loader = self.get_sliced_data(trial)
|
train_loader, valid_loader = self.get_sliced_data(trial)
|
||||||
|
|
||||||
optimizer_name = trial.suggest_categorical_optional("optimizer", self.optimizer_settings.optimizer)
|
optimizer_name = trial.suggest_categorical_optional("optimizer", self.optimizer_settings.optimizer)
|
||||||
|
|
||||||
lr = trial.suggest_float_optional("lr", self.optimizer_settings.learning_rate, log=True)
|
lr = trial.suggest_float_optional("lr", self.optimizer_settings.optimizer_kwargs["lr"], log=True)
|
||||||
|
|
||||||
optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)
|
optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)
|
||||||
if self.optimizer_settings.scheduler is not None:
|
# if self.optimizer_settings.scheduler is not None:
|
||||||
scheduler = getattr(optim.lr_scheduler, self.optimizer_settings.scheduler)(
|
# scheduler = getattr(optim.lr_scheduler, self.optimizer_settings.scheduler)(
|
||||||
optimizer, **self.optimizer_settings.scheduler_kwargs
|
# optimizer, **self.optimizer_settings.scheduler_kwargs
|
||||||
)
|
# )
|
||||||
|
|
||||||
for epoch in range(self.pytorch_settings.epochs):
|
for epoch in range(self.pytorch_settings.epochs):
|
||||||
trial.set_user_attr("epoch", epoch)
|
trial.set_user_attr("epoch", epoch)
|
||||||
@@ -628,8 +635,8 @@ class HyperTraining:
|
|||||||
writer,
|
writer,
|
||||||
# enable_progress=enable_progress,
|
# enable_progress=enable_progress,
|
||||||
)
|
)
|
||||||
if self.optimizer_settings.scheduler is not None:
|
# if self.optimizer_settings.scheduler is not None:
|
||||||
scheduler.step(error)
|
# scheduler.step(error)
|
||||||
|
|
||||||
trial.set_user_attr("mse", error)
|
trial.set_user_attr("mse", error)
|
||||||
trial.set_user_attr("log_mse", np.log10(error + np.finfo(float).eps))
|
trial.set_user_attr("log_mse", np.log10(error + np.finfo(float).eps))
|
||||||
@@ -645,10 +652,10 @@ class HyperTraining:
|
|||||||
if self.optuna_settings._multi_objective:
|
if self.optuna_settings._multi_objective:
|
||||||
return -np.log10(error + np.finfo(float).eps), trial.user_attrs.get("model_n_nodes", -1)
|
return -np.log10(error + np.finfo(float).eps), trial.user_attrs.get("model_n_nodes", -1)
|
||||||
|
|
||||||
if self.pytorch_settings.save_models and model is not None:
|
# if self.pytorch_settings.save_models and model is not None:
|
||||||
save_path = Path(self.pytorch_settings.model_dir) / f"{self.optuna_settings.study_name}_{trial.number}.pth"
|
# save_path = Path(self.pytorch_settings.model_dir) / f"{self.optuna_settings.study_name}_{trial.number}.pth"
|
||||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
# save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
torch.save(model, save_path)
|
# torch.save(model, save_path)
|
||||||
|
|
||||||
return error
|
return error
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,8 @@ from util.complexNN import (
|
|||||||
photodiode,
|
photodiode,
|
||||||
EOActivation,
|
EOActivation,
|
||||||
polarimeter,
|
polarimeter,
|
||||||
normalize_by_first
|
# normalize_by_first,
|
||||||
|
rotate,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -19,11 +20,11 @@ class polarisation_estimator2(Module):
|
|||||||
polarimeter(),
|
polarimeter(),
|
||||||
torch.nn.Linear(4, 4),
|
torch.nn.Linear(4, 4),
|
||||||
torch.nn.ReLU(),
|
torch.nn.ReLU(),
|
||||||
torch.nn.Dropout(p=0.01),
|
# torch.nn.Dropout(p=0.01),
|
||||||
torch.nn.Linear(4, 4),
|
torch.nn.Linear(4, 4),
|
||||||
torch.nn.ReLU(),
|
torch.nn.ReLU(),
|
||||||
torch.nn.Dropout(p=0.01),
|
# torch.nn.Dropout(p=0.01),
|
||||||
torch.nn.Linear(4, 4),
|
torch.nn.Linear(4, 1),
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
@@ -124,6 +125,7 @@ class regenerator(Module):
|
|||||||
dtype=torch.float64,
|
dtype=torch.float64,
|
||||||
dropout_prob=0.01,
|
dropout_prob=0.01,
|
||||||
scale_layers=False,
|
scale_layers=False,
|
||||||
|
rotate=False,
|
||||||
):
|
):
|
||||||
super(regenerator, self).__init__()
|
super(regenerator, self).__init__()
|
||||||
self._n_hidden_layers = len(dims) - 2
|
self._n_hidden_layers = len(dims) - 2
|
||||||
@@ -131,6 +133,8 @@ class regenerator(Module):
|
|||||||
layer_func_kwargs = layer_func_kwargs or {}
|
layer_func_kwargs = layer_func_kwargs or {}
|
||||||
act_func_kwargs = act_func_kwargs or {}
|
act_func_kwargs = act_func_kwargs or {}
|
||||||
|
|
||||||
|
self.rotation = rotate
|
||||||
|
|
||||||
self.build_model(dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob, scale_layers)
|
self.build_model(dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob, scale_layers)
|
||||||
|
|
||||||
def build_model(self, dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob, scale_layers):
|
def build_model(self, dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob, scale_layers):
|
||||||
@@ -146,8 +150,9 @@ class regenerator(Module):
|
|||||||
module = act_function(size=dims[i + 1], **act_func_kwargs)
|
module = act_function(size=dims[i + 1], **act_func_kwargs)
|
||||||
self.get_submodule(f"layer_{i}").add_module("activation", module)
|
self.get_submodule(f"layer_{i}").add_module("activation", module)
|
||||||
|
|
||||||
module = DropoutComplex(p=dropout_prob)
|
if dropout_prob is not None and dropout_prob > 0:
|
||||||
self.get_submodule(f"layer_{i}").add_module("dropout", module)
|
module = DropoutComplex(p=dropout_prob)
|
||||||
|
self.get_submodule(f"layer_{i}").add_module("dropout", module)
|
||||||
|
|
||||||
self.add_module(f"layer_{self._n_hidden_layers}", Sequential())
|
self.add_module(f"layer_{self._n_hidden_layers}", Sequential())
|
||||||
|
|
||||||
@@ -160,6 +165,10 @@ class regenerator(Module):
|
|||||||
module = act_function(size=dims[-1], **act_func_kwargs)
|
module = act_function(size=dims[-1], **act_func_kwargs)
|
||||||
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("activation", module)
|
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("activation", module)
|
||||||
|
|
||||||
|
if self.rotation:
|
||||||
|
module = rotate()
|
||||||
|
self.add_module("rotate", module)
|
||||||
|
|
||||||
# module = Scale(size=dims[-1])
|
# module = Scale(size=dims[-1])
|
||||||
# self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("out_scale", module)
|
# self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("out_scale", module)
|
||||||
|
|
||||||
@@ -190,15 +199,27 @@ class regenerator(Module):
|
|||||||
powers.append(x.abs().square().sum())
|
powers.append(x.abs().square().sum())
|
||||||
return powers
|
return powers
|
||||||
|
|
||||||
def forward(self, x, trace_powers=False):
|
def forward(self, x, angle=None, pre_rot=False, trace_powers=False):
|
||||||
powers = self._trace_powers(trace_powers, x)
|
powers = self._trace_powers(trace_powers, x)
|
||||||
x = self.layer_0(x)
|
# x = self.layer_0(x)
|
||||||
powers = self._trace_powers(trace_powers, x, powers)
|
# powers = self._trace_powers(trace_powers, x, powers)
|
||||||
for i in range(1, self._n_hidden_layers):
|
for i in range(0, self._n_hidden_layers):
|
||||||
x = getattr(self, f"layer_{i}")(x)
|
x = getattr(self, f"layer_{i}")(x)
|
||||||
powers = self._trace_powers(trace_powers, x, powers)
|
powers = self._trace_powers(trace_powers, x, powers)
|
||||||
x = getattr(self, f"layer_{self._n_hidden_layers}")(x)
|
x = getattr(self, f"layer_{self._n_hidden_layers}")(x)
|
||||||
powers = self._trace_powers(trace_powers, x, powers)
|
if self.rotation:
|
||||||
if trace_powers:
|
try:
|
||||||
return x, powers
|
x_rot = self.rotate(x, angle)
|
||||||
return x
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
powers = self._trace_powers(trace_powers, x_rot, powers)
|
||||||
|
else:
|
||||||
|
x_rot = x
|
||||||
|
|
||||||
|
if pre_rot and trace_powers:
|
||||||
|
return x_rot, x, powers
|
||||||
|
if pre_rot and not trace_powers:
|
||||||
|
return x_rot, x
|
||||||
|
if not pre_rot and trace_powers:
|
||||||
|
return x_rot, powers
|
||||||
|
return x_rot
|
||||||
@@ -22,6 +22,8 @@ class DataSettings:
|
|||||||
train_split: float = 0.8
|
train_split: float = 0.8
|
||||||
polarisations: tuple | list = (0,)
|
polarisations: tuple | list = (0,)
|
||||||
randomise_polarisations: bool = False
|
randomise_polarisations: bool = False
|
||||||
|
osnr: float | int = None
|
||||||
|
seed: int = None
|
||||||
|
|
||||||
"""
|
"""
|
||||||
change to:
|
change to:
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ import copy
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import random
|
import random
|
||||||
from typing import Literal
|
|
||||||
import matplotlib
|
import matplotlib
|
||||||
from matplotlib.colors import LinearSegmentedColormap
|
from matplotlib.colors import LinearSegmentedColormap
|
||||||
import torch.nn.utils.parametrize
|
import torch.nn.utils.parametrize
|
||||||
@@ -60,25 +59,26 @@ def traverse_dict_update(target, source):
|
|||||||
except TypeError:
|
except TypeError:
|
||||||
target.__dict__[k] = v
|
target.__dict__[k] = v
|
||||||
|
|
||||||
|
|
||||||
def get_parameter_names_and_values(model):
|
def get_parameter_names_and_values(model):
|
||||||
def is_parametrized(module):
|
def is_parametrized(module):
|
||||||
if hasattr(module, "parametrizations"):
|
if hasattr(module, "parametrizations"):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _get_param_info(module, prefix='', parametrization=False):
|
def _get_param_info(module, prefix="", parametrization=False):
|
||||||
param_list = []
|
param_list = []
|
||||||
for name, param in module.named_parameters(recurse = parametrization):
|
for name, param in module.named_parameters(recurse=parametrization):
|
||||||
if parametrization and name.startswith("parametrizations"):
|
if parametrization and name.startswith("parametrizations"):
|
||||||
name_parts = name.split('.')
|
name_parts = name.split(".")
|
||||||
name = name_parts[1]
|
name = name_parts[1]
|
||||||
param = getattr(module, name)
|
param = getattr(module, name)
|
||||||
full_name = prefix + ('.' if prefix else '') + name
|
full_name = prefix + ("." if prefix else "") + name
|
||||||
param_value = param.data
|
param_value = param.data
|
||||||
param_list.append((full_name, param_value))
|
param_list.append((full_name, param_value))
|
||||||
|
|
||||||
for child_name, child_module in module.named_children():
|
for child_name, child_module in module.named_children():
|
||||||
child_prefix = prefix + ('.' if prefix else '') + child_name
|
child_prefix = prefix + ("." if prefix else "") + child_name
|
||||||
if child_name == "parametrizations":
|
if child_name == "parametrizations":
|
||||||
continue
|
continue
|
||||||
param_list.extend(_get_param_info(child_module, child_prefix, is_parametrized(child_module)))
|
param_list.extend(_get_param_info(child_module, child_prefix, is_parametrized(child_module)))
|
||||||
@@ -87,6 +87,7 @@ def get_parameter_names_and_values(model):
|
|||||||
|
|
||||||
return _get_param_info(model)
|
return _get_param_info(model)
|
||||||
|
|
||||||
|
|
||||||
class PolarizationTrainer:
|
class PolarizationTrainer:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -101,7 +102,7 @@ class PolarizationTrainer:
|
|||||||
settings_override=None,
|
settings_override=None,
|
||||||
reset_epoch=False,
|
reset_epoch=False,
|
||||||
):
|
):
|
||||||
self.mod = torch.pi/2
|
self.mod = torch.pi / 2
|
||||||
self.resume = checkpoint_path is not None
|
self.resume = checkpoint_path is not None
|
||||||
torch.serialization.add_safe_globals([
|
torch.serialization.add_safe_globals([
|
||||||
*util.complexNN.__all__,
|
*util.complexNN.__all__,
|
||||||
@@ -219,7 +220,7 @@ class PolarizationTrainer:
|
|||||||
|
|
||||||
# dims = self.model_kwargs.pop("dims")
|
# dims = self.model_kwargs.pop("dims")
|
||||||
model_kwargs = copy.deepcopy(self.model_kwargs)
|
model_kwargs = copy.deepcopy(self.model_kwargs)
|
||||||
self.model = models.polarisation_estimator(*model_kwargs.pop('dims'),**model_kwargs)
|
self.model = models.polarisation_estimator(*model_kwargs.pop("dims"), **model_kwargs)
|
||||||
# self.model = models.polarisation_estimator2()
|
# self.model = models.polarisation_estimator2()
|
||||||
|
|
||||||
if self.writer is not None:
|
if self.writer is not None:
|
||||||
@@ -336,17 +337,20 @@ class PolarizationTrainer:
|
|||||||
write_div = 0
|
write_div = 0
|
||||||
loss_div = 0
|
loss_div = 0
|
||||||
for batch_idx, batch in enumerate(train_loader):
|
for batch_idx, batch in enumerate(train_loader):
|
||||||
x = batch["x"]
|
x = batch["angle_data2"]
|
||||||
y = batch["sop"]
|
y = batch["center_angle"]
|
||||||
self.model.zero_grad(set_to_none=True)
|
self.model.zero_grad(set_to_none=True)
|
||||||
x, y = (
|
x, y = (
|
||||||
x.to(self.pytorch_settings.device),
|
x.to(self.pytorch_settings.device),
|
||||||
y.to(self.pytorch_settings.device),
|
y.to(self.pytorch_settings.device),
|
||||||
)
|
)
|
||||||
y_pred = self.model(x)
|
y_pred = self.model(x).abs().real
|
||||||
|
# y_pred = torch.fmod(y_pred, self.mod)
|
||||||
|
y = y.abs().real
|
||||||
|
# y = torch.fmod(y, self.mod)
|
||||||
|
# loss = torch.nn.functional.mse_loss(torch.cos(y_pred), torch.cos(y))
|
||||||
# loss = torch.nn.functional.smooth_l1_loss(torch.cos(torch.fmod(y_pred, torch.pi/2)).squeeze(), torch.cos(torch.fmod(y, torch.pi/2)).squeeze(), beta=0.5)
|
# loss = torch.nn.functional.smooth_l1_loss(torch.cos(torch.fmod(y_pred, torch.pi/2)).squeeze(), torch.cos(torch.fmod(y, torch.pi/2)).squeeze(), beta=0.5)
|
||||||
loss = torch.nn.functional.mse_loss(y_pred, y)
|
loss = util.complexNN.naive_angle_loss(y_pred, y, mod=self.mod)
|
||||||
# loss = util.complexNN.naive_angle_loss(y_pred, y, mod=torch.pi/2)
|
|
||||||
loss_value = loss.item()
|
loss_value = loss.item()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
@@ -356,7 +360,7 @@ class PolarizationTrainer:
|
|||||||
loss_div += 1
|
loss_div += 1
|
||||||
|
|
||||||
if enable_progress:
|
if enable_progress:
|
||||||
progress.update(task, advance=1, description=f"{loss_value:.3e}")
|
progress.update(task, advance=1, description=f"{loss_value/np.pi*180:.3e} °")
|
||||||
|
|
||||||
if batch_idx % self.pytorch_settings.write_every == 0:
|
if batch_idx % self.pytorch_settings.write_every == 0:
|
||||||
self.writer.add_scalar(
|
self.writer.add_scalar(
|
||||||
@@ -395,24 +399,28 @@ class PolarizationTrainer:
|
|||||||
loss_div = 0
|
loss_div = 0
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for _, batch in enumerate(valid_loader):
|
for _, batch in enumerate(valid_loader):
|
||||||
x = batch["x"]
|
# x = batch["angle_data2"]
|
||||||
y = batch["sop"]
|
x = batch["angle_data2"]
|
||||||
|
y = batch["center_angle"]
|
||||||
x, y = (
|
x, y = (
|
||||||
x.to(self.pytorch_settings.device),
|
x.to(self.pytorch_settings.device),
|
||||||
y.to(self.pytorch_settings.device),
|
y.to(self.pytorch_settings.device),
|
||||||
)
|
)
|
||||||
y_pred = self.model(x)
|
y_pred = self.model(x).abs().real
|
||||||
|
# y_pred = torch.fmod(y_pred, self.mod)
|
||||||
|
y = y.abs().real
|
||||||
|
# y = torch.fmod(y, self.mod)
|
||||||
|
# loss = torch.nn.functional.mse_loss(torch.cos(y_pred), torch.cos(y))
|
||||||
# loss = torch.nn.functional.smooth_l1_loss(torch.cos(torch.fmod(y_pred, torch.pi/2)).squeeze(), torch.cos(torch.fmod(y, torch.pi/2)).squeeze(), beta=0.5)
|
# loss = torch.nn.functional.smooth_l1_loss(torch.cos(torch.fmod(y_pred, torch.pi/2)).squeeze(), torch.cos(torch.fmod(y, torch.pi/2)).squeeze(), beta=0.5)
|
||||||
loss = torch.nn.functional.mse_loss(y_pred, y)
|
loss = util.complexNN.naive_angle_loss(y_pred, y, mod=self.mod)
|
||||||
# loss = util.complexNN.naive_angle_loss(y_pred, y, mod=torch.pi/2)
|
|
||||||
loss_value = loss.item()
|
loss_value = loss.item()
|
||||||
running_loss += loss_value
|
running_loss += loss_value
|
||||||
loss_div += 1
|
loss_div += 1
|
||||||
|
|
||||||
if enable_progress:
|
if enable_progress:
|
||||||
progress.update(task, advance=1, description=f"{loss_value:.3e}")
|
progress.update(task, advance=1, description=f"{loss_value/np.pi*180:.3e} °")
|
||||||
|
|
||||||
running_loss = running_loss/loss_div
|
running_loss = running_loss / loss_div
|
||||||
|
|
||||||
self.writer.add_scalar(
|
self.writer.add_scalar(
|
||||||
"eval loss",
|
"eval loss",
|
||||||
@@ -506,19 +514,19 @@ class PolarizationTrainer:
|
|||||||
for i, config_path in enumerate(self.data_settings.config_path):
|
for i, config_path in enumerate(self.data_settings.config_path):
|
||||||
paths = Path.cwd().glob(config_path)
|
paths = Path.cwd().glob(config_path)
|
||||||
for j, path in enumerate(paths):
|
for j, path in enumerate(paths):
|
||||||
text = str(path) + '\n'
|
text = str(path) + "\n"
|
||||||
with open(path, 'r') as f:
|
with open(path, "r") as f:
|
||||||
text += f.read()
|
text += f.read()
|
||||||
text += '\n'
|
text += "\n"
|
||||||
self.writer.add_text(f"config_{i*len(self.data_settings.config_path)+j}", text)
|
self.writer.add_text(f"config_{i * len(self.data_settings.config_path) + j}", text)
|
||||||
|
|
||||||
elif isinstance(self.data_settings.config_path, str):
|
elif isinstance(self.data_settings.config_path, str):
|
||||||
paths = Path.cwd().glob(self.data_settings.config_path)
|
paths = Path.cwd().glob(self.data_settings.config_path)
|
||||||
for j, path in enumerate(paths):
|
for j, path in enumerate(paths):
|
||||||
text = str(path) + '\n'
|
text = str(path) + "\n"
|
||||||
with open(path, 'r') as f:
|
with open(path, "r") as f:
|
||||||
text += f.read()
|
text += f.read()
|
||||||
text += '\n'
|
text += "\n"
|
||||||
self.writer.add_text(f"config_{j}", text)
|
self.writer.add_text(f"config_{j}", text)
|
||||||
|
|
||||||
self.writer.flush()
|
self.writer.flush()
|
||||||
@@ -571,7 +579,8 @@ class PolarizationTrainer:
|
|||||||
if loss < self.best["loss"]:
|
if loss < self.best["loss"]:
|
||||||
self.best = checkpoint
|
self.best = checkpoint
|
||||||
save_path = (
|
save_path = (
|
||||||
Path(self.pytorch_settings.model_dir) / f"best_pol_{self.writer.get_logdir().split('/')[-1]}.tar"
|
Path(self.pytorch_settings.model_dir)
|
||||||
|
/ f"best_pol_{self.writer.get_logdir().split('/')[-1]}.tar"
|
||||||
)
|
)
|
||||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
self.save_checkpoint(self.best, save_path)
|
self.save_checkpoint(self.best, save_path)
|
||||||
@@ -580,6 +589,7 @@ class PolarizationTrainer:
|
|||||||
self.writer.close()
|
self.writer.close()
|
||||||
return self.best
|
return self.best
|
||||||
|
|
||||||
|
|
||||||
class RegenerationTrainer:
|
class RegenerationTrainer:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -636,6 +646,10 @@ class RegenerationTrainer:
|
|||||||
self.model_settings: ModelSettings = model_settings
|
self.model_settings: ModelSettings = model_settings
|
||||||
self.optimizer_settings: OptimizerSettings = optimizer_settings
|
self.optimizer_settings: OptimizerSettings = optimizer_settings
|
||||||
|
|
||||||
|
if self.global_settings.seed is not None:
|
||||||
|
random.seed(self.global_settings.seed)
|
||||||
|
np.random.seed(self.global_settings.seed)
|
||||||
|
|
||||||
self.console = console or Console()
|
self.console = console or Console()
|
||||||
self.writer = None
|
self.writer = None
|
||||||
|
|
||||||
@@ -706,10 +720,12 @@ class RegenerationTrainer:
|
|||||||
|
|
||||||
# dims = self.model_kwargs.pop("dims")
|
# dims = self.model_kwargs.pop("dims")
|
||||||
model_kwargs = copy.deepcopy(self.model_kwargs)
|
model_kwargs = copy.deepcopy(self.model_kwargs)
|
||||||
self.model = models.regenerator(*model_kwargs.pop('dims'),**model_kwargs)
|
self.model = models.regenerator(*model_kwargs.pop("dims"), **model_kwargs)
|
||||||
|
|
||||||
if self.writer is not None:
|
if self.writer is not None:
|
||||||
self.writer.add_graph(self.model, torch.zeros(1, input_dim, dtype=dtype))
|
self.writer.add_graph(
|
||||||
|
self.model, (torch.rand(1, input_dim, dtype=dtype), torch.rand(1, 1, dtype=dtype.to_real()))
|
||||||
|
)
|
||||||
|
|
||||||
self.model = self.model.to(self.pytorch_settings.device)
|
self.model = self.model.to(self.pytorch_settings.device)
|
||||||
if self.resume:
|
if self.resume:
|
||||||
@@ -728,12 +744,12 @@ class RegenerationTrainer:
|
|||||||
|
|
||||||
num_symbols = None
|
num_symbols = None
|
||||||
config_path = self.data_settings.config_path
|
config_path = self.data_settings.config_path
|
||||||
polarisations = self.data_settings.polarisations
|
|
||||||
randomise_polarisations = self.data_settings.randomise_polarisations
|
randomise_polarisations = self.data_settings.randomise_polarisations
|
||||||
|
osnr = self.data_settings.osnr
|
||||||
if override is not None:
|
if override is not None:
|
||||||
num_symbols = override.get("num_symbols", None)
|
num_symbols = override.get("num_symbols", None)
|
||||||
config_path = override.get("config_path", config_path)
|
config_path = override.get("config_path", config_path)
|
||||||
polarisations = override.get("polarisations", polarisations)
|
# polarisations = override.get("polarisations", polarisations)
|
||||||
randomise_polarisations = override.get("randomise_polarisation", randomise_polarisations)
|
randomise_polarisations = override.get("randomise_polarisation", randomise_polarisations)
|
||||||
# get dataset
|
# get dataset
|
||||||
dataset = FiberRegenerationDataset(
|
dataset = FiberRegenerationDataset(
|
||||||
@@ -746,8 +762,8 @@ class RegenerationTrainer:
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
real=not dtype.is_complex,
|
real=not dtype.is_complex,
|
||||||
num_symbols=num_symbols,
|
num_symbols=num_symbols,
|
||||||
polarisations=polarisations,
|
|
||||||
randomise_polarisations=randomise_polarisations,
|
randomise_polarisations=randomise_polarisations,
|
||||||
|
osnr = osnr,
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset_size = len(dataset)
|
dataset_size = len(dataset)
|
||||||
@@ -819,12 +835,14 @@ class RegenerationTrainer:
|
|||||||
for batch_idx, batch in enumerate(train_loader):
|
for batch_idx, batch in enumerate(train_loader):
|
||||||
x = batch["x"]
|
x = batch["x"]
|
||||||
y = batch["y"]
|
y = batch["y"]
|
||||||
|
angles = batch["mean_angle"]
|
||||||
self.model.zero_grad(set_to_none=True)
|
self.model.zero_grad(set_to_none=True)
|
||||||
x, y = (
|
x, y, angles = (
|
||||||
x.to(self.pytorch_settings.device),
|
x.to(self.pytorch_settings.device),
|
||||||
y.to(self.pytorch_settings.device),
|
y.to(self.pytorch_settings.device),
|
||||||
|
angles.to(self.pytorch_settings.device),
|
||||||
)
|
)
|
||||||
y_pred = self.model(x)
|
y_pred = self.model(x, -angles)
|
||||||
loss = util.complexNN.complex_mse_loss(y_pred, y, power=True)
|
loss = util.complexNN.complex_mse_loss(y_pred, y, power=True)
|
||||||
loss_value = loss.item()
|
loss_value = loss.item()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
@@ -872,11 +890,13 @@ class RegenerationTrainer:
|
|||||||
for _, batch in enumerate(valid_loader):
|
for _, batch in enumerate(valid_loader):
|
||||||
x = batch["x"]
|
x = batch["x"]
|
||||||
y = batch["y"]
|
y = batch["y"]
|
||||||
x, y = (
|
angles = batch["mean_angle"]
|
||||||
|
x, y, angles = (
|
||||||
x.to(self.pytorch_settings.device),
|
x.to(self.pytorch_settings.device),
|
||||||
y.to(self.pytorch_settings.device),
|
y.to(self.pytorch_settings.device),
|
||||||
|
angles.to(self.pytorch_settings.device),
|
||||||
)
|
)
|
||||||
y_pred = self.model(x)
|
y_pred = self.model(x, -angles)
|
||||||
error = util.complexNN.complex_mse_loss(y_pred, y, power=True)
|
error = util.complexNN.complex_mse_loss(y_pred, y, power=True)
|
||||||
error_value = error.item()
|
error_value = error.item()
|
||||||
running_error += error_value
|
running_error += error_value
|
||||||
@@ -884,7 +904,7 @@ class RegenerationTrainer:
|
|||||||
if enable_progress:
|
if enable_progress:
|
||||||
progress.update(task, advance=1, description=f"{error_value:.3e}")
|
progress.update(task, advance=1, description=f"{error_value:.3e}")
|
||||||
|
|
||||||
running_error = running_error/len(valid_loader)
|
running_error = running_error / len(valid_loader)
|
||||||
|
|
||||||
self.writer.add_scalar(
|
self.writer.add_scalar(
|
||||||
"eval loss",
|
"eval loss",
|
||||||
@@ -928,45 +948,65 @@ class RegenerationTrainer:
|
|||||||
def run_model(self, model, loader, trace_powers=False):
|
def run_model(self, model, loader, trace_powers=False):
|
||||||
model.eval()
|
model.eval()
|
||||||
fiber_out = []
|
fiber_out = []
|
||||||
|
fiber_out_rot = []
|
||||||
fiber_in = []
|
fiber_in = []
|
||||||
regen = []
|
regen = []
|
||||||
timestamps = []
|
timestamps = []
|
||||||
|
angles = []
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
model = model.to(self.pytorch_settings.device)
|
model = model.to(self.pytorch_settings.device)
|
||||||
for batch in loader:
|
for batch in loader:
|
||||||
x = batch["x"]
|
x = batch["x"]
|
||||||
y = batch["y"]
|
y = batch["y"]
|
||||||
|
plot_target = batch["plot_target"]
|
||||||
|
angle = batch["mean_angle"]
|
||||||
|
center_angle = batch["center_angle"]
|
||||||
timestamp = batch["timestamp"]
|
timestamp = batch["timestamp"]
|
||||||
plot_data = batch["plot_data"]
|
plot_data = batch["plot_data"]
|
||||||
x, y = (
|
plot_data_rot = batch["plot_data_rot"]
|
||||||
|
x, y, angle = (
|
||||||
x.to(self.pytorch_settings.device),
|
x.to(self.pytorch_settings.device),
|
||||||
y.to(self.pytorch_settings.device),
|
y.to(self.pytorch_settings.device),
|
||||||
|
angle.to(self.pytorch_settings.device),
|
||||||
)
|
)
|
||||||
if trace_powers:
|
if trace_powers:
|
||||||
y_pred, powers = model(x, trace_powers).cpu()
|
y_pred, powers = model(x, angle, True).cpu()
|
||||||
else:
|
else:
|
||||||
y_pred = model(x, trace_powers).cpu()
|
y_pred = model(x, angle).cpu()
|
||||||
# x = x.cpu()
|
# x = x.cpu()
|
||||||
# y = y.cpu()
|
# y = y.cpu()
|
||||||
y_pred = y_pred.view(y_pred.shape[0], -1, 2)
|
y_pred = y_pred.view(y_pred.shape[0], -1, 2)
|
||||||
|
y_pred = y_pred[:, y_pred.shape[1]//2, :]
|
||||||
y = y.view(y.shape[0], -1, 2)
|
y = y.view(y.shape[0], -1, 2)
|
||||||
plot_data = plot_data.view(plot_data.shape[0], -1, 2)
|
# plot_data = plot_data.view(plot_data.shape[0], -1, 2)
|
||||||
|
# c = torch.cos(-angle).cpu()
|
||||||
|
# s = torch.sin(-angle).cpu()
|
||||||
|
# rot = torch.stack([torch.stack([c, -s], dim=1), torch.stack([s, c], dim=1)], dim=2).squeeze(-1)
|
||||||
|
# plot_data = torch.bmm(plot_data, rot.to(dtype=plot_data.dtype))
|
||||||
|
# plot_data = plot_data
|
||||||
|
# sines = torch.sin(-angle.cpu())
|
||||||
|
# cosines = torch.cos(-angle.cpu())
|
||||||
|
# plot_data = torch.stack((plot_data[..., 0] * cosines - plot_data[..., 1] * sines, plot_data[..., 0] * sines + plot_data[..., 1] * cosines), dim=-1)
|
||||||
# x = x.view(x.shape[0], -1, 2)
|
# x = x.view(x.shape[0], -1, 2)
|
||||||
|
|
||||||
# timestamp = timestamp.view(-1, 1)
|
# timestamp = timestamp.view(-1, 1)
|
||||||
fiber_out.append(plot_data.squeeze())
|
fiber_out.append(plot_data.squeeze())
|
||||||
fiber_in.append(y.squeeze())
|
fiber_out_rot.append(plot_data_rot.squeeze())
|
||||||
|
fiber_in.append(plot_target.squeeze())
|
||||||
regen.append(y_pred.squeeze())
|
regen.append(y_pred.squeeze())
|
||||||
timestamps.append(timestamp.squeeze())
|
timestamps.append(timestamp.squeeze())
|
||||||
|
angles.append(center_angle.squeeze())
|
||||||
|
|
||||||
fiber_out = torch.vstack(fiber_out).cpu()
|
fiber_out = torch.vstack(fiber_out).cpu()
|
||||||
|
fiber_out_rot = torch.vstack(fiber_out_rot).cpu()
|
||||||
fiber_in = torch.vstack(fiber_in).cpu()
|
fiber_in = torch.vstack(fiber_in).cpu()
|
||||||
regen = torch.vstack(regen).cpu()
|
regen = torch.vstack(regen).cpu()
|
||||||
|
angles = torch.vstack(angles).cpu()
|
||||||
timestamps = torch.concat(timestamps).cpu()
|
timestamps = torch.concat(timestamps).cpu()
|
||||||
if trace_powers:
|
if trace_powers:
|
||||||
return fiber_in, fiber_out, regen, timestamps, powers
|
return fiber_in, fiber_out, fiber_out_rot, angles, regen, timestamps, powers
|
||||||
return fiber_in, fiber_out, regen, timestamps
|
return fiber_in, fiber_out, fiber_out_rot, angles, regen, timestamps
|
||||||
|
|
||||||
def write_parameters(self, epoch, attributes: list[str] | tuple[str] = None):
|
def write_parameters(self, epoch, attributes: list[str] | tuple[str] = None):
|
||||||
parameter_list = get_parameter_names_and_values(self.model)
|
parameter_list = get_parameter_names_and_values(self.model)
|
||||||
@@ -1027,18 +1067,18 @@ class RegenerationTrainer:
|
|||||||
for i, config_path in enumerate(self.data_settings.config_path):
|
for i, config_path in enumerate(self.data_settings.config_path):
|
||||||
paths = Path.cwd().glob(config_path)
|
paths = Path.cwd().glob(config_path)
|
||||||
for j, path in enumerate(paths):
|
for j, path in enumerate(paths):
|
||||||
text = str(path) + '\n'
|
text = str(path) + "\n"
|
||||||
with open(path, 'r') as f:
|
with open(path, "r") as f:
|
||||||
text += f.read()
|
text += f.read()
|
||||||
text += '\n'
|
text += "\n"
|
||||||
self.writer.add_text(f"config_{i*len(self.data_settings.config_path)+j}", text)
|
self.writer.add_text(f"config_{i * len(self.data_settings.config_path) + j}", text)
|
||||||
elif isinstance(self.data_settings.config_path, str):
|
elif isinstance(self.data_settings.config_path, str):
|
||||||
paths = Path.cwd().glob(self.data_settings.config_path)
|
paths = Path.cwd().glob(self.data_settings.config_path)
|
||||||
for j, path in enumerate(paths):
|
for j, path in enumerate(paths):
|
||||||
text = str(path) + '\n'
|
text = str(path) + "\n"
|
||||||
with open(path, 'r') as f:
|
with open(path, "r") as f:
|
||||||
text += f.read()
|
text += f.read()
|
||||||
text += '\n'
|
text += "\n"
|
||||||
self.writer.add_text(f"config_{j}", text)
|
self.writer.add_text(f"config_{j}", text)
|
||||||
|
|
||||||
self.writer.flush()
|
self.writer.flush()
|
||||||
@@ -1116,6 +1156,7 @@ class RegenerationTrainer:
|
|||||||
powers = [power / powers[0] for power in powers]
|
powers = [power / powers[0] for power in powers]
|
||||||
fig, ax = plt.subplots()
|
fig, ax = plt.subplots()
|
||||||
fig.set_figwidth(18)
|
fig.set_figwidth(18)
|
||||||
|
fig.set_figheight(4)
|
||||||
fig.suptitle(
|
fig.suptitle(
|
||||||
f"Energy conservation{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}"
|
f"Energy conservation{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}"
|
||||||
)
|
)
|
||||||
@@ -1184,6 +1225,7 @@ class RegenerationTrainer:
|
|||||||
|
|
||||||
fig, axs = plt.subplots(1, 2 * len(signals), sharex=True, sharey=True)
|
fig, axs = plt.subplots(1, 2 * len(signals), sharex=True, sharey=True)
|
||||||
fig.set_figwidth(18)
|
fig.set_figwidth(18)
|
||||||
|
fig.set_figheight(4)
|
||||||
fig.suptitle(f"Eye diagram{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}")
|
fig.suptitle(f"Eye diagram{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}")
|
||||||
# xaxis = timestamps / sps
|
# xaxis = timestamps / sps
|
||||||
# xaxis = np.arange(2 * sps) / sps
|
# xaxis = np.arange(2 * sps) / sps
|
||||||
@@ -1253,7 +1295,7 @@ class RegenerationTrainer:
|
|||||||
xaxis = timestamps / sps
|
xaxis = timestamps / sps
|
||||||
else:
|
else:
|
||||||
xaxis = timestamps
|
xaxis = timestamps
|
||||||
ax.plot(xaxis, np.abs(signal[:, i]) ** 2, label=label)
|
ax.plot(xaxis, np.abs(signal[:, i]) ** 2, label=label, alpha=0.7)
|
||||||
ax.set_xlabel("Sample" if sps is None else "Symbol")
|
ax.set_xlabel("Sample" if sps is None else "Symbol")
|
||||||
ax.set_ylabel("normalized power")
|
ax.set_ylabel("normalized power")
|
||||||
ax.minorticks_on()
|
ax.minorticks_on()
|
||||||
@@ -1269,7 +1311,7 @@ class RegenerationTrainer:
|
|||||||
|
|
||||||
def plot_model_response(
|
def plot_model_response(
|
||||||
self,
|
self,
|
||||||
model:torch.nn.Module=None,
|
model: torch.nn.Module = None,
|
||||||
title_append="",
|
title_append="",
|
||||||
subtitle="",
|
subtitle="",
|
||||||
# mode: Literal["eye", "head", "powers"] = "head",
|
# mode: Literal["eye", "head", "powers"] = "head",
|
||||||
@@ -1281,7 +1323,9 @@ class RegenerationTrainer:
|
|||||||
model = model.to(self.pytorch_settings.device)
|
model = model.to(self.pytorch_settings.device)
|
||||||
model.eval()
|
model.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
_, powers = model(input_data, trace_powers=True)
|
_, powers = model(
|
||||||
|
input_data, torch.zeros(input_data.shape[0], 1).to(self.pytorch_settings.device), trace_powers=True
|
||||||
|
)
|
||||||
|
|
||||||
powers = [power.item() for power in powers]
|
powers = [power.item() for power in powers]
|
||||||
layer_names = [name for (name, _) in model.named_children()]
|
layer_names = [name for (name, _) in model.named_children()]
|
||||||
@@ -1296,29 +1340,42 @@ class RegenerationTrainer:
|
|||||||
self.data_settings.shuffle = False
|
self.data_settings.shuffle = False
|
||||||
self.data_settings.train_split = 1.0
|
self.data_settings.train_split = 1.0
|
||||||
self.pytorch_settings.batchsize = max(self.pytorch_settings.head_symbols, self.pytorch_settings.eye_symbols)
|
self.pytorch_settings.batchsize = max(self.pytorch_settings.head_symbols, self.pytorch_settings.eye_symbols)
|
||||||
config_path = random.choice(self.data_settings.config_path) if isinstance(self.data_settings.config_path, (list, tuple)) else self.data_settings.config_path
|
config_path = (
|
||||||
fiber_length = int(float(str(config_path).split('-')[4])/1000)
|
random.choice(self.data_settings.config_path)
|
||||||
|
if isinstance(self.data_settings.config_path, (list, tuple))
|
||||||
|
else self.data_settings.config_path
|
||||||
|
)
|
||||||
|
fiber_length = int(float(str(config_path).split("-")[4]) / 1000)
|
||||||
if not hasattr(self, "_plot_loader"):
|
if not hasattr(self, "_plot_loader"):
|
||||||
self._plot_loader, _ = self.get_sliced_data(
|
self._plot_loader, _ = self.get_sliced_data(
|
||||||
override={
|
override={
|
||||||
"num_symbols": self.pytorch_settings.batchsize,
|
"num_symbols": self.pytorch_settings.batchsize,
|
||||||
"config_path": config_path,
|
"config_path": config_path,
|
||||||
"shuffle": False,
|
"shuffle": False,
|
||||||
"polarisations": (np.random.rand(1)*np.pi*2,),
|
"polarisations": (np.random.rand(1) * np.pi * 2,),
|
||||||
"randomise_polarisation": False,
|
"randomise_polarisation": self.data_settings.randomise_polarisations,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
self._sps = self._plot_loader.dataset.samples_per_symbol
|
self._sps = self._plot_loader.dataset.samples_per_symbol
|
||||||
self.data_settings = data_settings_backup
|
self.data_settings = data_settings_backup
|
||||||
self.pytorch_settings = pytorch_settings_backup
|
self.pytorch_settings = pytorch_settings_backup
|
||||||
|
|
||||||
fiber_in, fiber_out, regen, timestamps = self.run_model(model, self._plot_loader)
|
fiber_in, fiber_out, fiber_out_rot, angles, regen, timestamps = self.run_model(model, self._plot_loader)
|
||||||
fiber_in = fiber_in.view(-1, 2)
|
fiber_in = fiber_in.view(-1, 2)
|
||||||
fiber_out = fiber_out.view(-1, 2)
|
fiber_out = fiber_out.view(-1, 2)
|
||||||
|
fiber_out_rot = fiber_out_rot.view(-1, 2)
|
||||||
|
angles = angles.view(-1, 1)
|
||||||
|
angles = angles.real
|
||||||
|
angles = torch.fmod(angles, 2 * torch.pi)
|
||||||
|
angles = torch.div(angles, 2*torch.pi)
|
||||||
|
angles = torch.repeat_interleave(angles, 2, dim=1)
|
||||||
|
|
||||||
regen = regen.view(-1, 2)
|
regen = regen.view(-1, 2)
|
||||||
|
|
||||||
fiber_in = fiber_in.numpy()
|
fiber_in = fiber_in.numpy()
|
||||||
fiber_out = fiber_out.numpy()
|
fiber_out = fiber_out.numpy()
|
||||||
|
fiber_out_rot = fiber_out_rot.numpy()
|
||||||
|
angles = angles.numpy()
|
||||||
regen = regen.numpy()
|
regen = regen.numpy()
|
||||||
timestamps = timestamps.numpy()
|
timestamps = timestamps.numpy()
|
||||||
|
|
||||||
@@ -1327,28 +1384,29 @@ class RegenerationTrainer:
|
|||||||
import gc
|
import gc
|
||||||
|
|
||||||
head_fig = self._plot_model_response_head(
|
head_fig = self._plot_model_response_head(
|
||||||
fiber_in[:self.pytorch_settings.head_symbols*self._sps],
|
fiber_out_rot[: self.pytorch_settings.head_symbols * self._sps],
|
||||||
fiber_out[:self.pytorch_settings.head_symbols*self._sps],
|
fiber_in[: self.pytorch_settings.head_symbols * self._sps],
|
||||||
regen[:self.pytorch_settings.head_symbols*self._sps],
|
regen[: self.pytorch_settings.head_symbols * self._sps],
|
||||||
timestamps=timestamps[:self.pytorch_settings.head_symbols*self._sps],
|
angles[: self.pytorch_settings.head_symbols * self._sps],
|
||||||
|
timestamps=timestamps[: self.pytorch_settings.head_symbols * self._sps],
|
||||||
|
labels=("fiber out", "fiber in", "regen", "normed angle"),
|
||||||
|
sps=self._sps,
|
||||||
|
title_append=title_append + f" ({fiber_length} km)",
|
||||||
|
subtitle=subtitle,
|
||||||
|
show=show,
|
||||||
|
)
|
||||||
|
# raise NotImplementedError("Eye diagram not implemented")
|
||||||
|
eye_fig = self._plot_model_response_eye(
|
||||||
|
fiber_in[: self.pytorch_settings.eye_symbols * self._sps],
|
||||||
|
fiber_out_rot[: self.pytorch_settings.eye_symbols * self._sps],
|
||||||
|
regen[: self.pytorch_settings.eye_symbols * self._sps],
|
||||||
|
timestamps=timestamps[: self.pytorch_settings.eye_symbols * self._sps],
|
||||||
labels=("fiber in", "fiber out", "regen"),
|
labels=("fiber in", "fiber out", "regen"),
|
||||||
sps=self._sps,
|
sps=self._sps,
|
||||||
title_append=title_append + f" ({fiber_length} km)",
|
title_append=title_append + f" ({fiber_length} km)",
|
||||||
subtitle=subtitle,
|
subtitle=subtitle,
|
||||||
show=show,
|
show=show,
|
||||||
)
|
)
|
||||||
# raise NotImplementedError("Eye diagram not implemented")
|
|
||||||
eye_fig = self._plot_model_response_eye(
|
|
||||||
fiber_in[:self.pytorch_settings.eye_symbols*self._sps],
|
|
||||||
fiber_out[:self.pytorch_settings.eye_symbols*self._sps],
|
|
||||||
regen[:self.pytorch_settings.eye_symbols*self._sps],
|
|
||||||
timestamps=timestamps[:self.pytorch_settings.eye_symbols*self._sps],
|
|
||||||
labels=("fiber in", "fiber out", "regen"),
|
|
||||||
sps=self._sps,
|
|
||||||
title_append=title_append + f" ({fiber_length} km)",
|
|
||||||
subtitle=subtitle,
|
|
||||||
show=show,
|
|
||||||
)
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
return head_fig, eye_fig, power_fig
|
return head_fig, eye_fig, power_fig
|
||||||
@@ -1361,7 +1419,7 @@ class RegenerationTrainer:
|
|||||||
self.model_settings.overrides.get(f"n_hidden_nodes_{i}", -1) for i in range(model_n_hidden_layers)
|
self.model_settings.overrides.get(f"n_hidden_nodes_{i}", -1) for i in range(model_n_hidden_layers)
|
||||||
]
|
]
|
||||||
model_dims.insert(0, input_dim)
|
model_dims.insert(0, input_dim)
|
||||||
model_dims.append(2)
|
model_dims.append(self.model_settings.output_dim)
|
||||||
model_dims = [str(dim) for dim in model_dims]
|
model_dims = [str(dim) for dim in model_dims]
|
||||||
model_activation_func = self.model_settings.model_activation_func
|
model_activation_func = self.model_settings.model_activation_func
|
||||||
model_dtype = self.data_settings.dtype
|
model_dtype = self.data_settings.dtype
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
import optuna
|
import optuna
|
||||||
|
import torch
|
||||||
|
import util
|
||||||
from hypertraining.hypertraining import HyperTraining
|
from hypertraining.hypertraining import HyperTraining
|
||||||
from hypertraining.settings import (
|
from hypertraining.settings import (
|
||||||
GlobalSettings,
|
GlobalSettings,
|
||||||
@@ -16,24 +18,29 @@ global_settings = GlobalSettings(
|
|||||||
)
|
)
|
||||||
|
|
||||||
data_settings = DataSettings(
|
data_settings = DataSettings(
|
||||||
config_path="data/*-128-16384-100000-0-0-17-0-PAM4-0.ini",
|
# config_path="data/*-128-16384-100000-0-0-17-0-PAM4-0.ini",
|
||||||
|
config_path="data/20241204-131003-128-16384-100000-0-0-17-0-PAM4-0.ini",
|
||||||
dtype="complex64",
|
dtype="complex64",
|
||||||
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
|
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
|
||||||
symbols=13, # study: single_core_regen_20241123_011232
|
# symbols=13, # study: single_core_regen_20241123_011232
|
||||||
|
# symbols = (3, 13),
|
||||||
|
symbols=4,
|
||||||
# output_size = (11, 32), # ballpark 26 taps -> 2 taps per input symbol -> 1 tap every 0.01m (model has 52 inputs)
|
# output_size = (11, 32), # ballpark 26 taps -> 2 taps per input symbol -> 1 tap every 0.01m (model has 52 inputs)
|
||||||
output_size=26, # study: single_core_regen_20241123_011232 (model_input_dim/2)
|
# output_size=26, # study: single_core_regen_20241123_011232 (model_input_dim/2)
|
||||||
|
output_size=(8, 30),
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
in_out_delay=0,
|
in_out_delay=0,
|
||||||
xy_delay=0,
|
xy_delay=0,
|
||||||
drop_first=128 * 100,
|
drop_first=256,
|
||||||
train_split=0.8,
|
train_split=0.8,
|
||||||
|
randomise_polarisations=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
pytorch_settings = PytorchSettings(
|
pytorch_settings = PytorchSettings(
|
||||||
epochs=10000,
|
epochs=10,
|
||||||
batchsize=2**10,
|
batchsize=2**10,
|
||||||
device="cuda",
|
device="cuda",
|
||||||
dataloader_workers=12,
|
dataloader_workers=4,
|
||||||
dataloader_prefetch=4,
|
dataloader_prefetch=4,
|
||||||
summary_dir=".runs",
|
summary_dir=".runs",
|
||||||
write_every=2**5,
|
write_every=2**5,
|
||||||
@@ -43,28 +50,70 @@ pytorch_settings = PytorchSettings(
|
|||||||
|
|
||||||
model_settings = ModelSettings(
|
model_settings = ModelSettings(
|
||||||
output_dim=2,
|
output_dim=2,
|
||||||
# n_hidden_layers = (3, 8),
|
n_hidden_layers = (2, 5),
|
||||||
n_hidden_layers=4,
|
n_hidden_nodes=(2, 16),
|
||||||
overrides={
|
model_activation_func="EOActivation",
|
||||||
"n_hidden_nodes_0": 8,
|
dropout_prob=0,
|
||||||
"n_hidden_nodes_1": 6,
|
model_layer_function="ONNRect",
|
||||||
"n_hidden_nodes_2": 4,
|
model_layer_kwargs={"square": True},
|
||||||
"n_hidden_nodes_3": 8,
|
# scale=(False, True),
|
||||||
},
|
scale=False,
|
||||||
model_activation_func="Mag",
|
model_layer_parametrizations=[
|
||||||
# satabsT0=(1e-6, 1),
|
{
|
||||||
|
"tensor_name": "weight",
|
||||||
|
"parametrization": util.complexNN.energy_conserving,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tensor_name": "alpha",
|
||||||
|
"parametrization": util.complexNN.clamp,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tensor_name": "gain",
|
||||||
|
"parametrization": util.complexNN.clamp,
|
||||||
|
"kwargs": {
|
||||||
|
"min": 0,
|
||||||
|
"max": float("inf"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tensor_name": "phase_bias",
|
||||||
|
"parametrization": util.complexNN.clamp,
|
||||||
|
"kwargs": {
|
||||||
|
"min": 0,
|
||||||
|
"max": 2 * torch.pi,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tensor_name": "scales",
|
||||||
|
"parametrization": util.complexNN.clamp,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tensor_name": "angle",
|
||||||
|
"parametrization": util.complexNN.clamp,
|
||||||
|
"kwargs": {
|
||||||
|
"min": -torch.pi,
|
||||||
|
"max": torch.pi,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tensor_name": "loss",
|
||||||
|
"parametrization": util.complexNN.clamp,
|
||||||
|
},
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
optimizer_settings = OptimizerSettings(
|
optimizer_settings = OptimizerSettings(
|
||||||
optimizer="Adam",
|
optimizer="AdamW",
|
||||||
# learning_rate = (1e-5, 1e-1),
|
optimizer_kwargs={
|
||||||
learning_rate=5e-3
|
"lr": 5e-3,
|
||||||
# learning_rate=5e-4,
|
"amsgrad": True,
|
||||||
|
# "weight_decay": 1e-7,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
optuna_settings = OptunaSettings(
|
optuna_settings = OptunaSettings(
|
||||||
n_trials=1,
|
n_trials=1024,
|
||||||
n_workers=1,
|
n_workers=8,
|
||||||
timeout=3600,
|
timeout=3600,
|
||||||
directions=("minimize",),
|
directions=("minimize",),
|
||||||
metrics_names=("mse",),
|
metrics_names=("mse",),
|
||||||
|
|||||||
@@ -26,24 +26,26 @@ global_settings = GlobalSettings(
|
|||||||
)
|
)
|
||||||
|
|
||||||
data_settings = DataSettings(
|
data_settings = DataSettings(
|
||||||
config_path="data/*-128-16384-10000-0-0-0-0-PAM4-0-0.4.ini",
|
# config_path="data/*-128-16384-1-0-0-0-0-PAM4-0-0.ini",
|
||||||
|
config_path="data/*-128-16384-10000-0-0-17-0-PAM4-0.ini",
|
||||||
# config_path=[f"data/20241202-*-128-16384-{length}-0-0-17-0-PAM4-0.ini" for length in range(48000, 53000, 1000)],
|
# config_path=[f"data/20241202-*-128-16384-{length}-0-0-17-0-PAM4-0.ini" for length in range(48000, 53000, 1000)],
|
||||||
dtype="complex64",
|
dtype="complex64",
|
||||||
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
|
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
|
||||||
symbols=13, # study: single_core_regen_20241123_011232
|
symbols=4, # study: single_core_regen_20241123_011232
|
||||||
# output_size = (11, 32), # 26 taps -> 2 taps per input symbol -> 1 tap every 1cm (model has 52 inputs (x/y))
|
# output_size = (11, 32), # 26 taps -> 2 taps per input symbol -> 1 tap every 1cm (model has 52 inputs (x/y))
|
||||||
output_size=26, # study: single_core_regen_20241123_011232 (model_input_dim/2)
|
output_size=20, # study: single_core_regen_20241123_011232 (model_input_dim/2)
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
drop_first=64,
|
drop_first=64,
|
||||||
train_split=0.8,
|
train_split=0.8,
|
||||||
randomise_polarisations=True,
|
randomise_polarisations=True,
|
||||||
|
osnr=10,
|
||||||
)
|
)
|
||||||
|
|
||||||
pytorch_settings = PytorchSettings(
|
pytorch_settings = PytorchSettings(
|
||||||
epochs=10000,
|
epochs=10000,
|
||||||
batchsize=2**14,
|
batchsize=2**14,
|
||||||
device="cuda",
|
device="cuda",
|
||||||
dataloader_workers=16,
|
dataloader_workers=24,
|
||||||
dataloader_prefetch=8,
|
dataloader_prefetch=8,
|
||||||
summary_dir=".runs",
|
summary_dir=".runs",
|
||||||
write_every=2**5,
|
write_every=2**5,
|
||||||
@@ -53,17 +55,17 @@ pytorch_settings = PytorchSettings(
|
|||||||
|
|
||||||
model_settings = ModelSettings(
|
model_settings = ModelSettings(
|
||||||
output_dim=2,
|
output_dim=2,
|
||||||
n_hidden_layers=5,
|
n_hidden_layers=3,
|
||||||
overrides={
|
overrides={
|
||||||
# "hidden_layer_dims": (8, 8, 4, 4),
|
# "hidden_layer_dims": (8, 8, 4, 4),
|
||||||
"n_hidden_nodes_0": 8,
|
"n_hidden_nodes_0": 16,
|
||||||
"n_hidden_nodes_1": 8,
|
"n_hidden_nodes_1": 8,
|
||||||
"n_hidden_nodes_2": 4,
|
"n_hidden_nodes_2": 8,
|
||||||
"n_hidden_nodes_3": 4,
|
# "n_hidden_nodes_3": 4,
|
||||||
"n_hidden_nodes_4": 2,
|
# "n_hidden_nodes_4": 2,
|
||||||
},
|
},
|
||||||
model_activation_func="EOActivation",
|
model_activation_func="EOActivation",
|
||||||
dropout_prob=0.01,
|
dropout_prob=0,
|
||||||
model_layer_function="ONNRect",
|
model_layer_function="ONNRect",
|
||||||
model_layer_kwargs={"square": True},
|
model_layer_kwargs={"square": True},
|
||||||
scale=False,
|
scale=False,
|
||||||
@@ -126,7 +128,7 @@ model_settings = ModelSettings(
|
|||||||
optimizer_settings = OptimizerSettings(
|
optimizer_settings = OptimizerSettings(
|
||||||
optimizer="AdamW",
|
optimizer="AdamW",
|
||||||
optimizer_kwargs={
|
optimizer_kwargs={
|
||||||
"lr": 0.01,
|
"lr": 0.005,
|
||||||
"amsgrad": True,
|
"amsgrad": True,
|
||||||
# "weight_decay": 1e-7,
|
# "weight_decay": 1e-7,
|
||||||
},
|
},
|
||||||
@@ -242,7 +244,15 @@ if __name__ == "__main__":
|
|||||||
pytorch_settings=pytorch_settings,
|
pytorch_settings=pytorch_settings,
|
||||||
model_settings=model_settings,
|
model_settings=model_settings,
|
||||||
optimizer_settings=optimizer_settings,
|
optimizer_settings=optimizer_settings,
|
||||||
# checkpoint_path=".models/best_20241205_235929.tar",
|
checkpoint_path=".models/best_20241216_221359.tar",
|
||||||
|
reset_epoch=True,
|
||||||
|
# settings_override={
|
||||||
|
# "optimizer_settings": {
|
||||||
|
# "optimizer_kwargs": {
|
||||||
|
# "lr": 0.01,
|
||||||
|
# },
|
||||||
|
# }
|
||||||
|
# }
|
||||||
# 20241202_143149
|
# 20241202_143149
|
||||||
)
|
)
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ global_settings = GlobalSettings(
|
|||||||
)
|
)
|
||||||
|
|
||||||
data_settings = DataSettings(
|
data_settings = DataSettings(
|
||||||
config_path="data/*-128-16384-10000-0-0-0-0-PAM4-0-0.4.ini",
|
config_path="data/20241211-105524-128-16384-1-0-0-0-0-PAM4-0-0.ini",
|
||||||
# config_path=[f"data/20241202-*-128-16384-{length}-0-0-17-0-PAM4-0.ini" for length in range(48000, 53000, 1000)],
|
# config_path=[f"data/20241202-*-128-16384-{length}-0-0-17-0-PAM4-0.ini" for length in range(48000, 53000, 1000)],
|
||||||
dtype="complex64",
|
dtype="complex64",
|
||||||
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
|
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
|
||||||
@@ -53,14 +53,14 @@ pytorch_settings = PytorchSettings(
|
|||||||
)
|
)
|
||||||
|
|
||||||
model_settings = ModelSettings(
|
model_settings = ModelSettings(
|
||||||
output_dim=3,
|
output_dim=1,
|
||||||
n_hidden_layers=3,
|
n_hidden_layers=3,
|
||||||
overrides={
|
overrides={
|
||||||
"n_hidden_nodes_0": 2,
|
"n_hidden_nodes_0": 4,
|
||||||
"n_hidden_nodes_1": 2,
|
"n_hidden_nodes_1": 4,
|
||||||
"n_hidden_nodes_2": 2,
|
"n_hidden_nodes_2": 4,
|
||||||
},
|
},
|
||||||
dropout_prob=0.01,
|
dropout_prob=0,
|
||||||
model_layer_function="ONNRect",
|
model_layer_function="ONNRect",
|
||||||
model_activation_func="EOActivation",
|
model_activation_func="EOActivation",
|
||||||
model_layer_kwargs={"square": True},
|
model_layer_kwargs={"square": True},
|
||||||
@@ -110,20 +110,24 @@ model_settings = ModelSettings(
|
|||||||
)
|
)
|
||||||
|
|
||||||
optimizer_settings = OptimizerSettings(
|
optimizer_settings = OptimizerSettings(
|
||||||
optimizer="AdamW",
|
optimizer="RMSprop",
|
||||||
|
# optimizer="AdamW",
|
||||||
optimizer_kwargs={
|
optimizer_kwargs={
|
||||||
"lr": 0.005,
|
"lr": 0.01,
|
||||||
"amsgrad": True,
|
"alpha": 0.9,
|
||||||
|
"momentum": 0.1,
|
||||||
|
"eps": 1e-8,
|
||||||
|
"centered": True,
|
||||||
|
# "amsgrad": True,
|
||||||
# "weight_decay": 1e-7,
|
# "weight_decay": 1e-7,
|
||||||
},
|
},
|
||||||
# learning_rate=0.05,
|
|
||||||
scheduler="ReduceLROnPlateau",
|
scheduler="ReduceLROnPlateau",
|
||||||
scheduler_kwargs={
|
scheduler_kwargs={
|
||||||
"patience": 2**6,
|
"patience": 2**5,
|
||||||
"factor": 0.75,
|
"factor": 0.75,
|
||||||
# "threshold": 1e-3,
|
# "threshold": 1e-3,
|
||||||
"min_lr": 1e-6,
|
"min_lr": 1e-6,
|
||||||
"cooldown": 10,
|
# "cooldown": 10,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -320,6 +320,29 @@ class normalize_by_first(nn.Module):
|
|||||||
def forward(self, data):
|
def forward(self, data):
|
||||||
return data / data[:, 0].unsqueeze(1)
|
return data / data[:, 0].unsqueeze(1)
|
||||||
|
|
||||||
|
class rotate(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(rotate, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, data, angle):
|
||||||
|
# data -> (batch, n*2)
|
||||||
|
# angle -> (batch, n)
|
||||||
|
data_ = data
|
||||||
|
if angle.ndim == 1:
|
||||||
|
angle_ = angle.unsqueeze(1)
|
||||||
|
else:
|
||||||
|
angle_ = angle
|
||||||
|
angle_ = angle_.expand(-1, data_.shape[1]//2)
|
||||||
|
c = torch.cos(angle_)
|
||||||
|
s = torch.sin(angle_)
|
||||||
|
rot = torch.stack([torch.stack([c, -s], dim=2),
|
||||||
|
torch.stack([s, c], dim=2)], dim=3)
|
||||||
|
d = torch.bmm(data_.reshape(-1, 1, 2), rot.view(-1, 2, 2).to(dtype=data_.dtype)).reshape(*data.shape)
|
||||||
|
# d = torch.bmm(data.unsqueeze(-1).mT, rot.to(dtype=data.dtype).mT).mT.squeeze(-1)
|
||||||
|
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
class photodiode(nn.Module):
|
class photodiode(nn.Module):
|
||||||
def __init__(self, size, bias=True):
|
def __init__(self, size, bias=True):
|
||||||
super(photodiode, self).__init__()
|
super(photodiode, self).__init__()
|
||||||
@@ -487,7 +510,7 @@ class MZISingle(nn.Module):
|
|||||||
return x * torch.exp(1j * self.phi) * torch.sin(self.omega + self.func(x))
|
return x * torch.exp(1j * self.phi) * torch.sin(self.omega + self.func(x))
|
||||||
|
|
||||||
def naive_angle_loss(x: torch.Tensor, target: torch.Tensor, mod=2*torch.pi):
|
def naive_angle_loss(x: torch.Tensor, target: torch.Tensor, mod=2*torch.pi):
|
||||||
return torch.fmod((x - target), mod).square().mean()
|
return torch.fmod((x.abs().real - target.abs().real), mod).abs().mean()
|
||||||
|
|
||||||
def cosine_loss(x: torch.Tensor, target: torch.Tensor):
|
def cosine_loss(x: torch.Tensor, target: torch.Tensor):
|
||||||
return (2*(1 - torch.cos(x - target))).mean()
|
return (2*(1 - torch.cos(x - target))).mean()
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from torch.utils.data import Dataset
|
|||||||
# from torch.utils.data import Sampler
|
# from torch.utils.data import Sampler
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import configparser
|
import configparser
|
||||||
|
import multiprocessing as mp
|
||||||
|
|
||||||
# class SubsetSampler(Sampler[int]):
|
# class SubsetSampler(Sampler[int]):
|
||||||
# """
|
# """
|
||||||
@@ -113,8 +114,9 @@ class FiberRegenerationDataset(Dataset):
|
|||||||
dtype: torch.dtype = None,
|
dtype: torch.dtype = None,
|
||||||
real: bool = False,
|
real: bool = False,
|
||||||
device=None,
|
device=None,
|
||||||
polarisations: tuple | list = (0,),
|
osnr: float = None,
|
||||||
randomise_polarisations: bool = False,
|
randomise_polarisations: bool = False,
|
||||||
|
repeat_randoms: int = 1,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -190,18 +192,20 @@ class FiberRegenerationDataset(Dataset):
|
|||||||
files.append(config["data"]["file"].strip('"'))
|
files.append(config["data"]["file"].strip('"'))
|
||||||
self.config["data"]["file"] = str(files)
|
self.config["data"]["file"] = str(files)
|
||||||
|
|
||||||
for i, angle in enumerate(torch.tensor(np.array(polarisations))):
|
# if polarisations is not None:
|
||||||
data_raw_copy = data_raw.clone()
|
# self.angles = torch.tensor(polarisations).repeat(len(data_raw), 1)
|
||||||
if angle == 0:
|
# for i, angle in enumerate(torch.tensor(np.array(polarisations))):
|
||||||
continue
|
# data_raw_copy = data_raw.clone()
|
||||||
sine = torch.sin(angle)
|
# if angle == 0:
|
||||||
cosine = torch.cos(angle)
|
# continue
|
||||||
data_raw_copy[:, 2] = data_raw[:, 2] * cosine - data_raw[:, 3] * sine
|
# sine = torch.sin(angle)
|
||||||
data_raw_copy[:, 3] = data_raw[:, 2] * sine + data_raw[:, 3] * cosine
|
# cosine = torch.cos(angle)
|
||||||
if i == 0:
|
# data_raw_copy[:, 2] = data_raw[:, 2] * cosine - data_raw[:, 3] * sine
|
||||||
data_raw = data_raw_copy
|
# data_raw_copy[:, 3] = data_raw[:, 2] * sine + data_raw[:, 3] * cosine
|
||||||
else:
|
# if i == 0:
|
||||||
data_raw = torch.cat([data_raw, data_raw_copy], dim=0)
|
# data_raw = data_raw_copy
|
||||||
|
# else:
|
||||||
|
# data_raw = torch.cat([data_raw, data_raw_copy], dim=0)
|
||||||
|
|
||||||
self.device = data_raw.device
|
self.device = data_raw.device
|
||||||
|
|
||||||
@@ -278,23 +282,61 @@ class FiberRegenerationDataset(Dataset):
|
|||||||
timestamps = data_raw[4, :]
|
timestamps = data_raw[4, :]
|
||||||
data_raw = data_raw[:4, :]
|
data_raw = data_raw[:4, :]
|
||||||
data_raw = data_raw.view(2, 2, -1)
|
data_raw = data_raw.view(2, 2, -1)
|
||||||
timestamps_doubled = torch.cat([timestamps.unsqueeze(dim=0), timestamps.unsqueeze(dim=0)], dim=0).unsqueeze(
|
fiber_in = data_raw[0, :, :]
|
||||||
dim=1
|
fiber_out = data_raw[1, :, :]
|
||||||
)
|
# timestamps_doubled = torch.cat([timestamps.unsqueeze(dim=0), timestamps.unsqueeze(dim=0)], dim=0).unsqueeze(
|
||||||
data_raw = torch.cat([data_raw, timestamps_doubled], dim=1)
|
# dim=1
|
||||||
|
# )
|
||||||
|
fiber_in = torch.cat([fiber_in, timestamps.unsqueeze(0)], dim=0)
|
||||||
|
fiber_out = torch.cat([fiber_out, timestamps.unsqueeze(0)], dim=0)
|
||||||
|
|
||||||
|
if repeat_randoms > 1:
|
||||||
|
fiber_in = fiber_in.repeat(1, 1, repeat_randoms)
|
||||||
|
fiber_out = fiber_out.repeat(1, 1, repeat_randoms)
|
||||||
|
# review: potential problems with repeated timestamps when plotting
|
||||||
|
else:
|
||||||
|
repeat_randoms = 1
|
||||||
|
|
||||||
|
if self.randomise_polarisations:
|
||||||
|
angles = torch.fmod(torch.randperm(data_raw.shape[-1]*repeat_randoms), 2) * torch.pi
|
||||||
|
# start_angle = torch.rand(1) * 2 * torch.pi
|
||||||
|
# angles = start_angle + torch.cumsum(torch.randn(data_raw.shape[-1])/333, dim=0) * torch.pi * 2 # random walk
|
||||||
|
# self.angles = torch.rand(self.data.shape[0]) * 2 * torch.pi
|
||||||
|
else:
|
||||||
|
angles = torch.zeros(data_raw.shape[-1])
|
||||||
|
|
||||||
|
sin = torch.sin(angles)
|
||||||
|
cos = torch.cos(angles)
|
||||||
|
rot = torch.stack([torch.stack([cos, -sin], dim=1), torch.stack([sin, cos], dim=1)], dim=2)
|
||||||
|
data_rot = torch.bmm(fiber_out[:2, :].T.unsqueeze(1), rot.to(dtype=data_raw.dtype)).squeeze(1).T
|
||||||
|
fiber_out = torch.cat((fiber_out, data_rot), dim=0)
|
||||||
|
fiber_out = torch.cat([fiber_out, angles.unsqueeze(0)], dim=0)
|
||||||
|
|
||||||
|
if osnr is not None:
|
||||||
|
popt = torch.mean(fiber_out[:2, :, :].abs().flatten(), dim=-1)
|
||||||
|
noise = torch.randn_like(fiber_out[:2, :, :])
|
||||||
|
pn = torch.mean(noise.abs().flatten(), dim=-1)
|
||||||
|
noise = noise * (popt / pn) * 10 ** (-osnr / 20)
|
||||||
|
fiber_out[:2, :, :] = torch.add(fiber_out[:2, :, :], noise)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# data_raw = torch.cat([data_raw, timestamps_doubled], dim=1)
|
# data_raw = torch.cat([data_raw, timestamps_doubled], dim=1)
|
||||||
# data layout
|
# data layout
|
||||||
# [ [E_in_x, E_in_y, timestamps],
|
# [ [E_in_x, E_in_y, timestamps],
|
||||||
# [E_out_x, E_out_y, timestamps] ]
|
# [E_out_x, E_out_y, timestamps] ]
|
||||||
|
|
||||||
self.data = data_raw.unfold(dimension=-1, size=self.samples_per_slice, step=1)
|
self.fiber_in = fiber_in.unfold(dimension=-1, size=self.samples_per_slice, step=1)
|
||||||
self.data = self.data.movedim(-2, 0)
|
self.fiber_in = self.fiber_in.movedim(-2, 0)
|
||||||
|
|
||||||
if randomise_polarisations:
|
self.fiber_out = fiber_out.unfold(dimension=-1, size=self.samples_per_slice, step=1)
|
||||||
self.angles = torch.rand(self.data.shape[0]) * np.pi * 2
|
self.fiber_out = self.fiber_out.movedim(-2, 0)
|
||||||
# self.data[:, 1, :2, :] = self.rotate(self.data[:, 1, :2, :], self.angles)
|
|
||||||
else:
|
# self.data = data_raw.unfold(dimension=-1, size=self.samples_per_slice, step=1)
|
||||||
self.angles = torch.zeros(self.data.shape[0])
|
# self.data = self.data.movedim(-2, 0)
|
||||||
|
# self.angles = torch.zeros(self.data.shape[0])
|
||||||
|
...
|
||||||
# ...
|
# ...
|
||||||
# -> [no_slices, 2, 3, samples_per_slice]
|
# -> [no_slices, 2, 3, samples_per_slice]
|
||||||
|
|
||||||
@@ -305,51 +347,56 @@ class FiberRegenerationDataset(Dataset):
|
|||||||
# ...
|
# ...
|
||||||
# ] -> [no_slices, 2, 3, samples_per_slice]
|
# ] -> [no_slices, 2, 3, samples_per_slice]
|
||||||
|
|
||||||
...
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.data.shape[0]
|
return self.fiber_in.shape[0]
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
if isinstance(idx, slice):
|
if isinstance(idx, slice):
|
||||||
return [self.__getitem__(i) for i in range(*idx.indices(len(self)))]
|
return [self.__getitem__(i) for i in range(*idx.indices(len(self)))]
|
||||||
else:
|
else:
|
||||||
data_slice = self.data[idx].squeeze()
|
# fiber in: [E_in_x, E_in_y, timestamps]
|
||||||
|
# fiber out: [E_out_x, E_out_y, timestamps, E_out_x_rot, E_out_y_rot, angle]
|
||||||
|
|
||||||
data_slice = data_slice[:, :, : data_slice.shape[2] // self.output_dim * self.output_dim]
|
fiber_in = self.fiber_in[idx].squeeze()
|
||||||
|
fiber_out = self.fiber_out[idx].squeeze()
|
||||||
|
|
||||||
data_slice = data_slice.view(data_slice.shape[0], data_slice.shape[1], self.output_dim, -1)
|
fiber_in = fiber_in[..., : fiber_in.shape[-1] // self.output_dim * self.output_dim]
|
||||||
|
fiber_out = fiber_out[..., : fiber_out.shape[-1] // self.output_dim * self.output_dim]
|
||||||
|
|
||||||
|
fiber_in = fiber_in.view(fiber_in.shape[0], self.output_dim, -1)
|
||||||
|
fiber_out = fiber_out.view(fiber_out.shape[0], self.output_dim, -1)
|
||||||
|
|
||||||
|
# data_slice = data_slice[:, :, : data_slice.shape[2] // self.output_dim * self.output_dim]
|
||||||
|
|
||||||
|
# data_slice = data_slice.view(data_slice.shape[0], data_slice.shape[1], self.output_dim, -1)
|
||||||
|
|
||||||
|
# angle = self.angles[idx]
|
||||||
|
|
||||||
|
center_angle = fiber_out[5, self.output_dim // 2, 0]
|
||||||
|
angles = fiber_out[5, :, 0]
|
||||||
|
plot_data = fiber_out[:2, self.output_dim // 2, 0].detach().clone()
|
||||||
|
plot_data_rot = fiber_out[3:5, self.output_dim // 2, 0].detach().clone()
|
||||||
|
data = fiber_out[3:5, :, 0]
|
||||||
|
|
||||||
# if self.randomise_polarisations:
|
# if self.randomise_polarisations:
|
||||||
# angle = torch.rand(1) * torch.pi * 2
|
# data = data.mT
|
||||||
# sine = torch.sin(angle)
|
# c = torch.cos(angle).unsqueeze(-1)
|
||||||
# cosine = torch.cos(angle)
|
# s = torch.sin(angle).unsqueeze(-1)
|
||||||
# data_slice_ = data_slice[1]
|
# rot = torch.stack([torch.stack([c, -s], dim=1), torch.stack([s, c], dim=1)], dim=2).squeeze(-1)
|
||||||
# data_slice[1, 0] = data_slice_[0] * cosine - data_slice_[1] * sine
|
# data = torch.bmm(data.mT.unsqueeze(0), rot.to(dtype=data.dtype)).squeeze(-1)
|
||||||
# data_slice[1,1] = data_slice_[0] * sine + data_slice_[1] * cosine
|
...
|
||||||
# else:
|
|
||||||
# angle = torch.zeros(1)
|
|
||||||
|
|
||||||
# data = data_slice[1, :2, :, 0]
|
# angle = torch.zeros_like(angle)
|
||||||
|
|
||||||
angle = self.angles[idx]
|
|
||||||
|
|
||||||
data_index = 1
|
|
||||||
|
|
||||||
data_slice[1, :2, :, :] = self.rotate(data_slice[data_index, :2, :, :], angle)
|
|
||||||
|
|
||||||
data = data_slice[1, :2, :, 0]
|
|
||||||
# data = self.rotate(data, angle)
|
|
||||||
|
|
||||||
# for both polarisations (x/y), calculate the mean of the signal around the current symbol (-> corresponding to a lowpass filter)
|
# for both polarisations (x/y), calculate the mean of the signal around the current symbol (-> corresponding to a lowpass filter)
|
||||||
angle_data = data_slice[1, :2, :, :].reshape(2, -1).mean(dim=1)
|
angle_data = fiber_out[:2, :, :].reshape(2, -1).mean(dim=1).repeat(1, self.output_dim)
|
||||||
angle_data2 = self.complex_max(data_slice[1, :2, :, :].reshape(2, -1))
|
angle_data2 = self.complex_max(fiber_out[:2, :, :].reshape(2, -1)).repeat(1, self.output_dim)
|
||||||
plot_data = data_slice[1, :2, self.output_dim // 2, 0]
|
# sop = self.polarimeter(plot_data)
|
||||||
sop = self.polarimeter(plot_data)
|
|
||||||
# angle_data = data_slice[1, :2, :self.output_dim//2, :].squeeze().reshape(2, -1).mean(dim=-1)
|
# angle_data = data_slice[1, :2, :self.output_dim//2, :].squeeze().reshape(2, -1).mean(dim=-1)
|
||||||
# angle = data_slice[1, 3, self.output_dim // 2, 0].real
|
# angle = data_slice[1, 3, self.output_dim // 2, 0].real
|
||||||
target = data_slice[0, :2, self.output_dim // 2, 0]
|
target = fiber_in[:2, self.output_dim // 2, 0]
|
||||||
target_timestamp = data_slice[0, 2, self.output_dim // 2, 0].real
|
plot_target = fiber_in[:2, self.output_dim // 2, 0].detach().clone()
|
||||||
|
target_timestamp = fiber_in[2, self.output_dim // 2, 0].real
|
||||||
...
|
...
|
||||||
|
|
||||||
# data_timestamps = data[-1,:].real
|
# data_timestamps = data[-1,:].real
|
||||||
@@ -360,14 +407,32 @@ class FiberRegenerationDataset(Dataset):
|
|||||||
|
|
||||||
# transpose to interleave the x and y data in the output tensor
|
# transpose to interleave the x and y data in the output tensor
|
||||||
data = data.transpose(0, 1).flatten().squeeze()
|
data = data.transpose(0, 1).flatten().squeeze()
|
||||||
angle_data = angle_data.flatten().squeeze()
|
angle_data = angle_data.transpose(0, 1).flatten().squeeze()
|
||||||
angle_data2 = angle_data.flatten().squeeze()
|
angle_data2 = angle_data2.transpose(0,1).flatten().squeeze()
|
||||||
angle = angle.flatten().squeeze()
|
center_angle = center_angle.flatten().squeeze()
|
||||||
|
angles = angles.flatten().squeeze()
|
||||||
# data_timestamps = data_timestamps.flatten().squeeze()
|
# data_timestamps = data_timestamps.flatten().squeeze()
|
||||||
|
# target = target.transpose(0,1).flatten().squeeze()
|
||||||
target = target.flatten().squeeze()
|
target = target.flatten().squeeze()
|
||||||
target_timestamp = target_timestamp.flatten().squeeze()
|
target_timestamp = target_timestamp.flatten().squeeze()
|
||||||
|
plot_target = plot_target.flatten().squeeze()
|
||||||
|
plot_data = plot_data.flatten().squeeze()
|
||||||
|
plot_data_rot = plot_data_rot.flatten().squeeze()
|
||||||
|
|
||||||
return {"x": data, "y": target, "angle": angle, "sop": sop, "angle_data": angle_data, "angle_data2": angle_data2, "timestamp": target_timestamp, "plot_data": plot_data}
|
return {
|
||||||
|
"x": data,
|
||||||
|
"y": target,
|
||||||
|
"center_angle": center_angle,
|
||||||
|
"angles": angles,
|
||||||
|
"mean_angle": angles.mean(),
|
||||||
|
# "sop": sop,
|
||||||
|
"angle_data": angle_data,
|
||||||
|
"angle_data2": angle_data2,
|
||||||
|
"timestamp": target_timestamp,
|
||||||
|
"plot_target": plot_target,
|
||||||
|
"plot_data": plot_data,
|
||||||
|
"plot_data_rot": plot_data_rot,
|
||||||
|
}
|
||||||
|
|
||||||
def complex_max(self, data, dim=-1):
|
def complex_max(self, data, dim=-1):
|
||||||
# returns element(s) with the maximum absolute value along a given dimension
|
# returns element(s) with the maximum absolute value along a given dimension
|
||||||
@@ -376,7 +441,6 @@ class FiberRegenerationDataset(Dataset):
|
|||||||
# return max_values
|
# return max_values
|
||||||
return torch.gather(data, dim, torch.argmax(data.abs(), dim=dim, keepdim=True)).squeeze(dim=dim)
|
return torch.gather(data, dim, torch.argmax(data.abs(), dim=dim, keepdim=True)).squeeze(dim=dim)
|
||||||
|
|
||||||
|
|
||||||
def rotate(self, data, angle):
|
def rotate(self, data, angle):
|
||||||
# rotates a 2d tensor by a given angle
|
# rotates a 2d tensor by a given angle
|
||||||
# data: [2, ...]
|
# data: [2, ...]
|
||||||
@@ -389,6 +453,24 @@ class FiberRegenerationDataset(Dataset):
|
|||||||
|
|
||||||
return torch.stack([data[0] * cosine - data[1] * sine, data[0] * sine + data[1] * cosine], dim=0)
|
return torch.stack([data[0] * cosine - data[1] * sine, data[0] * sine + data[1] * cosine], dim=0)
|
||||||
|
|
||||||
|
def rotate_all(self):
|
||||||
|
def do_rotation(j, num_processes):
|
||||||
|
for i in range(len(self) // num_processes):
|
||||||
|
index = i * num_processes + j
|
||||||
|
self.data[index, 1, :2, :] = self.rotate(self.data[index, 1, :2, :], self.angles[index])
|
||||||
|
|
||||||
|
self.processes = []
|
||||||
|
|
||||||
|
for j in range(mp.cpu_count()):
|
||||||
|
self.processes.append(mp.Process(target=do_rotation, args=(j, mp.cpu_count())))
|
||||||
|
self.processes[-1].start()
|
||||||
|
|
||||||
|
for p in self.processes:
|
||||||
|
p.join()
|
||||||
|
|
||||||
|
for i in range(len(self) // mp.cpu_count() * mp.cpu_count(), len(self)):
|
||||||
|
self.data[i, 1, :2, :] = self.rotate(self.data[i, 1, :2, :], self.angles[i])
|
||||||
|
|
||||||
def polarimeter(self, data):
|
def polarimeter(self, data):
|
||||||
# data: [2, ...] -> x, y
|
# data: [2, ...] -> x, y
|
||||||
# returns [4] -> S0, S1, S2, S3
|
# returns [4] -> S0, S1, S2, S3
|
||||||
@@ -396,12 +478,12 @@ class FiberRegenerationDataset(Dataset):
|
|||||||
y = data[1].mean()
|
y = data[1].mean()
|
||||||
I_X = x.abs().square()
|
I_X = x.abs().square()
|
||||||
I_Y = y.abs().square()
|
I_Y = y.abs().square()
|
||||||
I_45 = (x+y).abs().square()
|
I_45 = (x + y).abs().square()
|
||||||
I_RHC = (x + 1j*y).abs().square()
|
I_RHC = (x + 1j * y).abs().square()
|
||||||
|
|
||||||
S0 = I_X + I_Y
|
S0 = I_X + I_Y
|
||||||
S1 = (2*I_X - S0) / S0
|
S1 = (2 * I_X - S0) / S0
|
||||||
S2 = (2*I_45 - S0) / S0
|
S2 = (2 * I_45 - S0) / S0
|
||||||
S3 = (2*I_RHC - S0) / S0
|
S3 = (2 * I_RHC - S0) / S0
|
||||||
|
|
||||||
return torch.stack([S1, S2, S3], dim=0)
|
return torch.stack([S0, S1, S2, S3], dim=0)
|
||||||
|
|||||||
Reference in New Issue
Block a user