Compare commits

..

2 Commits

10 changed files with 561 additions and 305 deletions

View File

@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:746ea83013870351296e01e294905ee027291ef79fd78c1e6b69dd9ebaa1cba0
size 10240000
oid sha256:76934d1d202aea1311ba67f5ea35eeb99a9c5c856f491565032e7d54ca6f072d
size 13598720

View File

@@ -26,6 +26,8 @@ import torch
import torch.optim as optim
import torch.utils.data
import hypertraining.models as models
from torch.utils.tensorboard import SummaryWriter
import multiprocessing
@@ -253,14 +255,17 @@ class HyperTraining:
model_kwargs = {
"dims": (input_dim, *hidden_dims, self.model_settings.output_dim),
"layer_function": layer_func,
"layer_parametrizations": layer_parametrizations,
"activation_function": afunc,
"layer_func_kwargs": self.model_settings.model_layer_kwargs,
"act_function": afunc,
"act_func_kwargs": None,
"parametrizations": layer_parametrizations,
"dtype": dtype,
"droupout_prob": self.model_settings.dropout_prob,
"scale": scale_layers,
"dropout_prob": self.model_settings.dropout_prob,
"scale_layers": scale_layers,
"rotate": False,
}
model = util.complexNN.regenerator(**model_kwargs)
model = models.regenerator(*model_kwargs.pop("dims"), **model_kwargs)
n_nodes = sum(hidden_dims)
if writer is not None:
@@ -381,7 +386,10 @@ class HyperTraining:
running_loss = 0.0
model.train()
loader_len = len(train_loader)
for batch_idx, (x, y, _) in enumerate(train_loader):
for batch_idx, batch in enumerate(train_loader):
x = batch["x"]
y = batch["y"]
if batch_idx >= self.optuna_settings._n_train_batches:
break
model.zero_grad(set_to_none=True)
@@ -390,7 +398,7 @@ class HyperTraining:
y.to(self.pytorch_settings.device),
)
y_pred = model(x)
loss = util.complexNN.complex_mse_loss(y_pred, y)
loss = util.complexNN.complex_mse_loss(y_pred, y, power=True)
loss_value = loss.item()
loss.backward()
optimizer.step()
@@ -444,7 +452,9 @@ class HyperTraining:
model.eval()
running_error = 0
with torch.no_grad():
for batch_idx, (x, y, _) in enumerate(valid_loader):
for batch_idx, batch in enumerate(valid_loader):
x = batch["x"]
y = batch["y"]
if batch_idx >= self.optuna_settings._n_valid_batches:
break
x, y = (
@@ -452,50 +462,44 @@ class HyperTraining:
y.to(self.pytorch_settings.device),
)
y_pred = model(x)
error = util.complexNN.complex_mse_loss(y_pred, y)
error = util.complexNN.complex_mse_loss(y_pred, y, power=True)
error_value = error.item()
running_error += error_value
running_error /= min(len(valid_loader), self.optuna_settings._n_valid_batches)
if writer is not None:
title_append, subtitle = self.build_title(trial)
writer.add_figure(
"fiber response",
self.plot_model_response(
trial,
model=model,
title_append=title_append,
subtitle=subtitle,
show=False,
),
epoch + 1,
)
writer.add_figure(
"eye diagram",
self.plot_model_response(
trial,
model=self.model,
title_append=title_append,
subtitle=subtitle,
show=False,
mode="eye",
),
epoch + 1,
writer.add_scalar(
"eval loss",
running_error,
epoch,
)
# if (epoch + 1) % 10 == 0 or epoch < 10:
# # plotting is slow, so only do it every 10 epochs
# title_append, subtitle = self.build_title(trial)
# head_fig, eye_fig, powers_fig = self.plot_model_response(
# model=model,
# title_append=title_append,
# subtitle=subtitle,
# show=False,
# )
# writer.add_figure(
# "fiber response",
# head_fig,
# epoch + 1,
# )
# writer.add_figure(
# "eye diagram",
# eye_fig,
# epoch + 1,
# )
writer.add_figure(
"powers",
self.plot_model_response(
trial,
model=self.model,
title_append=title_append,
subtitle=subtitle,
mode="powers",
show=False,
),
epoch + 1,
)
# writer.add_figure(
# "powers",
# powers_fig,
# epoch + 1,
# )
# writer.flush()
# if enable_progress:
# progress.stop()
@@ -511,15 +515,18 @@ class HyperTraining:
with torch.no_grad():
model = model.to(self.pytorch_settings.device)
for x, y, timestamp in loader:
for batch in loader:
x = batch["x"]
y = batch["y"]
timestamp = batch["timestamp"]
x, y = (
x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device),
)
if trace_powers:
y_pred, powers = model(x, trace_powers).cpu()
y_pred, powers = model(x, trace_powers=True).cpu()
else:
y_pred = model(x, trace_powers).cpu()
y_pred = model(x, trace_powers=True).cpu()
# x = x.cpu()
# y = y.cpu()
y_pred = y_pred.view(y_pred.shape[0], -1, 2)
@@ -539,7 +546,7 @@ class HyperTraining:
return fiber_in, fiber_out, regen, timestamps, powers
return fiber_in, fiber_out, regen, timestamps
def objective(self, trial: optuna.Trial, plot_before=False):
def objective(self, trial: optuna.Trial):
if self.stop_study:
trial.study.stop()
model = None
@@ -555,54 +562,54 @@ class HyperTraining:
title_append, subtitle = self.build_title(trial)
writer.add_figure(
"fiber response",
self.plot_model_response(
trial,
model=model,
title_append=title_append,
subtitle=subtitle,
show=False,
),
0,
)
writer.add_figure(
"eye diagram",
self.plot_model_response(
trial,
model=self.model,
title_append=title_append,
subtitle=subtitle,
mode="eye",
show=False,
),
0,
)
# writer.add_figure(
# "fiber response",
# self.plot_model_response(
# trial,
# model=model,
# title_append=title_append,
# subtitle=subtitle,
# show=False,
# ),
# 0,
# )
# writer.add_figure(
# "eye diagram",
# self.plot_model_response(
# trial,
# model=self.model,
# title_append=title_append,
# subtitle=subtitle,
# mode="eye",
# show=False,
# ),
# 0,
# )
writer.add_figure(
"powers",
self.plot_model_response(
trial,
model=self.model,
title_append=title_append,
subtitle=subtitle,
mode="powers",
show=False,
),
0,
)
# writer.add_figure(
# "powers",
# self.plot_model_response(
# trial,
# model=self.model,
# title_append=title_append,
# subtitle=subtitle,
# mode="powers",
# show=False,
# ),
# 0,
# )
train_loader, valid_loader = self.get_sliced_data(trial)
optimizer_name = trial.suggest_categorical_optional("optimizer", self.optimizer_settings.optimizer)
lr = trial.suggest_float_optional("lr", self.optimizer_settings.learning_rate, log=True)
lr = trial.suggest_float_optional("lr", self.optimizer_settings.optimizer_kwargs["lr"], log=True)
optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)
if self.optimizer_settings.scheduler is not None:
scheduler = getattr(optim.lr_scheduler, self.optimizer_settings.scheduler)(
optimizer, **self.optimizer_settings.scheduler_kwargs
)
# if self.optimizer_settings.scheduler is not None:
# scheduler = getattr(optim.lr_scheduler, self.optimizer_settings.scheduler)(
# optimizer, **self.optimizer_settings.scheduler_kwargs
# )
for epoch in range(self.pytorch_settings.epochs):
trial.set_user_attr("epoch", epoch)
@@ -628,8 +635,8 @@ class HyperTraining:
writer,
# enable_progress=enable_progress,
)
if self.optimizer_settings.scheduler is not None:
scheduler.step(error)
# if self.optimizer_settings.scheduler is not None:
# scheduler.step(error)
trial.set_user_attr("mse", error)
trial.set_user_attr("log_mse", np.log10(error + np.finfo(float).eps))
@@ -645,10 +652,10 @@ class HyperTraining:
if self.optuna_settings._multi_objective:
return -np.log10(error + np.finfo(float).eps), trial.user_attrs.get("model_n_nodes", -1)
if self.pytorch_settings.save_models and model is not None:
save_path = Path(self.pytorch_settings.model_dir) / f"{self.optuna_settings.study_name}_{trial.number}.pth"
save_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(model, save_path)
# if self.pytorch_settings.save_models and model is not None:
# save_path = Path(self.pytorch_settings.model_dir) / f"{self.optuna_settings.study_name}_{trial.number}.pth"
# save_path.parent.mkdir(parents=True, exist_ok=True)
# torch.save(model, save_path)
return error

View File

@@ -8,7 +8,8 @@ from util.complexNN import (
photodiode,
EOActivation,
polarimeter,
normalize_by_first
# normalize_by_first,
rotate,
)
@@ -19,11 +20,11 @@ class polarisation_estimator2(Module):
polarimeter(),
torch.nn.Linear(4, 4),
torch.nn.ReLU(),
torch.nn.Dropout(p=0.01),
# torch.nn.Dropout(p=0.01),
torch.nn.Linear(4, 4),
torch.nn.ReLU(),
torch.nn.Dropout(p=0.01),
torch.nn.Linear(4, 4),
# torch.nn.Dropout(p=0.01),
torch.nn.Linear(4, 1),
)
def forward(self, x):
@@ -124,6 +125,7 @@ class regenerator(Module):
dtype=torch.float64,
dropout_prob=0.01,
scale_layers=False,
rotate=False,
):
super(regenerator, self).__init__()
self._n_hidden_layers = len(dims) - 2
@@ -131,6 +133,8 @@ class regenerator(Module):
layer_func_kwargs = layer_func_kwargs or {}
act_func_kwargs = act_func_kwargs or {}
self.rotation = rotate
self.build_model(dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob, scale_layers)
def build_model(self, dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob, scale_layers):
@@ -146,8 +150,9 @@ class regenerator(Module):
module = act_function(size=dims[i + 1], **act_func_kwargs)
self.get_submodule(f"layer_{i}").add_module("activation", module)
module = DropoutComplex(p=dropout_prob)
self.get_submodule(f"layer_{i}").add_module("dropout", module)
if dropout_prob is not None and dropout_prob > 0:
module = DropoutComplex(p=dropout_prob)
self.get_submodule(f"layer_{i}").add_module("dropout", module)
self.add_module(f"layer_{self._n_hidden_layers}", Sequential())
@@ -160,6 +165,10 @@ class regenerator(Module):
module = act_function(size=dims[-1], **act_func_kwargs)
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("activation", module)
if self.rotation:
module = rotate()
self.add_module("rotate", module)
# module = Scale(size=dims[-1])
# self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("out_scale", module)
@@ -190,15 +199,27 @@ class regenerator(Module):
powers.append(x.abs().square().sum())
return powers
def forward(self, x, trace_powers=False):
def forward(self, x, angle=None, pre_rot=False, trace_powers=False):
powers = self._trace_powers(trace_powers, x)
x = self.layer_0(x)
powers = self._trace_powers(trace_powers, x, powers)
for i in range(1, self._n_hidden_layers):
# x = self.layer_0(x)
# powers = self._trace_powers(trace_powers, x, powers)
for i in range(0, self._n_hidden_layers):
x = getattr(self, f"layer_{i}")(x)
powers = self._trace_powers(trace_powers, x, powers)
x = getattr(self, f"layer_{self._n_hidden_layers}")(x)
powers = self._trace_powers(trace_powers, x, powers)
if trace_powers:
return x, powers
return x
if self.rotation:
try:
x_rot = self.rotate(x, angle)
except AttributeError:
pass
powers = self._trace_powers(trace_powers, x_rot, powers)
else:
x_rot = x
if pre_rot and trace_powers:
return x_rot, x, powers
if pre_rot and not trace_powers:
return x_rot, x
if not pre_rot and trace_powers:
return x_rot, powers
return x_rot

View File

@@ -22,6 +22,8 @@ class DataSettings:
train_split: float = 0.8
polarisations: tuple | list = (0,)
randomise_polarisations: bool = False
osnr: float | int = None
seed: int = None
"""
change to:

View File

@@ -2,7 +2,6 @@ import copy
from datetime import datetime
from pathlib import Path
import random
from typing import Literal
import matplotlib
from matplotlib.colors import LinearSegmentedColormap
import torch.nn.utils.parametrize
@@ -60,33 +59,35 @@ def traverse_dict_update(target, source):
except TypeError:
target.__dict__[k] = v
def get_parameter_names_and_values(model):
def is_parametrized(module):
if hasattr(module, "parametrizations"):
return True
return False
def _get_param_info(module, prefix='', parametrization=False):
def _get_param_info(module, prefix="", parametrization=False):
param_list = []
for name, param in module.named_parameters(recurse = parametrization):
for name, param in module.named_parameters(recurse=parametrization):
if parametrization and name.startswith("parametrizations"):
name_parts = name.split('.')
name_parts = name.split(".")
name = name_parts[1]
param = getattr(module, name)
full_name = prefix + ('.' if prefix else '') + name
full_name = prefix + ("." if prefix else "") + name
param_value = param.data
param_list.append((full_name, param_value))
for child_name, child_module in module.named_children():
child_prefix = prefix + ('.' if prefix else '') + child_name
child_prefix = prefix + ("." if prefix else "") + child_name
if child_name == "parametrizations":
continue
param_list.extend(_get_param_info(child_module, child_prefix, is_parametrized(child_module)))
return param_list
return _get_param_info(model)
class PolarizationTrainer:
def __init__(
self,
@@ -101,7 +102,7 @@ class PolarizationTrainer:
settings_override=None,
reset_epoch=False,
):
self.mod = torch.pi/2
self.mod = torch.pi / 2
self.resume = checkpoint_path is not None
torch.serialization.add_safe_globals([
*util.complexNN.__all__,
@@ -219,7 +220,7 @@ class PolarizationTrainer:
# dims = self.model_kwargs.pop("dims")
model_kwargs = copy.deepcopy(self.model_kwargs)
self.model = models.polarisation_estimator(*model_kwargs.pop('dims'),**model_kwargs)
self.model = models.polarisation_estimator(*model_kwargs.pop("dims"), **model_kwargs)
# self.model = models.polarisation_estimator2()
if self.writer is not None:
@@ -336,17 +337,20 @@ class PolarizationTrainer:
write_div = 0
loss_div = 0
for batch_idx, batch in enumerate(train_loader):
x = batch["x"]
y = batch["sop"]
x = batch["angle_data2"]
y = batch["center_angle"]
self.model.zero_grad(set_to_none=True)
x, y = (
x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device),
)
y_pred = self.model(x)
y_pred = self.model(x).abs().real
# y_pred = torch.fmod(y_pred, self.mod)
y = y.abs().real
# y = torch.fmod(y, self.mod)
# loss = torch.nn.functional.mse_loss(torch.cos(y_pred), torch.cos(y))
# loss = torch.nn.functional.smooth_l1_loss(torch.cos(torch.fmod(y_pred, torch.pi/2)).squeeze(), torch.cos(torch.fmod(y, torch.pi/2)).squeeze(), beta=0.5)
loss = torch.nn.functional.mse_loss(y_pred, y)
# loss = util.complexNN.naive_angle_loss(y_pred, y, mod=torch.pi/2)
loss = util.complexNN.naive_angle_loss(y_pred, y, mod=self.mod)
loss_value = loss.item()
loss.backward()
optimizer.step()
@@ -356,7 +360,7 @@ class PolarizationTrainer:
loss_div += 1
if enable_progress:
progress.update(task, advance=1, description=f"{loss_value:.3e}")
progress.update(task, advance=1, description=f"{loss_value/np.pi*180:.3e} °")
if batch_idx % self.pytorch_settings.write_every == 0:
self.writer.add_scalar(
@@ -395,24 +399,28 @@ class PolarizationTrainer:
loss_div = 0
with torch.no_grad():
for _, batch in enumerate(valid_loader):
x = batch["x"]
y = batch["sop"]
# x = batch["angle_data2"]
x = batch["angle_data2"]
y = batch["center_angle"]
x, y = (
x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device),
)
y_pred = self.model(x)
y_pred = self.model(x).abs().real
# y_pred = torch.fmod(y_pred, self.mod)
y = y.abs().real
# y = torch.fmod(y, self.mod)
# loss = torch.nn.functional.mse_loss(torch.cos(y_pred), torch.cos(y))
# loss = torch.nn.functional.smooth_l1_loss(torch.cos(torch.fmod(y_pred, torch.pi/2)).squeeze(), torch.cos(torch.fmod(y, torch.pi/2)).squeeze(), beta=0.5)
loss = torch.nn.functional.mse_loss(y_pred, y)
# loss = util.complexNN.naive_angle_loss(y_pred, y, mod=torch.pi/2)
loss = util.complexNN.naive_angle_loss(y_pred, y, mod=self.mod)
loss_value = loss.item()
running_loss += loss_value
loss_div += 1
if enable_progress:
progress.update(task, advance=1, description=f"{loss_value:.3e}")
progress.update(task, advance=1, description=f"{loss_value/np.pi*180:.3e} °")
running_loss = running_loss/loss_div
running_loss = running_loss / loss_div
self.writer.add_scalar(
"eval loss",
@@ -506,19 +514,19 @@ class PolarizationTrainer:
for i, config_path in enumerate(self.data_settings.config_path):
paths = Path.cwd().glob(config_path)
for j, path in enumerate(paths):
text = str(path) + '\n'
with open(path, 'r') as f:
text = str(path) + "\n"
with open(path, "r") as f:
text += f.read()
text += '\n'
self.writer.add_text(f"config_{i*len(self.data_settings.config_path)+j}", text)
text += "\n"
self.writer.add_text(f"config_{i * len(self.data_settings.config_path) + j}", text)
elif isinstance(self.data_settings.config_path, str):
paths = Path.cwd().glob(self.data_settings.config_path)
for j, path in enumerate(paths):
text = str(path) + '\n'
with open(path, 'r') as f:
text = str(path) + "\n"
with open(path, "r") as f:
text += f.read()
text += '\n'
text += "\n"
self.writer.add_text(f"config_{j}", text)
self.writer.flush()
@@ -571,7 +579,8 @@ class PolarizationTrainer:
if loss < self.best["loss"]:
self.best = checkpoint
save_path = (
Path(self.pytorch_settings.model_dir) / f"best_pol_{self.writer.get_logdir().split('/')[-1]}.tar"
Path(self.pytorch_settings.model_dir)
/ f"best_pol_{self.writer.get_logdir().split('/')[-1]}.tar"
)
save_path.parent.mkdir(parents=True, exist_ok=True)
self.save_checkpoint(self.best, save_path)
@@ -580,6 +589,7 @@ class PolarizationTrainer:
self.writer.close()
return self.best
class RegenerationTrainer:
def __init__(
self,
@@ -636,6 +646,10 @@ class RegenerationTrainer:
self.model_settings: ModelSettings = model_settings
self.optimizer_settings: OptimizerSettings = optimizer_settings
if self.global_settings.seed is not None:
random.seed(self.global_settings.seed)
np.random.seed(self.global_settings.seed)
self.console = console or Console()
self.writer = None
@@ -706,10 +720,12 @@ class RegenerationTrainer:
# dims = self.model_kwargs.pop("dims")
model_kwargs = copy.deepcopy(self.model_kwargs)
self.model = models.regenerator(*model_kwargs.pop('dims'),**model_kwargs)
self.model = models.regenerator(*model_kwargs.pop("dims"), **model_kwargs)
if self.writer is not None:
self.writer.add_graph(self.model, torch.zeros(1, input_dim, dtype=dtype))
self.writer.add_graph(
self.model, (torch.rand(1, input_dim, dtype=dtype), torch.rand(1, 1, dtype=dtype.to_real()))
)
self.model = self.model.to(self.pytorch_settings.device)
if self.resume:
@@ -728,12 +744,12 @@ class RegenerationTrainer:
num_symbols = None
config_path = self.data_settings.config_path
polarisations = self.data_settings.polarisations
randomise_polarisations = self.data_settings.randomise_polarisations
osnr = self.data_settings.osnr
if override is not None:
num_symbols = override.get("num_symbols", None)
config_path = override.get("config_path", config_path)
polarisations = override.get("polarisations", polarisations)
# polarisations = override.get("polarisations", polarisations)
randomise_polarisations = override.get("randomise_polarisation", randomise_polarisations)
# get dataset
dataset = FiberRegenerationDataset(
@@ -746,8 +762,8 @@ class RegenerationTrainer:
dtype=dtype,
real=not dtype.is_complex,
num_symbols=num_symbols,
polarisations=polarisations,
randomise_polarisations=randomise_polarisations,
osnr = osnr,
)
dataset_size = len(dataset)
@@ -819,12 +835,14 @@ class RegenerationTrainer:
for batch_idx, batch in enumerate(train_loader):
x = batch["x"]
y = batch["y"]
angles = batch["mean_angle"]
self.model.zero_grad(set_to_none=True)
x, y = (
x, y, angles = (
x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device),
angles.to(self.pytorch_settings.device),
)
y_pred = self.model(x)
y_pred = self.model(x, -angles)
loss = util.complexNN.complex_mse_loss(y_pred, y, power=True)
loss_value = loss.item()
loss.backward()
@@ -872,11 +890,13 @@ class RegenerationTrainer:
for _, batch in enumerate(valid_loader):
x = batch["x"]
y = batch["y"]
x, y = (
angles = batch["mean_angle"]
x, y, angles = (
x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device),
angles.to(self.pytorch_settings.device),
)
y_pred = self.model(x)
y_pred = self.model(x, -angles)
error = util.complexNN.complex_mse_loss(y_pred, y, power=True)
error_value = error.item()
running_error += error_value
@@ -884,7 +904,7 @@ class RegenerationTrainer:
if enable_progress:
progress.update(task, advance=1, description=f"{error_value:.3e}")
running_error = running_error/len(valid_loader)
running_error = running_error / len(valid_loader)
self.writer.add_scalar(
"eval loss",
@@ -928,45 +948,65 @@ class RegenerationTrainer:
def run_model(self, model, loader, trace_powers=False):
model.eval()
fiber_out = []
fiber_out_rot = []
fiber_in = []
regen = []
timestamps = []
angles = []
with torch.no_grad():
model = model.to(self.pytorch_settings.device)
for batch in loader:
x = batch["x"]
y = batch["y"]
plot_target = batch["plot_target"]
angle = batch["mean_angle"]
center_angle = batch["center_angle"]
timestamp = batch["timestamp"]
plot_data = batch["plot_data"]
x, y = (
plot_data_rot = batch["plot_data_rot"]
x, y, angle = (
x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device),
angle.to(self.pytorch_settings.device),
)
if trace_powers:
y_pred, powers = model(x, trace_powers).cpu()
y_pred, powers = model(x, angle, True).cpu()
else:
y_pred = model(x, trace_powers).cpu()
y_pred = model(x, angle).cpu()
# x = x.cpu()
# y = y.cpu()
y_pred = y_pred.view(y_pred.shape[0], -1, 2)
y_pred = y_pred[:, y_pred.shape[1]//2, :]
y = y.view(y.shape[0], -1, 2)
plot_data = plot_data.view(plot_data.shape[0], -1, 2)
# plot_data = plot_data.view(plot_data.shape[0], -1, 2)
# c = torch.cos(-angle).cpu()
# s = torch.sin(-angle).cpu()
# rot = torch.stack([torch.stack([c, -s], dim=1), torch.stack([s, c], dim=1)], dim=2).squeeze(-1)
# plot_data = torch.bmm(plot_data, rot.to(dtype=plot_data.dtype))
# plot_data = plot_data
# sines = torch.sin(-angle.cpu())
# cosines = torch.cos(-angle.cpu())
# plot_data = torch.stack((plot_data[..., 0] * cosines - plot_data[..., 1] * sines, plot_data[..., 0] * sines + plot_data[..., 1] * cosines), dim=-1)
# x = x.view(x.shape[0], -1, 2)
# timestamp = timestamp.view(-1, 1)
fiber_out.append(plot_data.squeeze())
fiber_in.append(y.squeeze())
fiber_out_rot.append(plot_data_rot.squeeze())
fiber_in.append(plot_target.squeeze())
regen.append(y_pred.squeeze())
timestamps.append(timestamp.squeeze())
angles.append(center_angle.squeeze())
fiber_out = torch.vstack(fiber_out).cpu()
fiber_out_rot = torch.vstack(fiber_out_rot).cpu()
fiber_in = torch.vstack(fiber_in).cpu()
regen = torch.vstack(regen).cpu()
angles = torch.vstack(angles).cpu()
timestamps = torch.concat(timestamps).cpu()
if trace_powers:
return fiber_in, fiber_out, regen, timestamps, powers
return fiber_in, fiber_out, regen, timestamps
return fiber_in, fiber_out, fiber_out_rot, angles, regen, timestamps, powers
return fiber_in, fiber_out, fiber_out_rot, angles, regen, timestamps
def write_parameters(self, epoch, attributes: list[str] | tuple[str] = None):
parameter_list = get_parameter_names_and_values(self.model)
@@ -1027,18 +1067,18 @@ class RegenerationTrainer:
for i, config_path in enumerate(self.data_settings.config_path):
paths = Path.cwd().glob(config_path)
for j, path in enumerate(paths):
text = str(path) + '\n'
with open(path, 'r') as f:
text = str(path) + "\n"
with open(path, "r") as f:
text += f.read()
text += '\n'
self.writer.add_text(f"config_{i*len(self.data_settings.config_path)+j}", text)
text += "\n"
self.writer.add_text(f"config_{i * len(self.data_settings.config_path) + j}", text)
elif isinstance(self.data_settings.config_path, str):
paths = Path.cwd().glob(self.data_settings.config_path)
for j, path in enumerate(paths):
text = str(path) + '\n'
with open(path, 'r') as f:
text = str(path) + "\n"
with open(path, "r") as f:
text += f.read()
text += '\n'
text += "\n"
self.writer.add_text(f"config_{j}", text)
self.writer.flush()
@@ -1116,6 +1156,7 @@ class RegenerationTrainer:
powers = [power / powers[0] for power in powers]
fig, ax = plt.subplots()
fig.set_figwidth(18)
fig.set_figheight(4)
fig.suptitle(
f"Energy conservation{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}"
)
@@ -1184,6 +1225,7 @@ class RegenerationTrainer:
fig, axs = plt.subplots(1, 2 * len(signals), sharex=True, sharey=True)
fig.set_figwidth(18)
fig.set_figheight(4)
fig.suptitle(f"Eye diagram{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}")
# xaxis = timestamps / sps
# xaxis = np.arange(2 * sps) / sps
@@ -1253,7 +1295,7 @@ class RegenerationTrainer:
xaxis = timestamps / sps
else:
xaxis = timestamps
ax.plot(xaxis, np.abs(signal[:, i]) ** 2, label=label)
ax.plot(xaxis, np.abs(signal[:, i]) ** 2, label=label, alpha=0.7)
ax.set_xlabel("Sample" if sps is None else "Symbol")
ax.set_ylabel("normalized power")
ax.minorticks_on()
@@ -1269,7 +1311,7 @@ class RegenerationTrainer:
def plot_model_response(
self,
model:torch.nn.Module=None,
model: torch.nn.Module = None,
title_append="",
subtitle="",
# mode: Literal["eye", "head", "powers"] = "head",
@@ -1281,7 +1323,9 @@ class RegenerationTrainer:
model = model.to(self.pytorch_settings.device)
model.eval()
with torch.no_grad():
_, powers = model(input_data, trace_powers=True)
_, powers = model(
input_data, torch.zeros(input_data.shape[0], 1).to(self.pytorch_settings.device), trace_powers=True
)
powers = [power.item() for power in powers]
layer_names = [name for (name, _) in model.named_children()]
@@ -1296,29 +1340,42 @@ class RegenerationTrainer:
self.data_settings.shuffle = False
self.data_settings.train_split = 1.0
self.pytorch_settings.batchsize = max(self.pytorch_settings.head_symbols, self.pytorch_settings.eye_symbols)
config_path = random.choice(self.data_settings.config_path) if isinstance(self.data_settings.config_path, (list, tuple)) else self.data_settings.config_path
fiber_length = int(float(str(config_path).split('-')[4])/1000)
config_path = (
random.choice(self.data_settings.config_path)
if isinstance(self.data_settings.config_path, (list, tuple))
else self.data_settings.config_path
)
fiber_length = int(float(str(config_path).split("-")[4]) / 1000)
if not hasattr(self, "_plot_loader"):
self._plot_loader, _ = self.get_sliced_data(
override={
"num_symbols": self.pytorch_settings.batchsize,
"config_path": config_path,
"shuffle": False,
"polarisations": (np.random.rand(1)*np.pi*2,),
"randomise_polarisation": False,
"polarisations": (np.random.rand(1) * np.pi * 2,),
"randomise_polarisation": self.data_settings.randomise_polarisations,
}
)
self._sps = self._plot_loader.dataset.samples_per_symbol
self.data_settings = data_settings_backup
self.pytorch_settings = pytorch_settings_backup
fiber_in, fiber_out, regen, timestamps = self.run_model(model, self._plot_loader)
fiber_in, fiber_out, fiber_out_rot, angles, regen, timestamps = self.run_model(model, self._plot_loader)
fiber_in = fiber_in.view(-1, 2)
fiber_out = fiber_out.view(-1, 2)
fiber_out_rot = fiber_out_rot.view(-1, 2)
angles = angles.view(-1, 1)
angles = angles.real
angles = torch.fmod(angles, 2 * torch.pi)
angles = torch.div(angles, 2*torch.pi)
angles = torch.repeat_interleave(angles, 2, dim=1)
regen = regen.view(-1, 2)
fiber_in = fiber_in.numpy()
fiber_out = fiber_out.numpy()
fiber_out_rot = fiber_out_rot.numpy()
angles = angles.numpy()
regen = regen.numpy()
timestamps = timestamps.numpy()
@@ -1327,28 +1384,29 @@ class RegenerationTrainer:
import gc
head_fig = self._plot_model_response_head(
fiber_in[:self.pytorch_settings.head_symbols*self._sps],
fiber_out[:self.pytorch_settings.head_symbols*self._sps],
regen[:self.pytorch_settings.head_symbols*self._sps],
timestamps=timestamps[:self.pytorch_settings.head_symbols*self._sps],
fiber_out_rot[: self.pytorch_settings.head_symbols * self._sps],
fiber_in[: self.pytorch_settings.head_symbols * self._sps],
regen[: self.pytorch_settings.head_symbols * self._sps],
angles[: self.pytorch_settings.head_symbols * self._sps],
timestamps=timestamps[: self.pytorch_settings.head_symbols * self._sps],
labels=("fiber out", "fiber in", "regen", "normed angle"),
sps=self._sps,
title_append=title_append + f" ({fiber_length} km)",
subtitle=subtitle,
show=show,
)
# raise NotImplementedError("Eye diagram not implemented")
eye_fig = self._plot_model_response_eye(
fiber_in[: self.pytorch_settings.eye_symbols * self._sps],
fiber_out_rot[: self.pytorch_settings.eye_symbols * self._sps],
regen[: self.pytorch_settings.eye_symbols * self._sps],
timestamps=timestamps[: self.pytorch_settings.eye_symbols * self._sps],
labels=("fiber in", "fiber out", "regen"),
sps=self._sps,
title_append=title_append + f" ({fiber_length} km)",
subtitle=subtitle,
show=show,
)
# raise NotImplementedError("Eye diagram not implemented")
eye_fig = self._plot_model_response_eye(
fiber_in[:self.pytorch_settings.eye_symbols*self._sps],
fiber_out[:self.pytorch_settings.eye_symbols*self._sps],
regen[:self.pytorch_settings.eye_symbols*self._sps],
timestamps=timestamps[:self.pytorch_settings.eye_symbols*self._sps],
labels=("fiber in", "fiber out", "regen"),
sps=self._sps,
title_append=title_append + f" ({fiber_length} km)",
subtitle=subtitle,
show=show,
)
gc.collect()
return head_fig, eye_fig, power_fig
@@ -1361,7 +1419,7 @@ class RegenerationTrainer:
self.model_settings.overrides.get(f"n_hidden_nodes_{i}", -1) for i in range(model_n_hidden_layers)
]
model_dims.insert(0, input_dim)
model_dims.append(2)
model_dims.append(self.model_settings.output_dim)
model_dims = [str(dim) for dim in model_dims]
model_activation_func = self.model_settings.model_activation_func
model_dtype = self.data_settings.dtype

View File

@@ -1,6 +1,8 @@
from datetime import datetime
import optuna
import torch
import util
from hypertraining.hypertraining import HyperTraining
from hypertraining.settings import (
GlobalSettings,
@@ -16,24 +18,29 @@ global_settings = GlobalSettings(
)
data_settings = DataSettings(
config_path="data/*-128-16384-100000-0-0-17-0-PAM4-0.ini",
# config_path="data/*-128-16384-100000-0-0-17-0-PAM4-0.ini",
config_path="data/20241204-131003-128-16384-100000-0-0-17-0-PAM4-0.ini",
dtype="complex64",
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
symbols=13, # study: single_core_regen_20241123_011232
# symbols=13, # study: single_core_regen_20241123_011232
# symbols = (3, 13),
symbols=4,
# output_size = (11, 32), # ballpark 26 taps -> 2 taps per input symbol -> 1 tap every 0.01m (model has 52 inputs)
output_size=26, # study: single_core_regen_20241123_011232 (model_input_dim/2)
# output_size=26, # study: single_core_regen_20241123_011232 (model_input_dim/2)
output_size=(8, 30),
shuffle=True,
in_out_delay=0,
xy_delay=0,
drop_first=128 * 100,
drop_first=256,
train_split=0.8,
randomise_polarisations=False,
)
pytorch_settings = PytorchSettings(
epochs=10000,
epochs=10,
batchsize=2**10,
device="cuda",
dataloader_workers=12,
dataloader_workers=4,
dataloader_prefetch=4,
summary_dir=".runs",
write_every=2**5,
@@ -43,28 +50,70 @@ pytorch_settings = PytorchSettings(
model_settings = ModelSettings(
output_dim=2,
# n_hidden_layers = (3, 8),
n_hidden_layers=4,
overrides={
"n_hidden_nodes_0": 8,
"n_hidden_nodes_1": 6,
"n_hidden_nodes_2": 4,
"n_hidden_nodes_3": 8,
},
model_activation_func="Mag",
# satabsT0=(1e-6, 1),
n_hidden_layers = (2, 5),
n_hidden_nodes=(2, 16),
model_activation_func="EOActivation",
dropout_prob=0,
model_layer_function="ONNRect",
model_layer_kwargs={"square": True},
# scale=(False, True),
scale=False,
model_layer_parametrizations=[
{
"tensor_name": "weight",
"parametrization": util.complexNN.energy_conserving,
},
{
"tensor_name": "alpha",
"parametrization": util.complexNN.clamp,
},
{
"tensor_name": "gain",
"parametrization": util.complexNN.clamp,
"kwargs": {
"min": 0,
"max": float("inf"),
},
},
{
"tensor_name": "phase_bias",
"parametrization": util.complexNN.clamp,
"kwargs": {
"min": 0,
"max": 2 * torch.pi,
},
},
{
"tensor_name": "scales",
"parametrization": util.complexNN.clamp,
},
{
"tensor_name": "angle",
"parametrization": util.complexNN.clamp,
"kwargs": {
"min": -torch.pi,
"max": torch.pi,
},
},
{
"tensor_name": "loss",
"parametrization": util.complexNN.clamp,
},
],
)
optimizer_settings = OptimizerSettings(
optimizer="Adam",
# learning_rate = (1e-5, 1e-1),
learning_rate=5e-3
# learning_rate=5e-4,
optimizer="AdamW",
optimizer_kwargs={
"lr": 5e-3,
"amsgrad": True,
# "weight_decay": 1e-7,
},
)
optuna_settings = OptunaSettings(
n_trials=1,
n_workers=1,
n_trials=1024,
n_workers=8,
timeout=3600,
directions=("minimize",),
metrics_names=("mse",),

View File

@@ -26,24 +26,26 @@ global_settings = GlobalSettings(
)
data_settings = DataSettings(
config_path="data/*-128-16384-10000-0-0-0-0-PAM4-0-0.4.ini",
# config_path="data/*-128-16384-1-0-0-0-0-PAM4-0-0.ini",
config_path="data/*-128-16384-10000-0-0-17-0-PAM4-0.ini",
# config_path=[f"data/20241202-*-128-16384-{length}-0-0-17-0-PAM4-0.ini" for length in range(48000, 53000, 1000)],
dtype="complex64",
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
symbols=13, # study: single_core_regen_20241123_011232
symbols=4, # study: single_core_regen_20241123_011232
# output_size = (11, 32), # 26 taps -> 2 taps per input symbol -> 1 tap every 1cm (model has 52 inputs (x/y))
output_size=26, # study: single_core_regen_20241123_011232 (model_input_dim/2)
output_size=20, # study: single_core_regen_20241123_011232 (model_input_dim/2)
shuffle=True,
drop_first=64,
train_split=0.8,
randomise_polarisations=True,
osnr=10,
)
pytorch_settings = PytorchSettings(
epochs=10000,
batchsize=2**14,
device="cuda",
dataloader_workers=16,
dataloader_workers=24,
dataloader_prefetch=8,
summary_dir=".runs",
write_every=2**5,
@@ -53,17 +55,17 @@ pytorch_settings = PytorchSettings(
model_settings = ModelSettings(
output_dim=2,
n_hidden_layers=5,
n_hidden_layers=3,
overrides={
# "hidden_layer_dims": (8, 8, 4, 4),
"n_hidden_nodes_0": 8,
"n_hidden_nodes_0": 16,
"n_hidden_nodes_1": 8,
"n_hidden_nodes_2": 4,
"n_hidden_nodes_3": 4,
"n_hidden_nodes_4": 2,
"n_hidden_nodes_2": 8,
# "n_hidden_nodes_3": 4,
# "n_hidden_nodes_4": 2,
},
model_activation_func="EOActivation",
dropout_prob=0.01,
dropout_prob=0,
model_layer_function="ONNRect",
model_layer_kwargs={"square": True},
scale=False,
@@ -126,7 +128,7 @@ model_settings = ModelSettings(
optimizer_settings = OptimizerSettings(
optimizer="AdamW",
optimizer_kwargs={
"lr": 0.01,
"lr": 0.005,
"amsgrad": True,
# "weight_decay": 1e-7,
},
@@ -242,7 +244,15 @@ if __name__ == "__main__":
pytorch_settings=pytorch_settings,
model_settings=model_settings,
optimizer_settings=optimizer_settings,
# checkpoint_path=".models/best_20241205_235929.tar",
checkpoint_path=".models/best_20241216_221359.tar",
reset_epoch=True,
# settings_override={
# "optimizer_settings": {
# "optimizer_kwargs": {
# "lr": 0.01,
# },
# }
# }
# 20241202_143149
)
trainer.train()

View File

@@ -26,7 +26,7 @@ global_settings = GlobalSettings(
)
data_settings = DataSettings(
config_path="data/*-128-16384-10000-0-0-0-0-PAM4-0-0.4.ini",
config_path="data/20241211-105524-128-16384-1-0-0-0-0-PAM4-0-0.ini",
# config_path=[f"data/20241202-*-128-16384-{length}-0-0-17-0-PAM4-0.ini" for length in range(48000, 53000, 1000)],
dtype="complex64",
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
@@ -53,14 +53,14 @@ pytorch_settings = PytorchSettings(
)
model_settings = ModelSettings(
output_dim=3,
output_dim=1,
n_hidden_layers=3,
overrides={
"n_hidden_nodes_0": 2,
"n_hidden_nodes_1": 2,
"n_hidden_nodes_2": 2,
"n_hidden_nodes_0": 4,
"n_hidden_nodes_1": 4,
"n_hidden_nodes_2": 4,
},
dropout_prob=0.01,
dropout_prob=0,
model_layer_function="ONNRect",
model_activation_func="EOActivation",
model_layer_kwargs={"square": True},
@@ -110,20 +110,24 @@ model_settings = ModelSettings(
)
optimizer_settings = OptimizerSettings(
optimizer="AdamW",
optimizer="RMSprop",
# optimizer="AdamW",
optimizer_kwargs={
"lr": 0.005,
"amsgrad": True,
"lr": 0.01,
"alpha": 0.9,
"momentum": 0.1,
"eps": 1e-8,
"centered": True,
# "amsgrad": True,
# "weight_decay": 1e-7,
},
# learning_rate=0.05,
scheduler="ReduceLROnPlateau",
scheduler_kwargs={
"patience": 2**6,
"patience": 2**5,
"factor": 0.75,
# "threshold": 1e-3,
"min_lr": 1e-6,
"cooldown": 10,
# "cooldown": 10,
},
)

View File

@@ -319,6 +319,29 @@ class normalize_by_first(nn.Module):
def forward(self, data):
return data / data[:, 0].unsqueeze(1)
class rotate(nn.Module):
def __init__(self):
super(rotate, self).__init__()
def forward(self, data, angle):
# data -> (batch, n*2)
# angle -> (batch, n)
data_ = data
if angle.ndim == 1:
angle_ = angle.unsqueeze(1)
else:
angle_ = angle
angle_ = angle_.expand(-1, data_.shape[1]//2)
c = torch.cos(angle_)
s = torch.sin(angle_)
rot = torch.stack([torch.stack([c, -s], dim=2),
torch.stack([s, c], dim=2)], dim=3)
d = torch.bmm(data_.reshape(-1, 1, 2), rot.view(-1, 2, 2).to(dtype=data_.dtype)).reshape(*data.shape)
# d = torch.bmm(data.unsqueeze(-1).mT, rot.to(dtype=data.dtype).mT).mT.squeeze(-1)
return d
class photodiode(nn.Module):
def __init__(self, size, bias=True):
@@ -487,7 +510,7 @@ class MZISingle(nn.Module):
return x * torch.exp(1j * self.phi) * torch.sin(self.omega + self.func(x))
def naive_angle_loss(x: torch.Tensor, target: torch.Tensor, mod=2*torch.pi):
return torch.fmod((x - target), mod).square().mean()
return torch.fmod((x.abs().real - target.abs().real), mod).abs().mean()
def cosine_loss(x: torch.Tensor, target: torch.Tensor):
return (2*(1 - torch.cos(x - target))).mean()

View File

@@ -5,6 +5,7 @@ from torch.utils.data import Dataset
# from torch.utils.data import Sampler
import numpy as np
import configparser
import multiprocessing as mp
# class SubsetSampler(Sampler[int]):
# """
@@ -113,8 +114,9 @@ class FiberRegenerationDataset(Dataset):
dtype: torch.dtype = None,
real: bool = False,
device=None,
polarisations: tuple | list = (0,),
osnr: float = None,
randomise_polarisations: bool = False,
repeat_randoms: int = 1,
**kwargs,
):
"""
@@ -190,18 +192,20 @@ class FiberRegenerationDataset(Dataset):
files.append(config["data"]["file"].strip('"'))
self.config["data"]["file"] = str(files)
for i, angle in enumerate(torch.tensor(np.array(polarisations))):
data_raw_copy = data_raw.clone()
if angle == 0:
continue
sine = torch.sin(angle)
cosine = torch.cos(angle)
data_raw_copy[:, 2] = data_raw[:, 2] * cosine - data_raw[:, 3] * sine
data_raw_copy[:, 3] = data_raw[:, 2] * sine + data_raw[:, 3] * cosine
if i == 0:
data_raw = data_raw_copy
else:
data_raw = torch.cat([data_raw, data_raw_copy], dim=0)
# if polarisations is not None:
# self.angles = torch.tensor(polarisations).repeat(len(data_raw), 1)
# for i, angle in enumerate(torch.tensor(np.array(polarisations))):
# data_raw_copy = data_raw.clone()
# if angle == 0:
# continue
# sine = torch.sin(angle)
# cosine = torch.cos(angle)
# data_raw_copy[:, 2] = data_raw[:, 2] * cosine - data_raw[:, 3] * sine
# data_raw_copy[:, 3] = data_raw[:, 2] * sine + data_raw[:, 3] * cosine
# if i == 0:
# data_raw = data_raw_copy
# else:
# data_raw = torch.cat([data_raw, data_raw_copy], dim=0)
self.device = data_raw.device
@@ -278,23 +282,61 @@ class FiberRegenerationDataset(Dataset):
timestamps = data_raw[4, :]
data_raw = data_raw[:4, :]
data_raw = data_raw.view(2, 2, -1)
timestamps_doubled = torch.cat([timestamps.unsqueeze(dim=0), timestamps.unsqueeze(dim=0)], dim=0).unsqueeze(
dim=1
)
data_raw = torch.cat([data_raw, timestamps_doubled], dim=1)
fiber_in = data_raw[0, :, :]
fiber_out = data_raw[1, :, :]
# timestamps_doubled = torch.cat([timestamps.unsqueeze(dim=0), timestamps.unsqueeze(dim=0)], dim=0).unsqueeze(
# dim=1
# )
fiber_in = torch.cat([fiber_in, timestamps.unsqueeze(0)], dim=0)
fiber_out = torch.cat([fiber_out, timestamps.unsqueeze(0)], dim=0)
if repeat_randoms > 1:
fiber_in = fiber_in.repeat(1, 1, repeat_randoms)
fiber_out = fiber_out.repeat(1, 1, repeat_randoms)
# review: potential problems with repeated timestamps when plotting
else:
repeat_randoms = 1
if self.randomise_polarisations:
angles = torch.fmod(torch.randperm(data_raw.shape[-1]*repeat_randoms), 2) * torch.pi
# start_angle = torch.rand(1) * 2 * torch.pi
# angles = start_angle + torch.cumsum(torch.randn(data_raw.shape[-1])/333, dim=0) * torch.pi * 2 # random walk
# self.angles = torch.rand(self.data.shape[0]) * 2 * torch.pi
else:
angles = torch.zeros(data_raw.shape[-1])
sin = torch.sin(angles)
cos = torch.cos(angles)
rot = torch.stack([torch.stack([cos, -sin], dim=1), torch.stack([sin, cos], dim=1)], dim=2)
data_rot = torch.bmm(fiber_out[:2, :].T.unsqueeze(1), rot.to(dtype=data_raw.dtype)).squeeze(1).T
fiber_out = torch.cat((fiber_out, data_rot), dim=0)
fiber_out = torch.cat([fiber_out, angles.unsqueeze(0)], dim=0)
if osnr is not None:
popt = torch.mean(fiber_out[:2, :, :].abs().flatten(), dim=-1)
noise = torch.randn_like(fiber_out[:2, :, :])
pn = torch.mean(noise.abs().flatten(), dim=-1)
noise = noise * (popt / pn) * 10 ** (-osnr / 20)
fiber_out[:2, :, :] = torch.add(fiber_out[:2, :, :], noise)
# data_raw = torch.cat([data_raw, timestamps_doubled], dim=1)
# data layout
# [ [E_in_x, E_in_y, timestamps],
# [E_out_x, E_out_y, timestamps] ]
self.data = data_raw.unfold(dimension=-1, size=self.samples_per_slice, step=1)
self.data = self.data.movedim(-2, 0)
self.fiber_in = fiber_in.unfold(dimension=-1, size=self.samples_per_slice, step=1)
self.fiber_in = self.fiber_in.movedim(-2, 0)
if randomise_polarisations:
self.angles = torch.rand(self.data.shape[0]) * np.pi * 2
# self.data[:, 1, :2, :] = self.rotate(self.data[:, 1, :2, :], self.angles)
else:
self.angles = torch.zeros(self.data.shape[0])
self.fiber_out = fiber_out.unfold(dimension=-1, size=self.samples_per_slice, step=1)
self.fiber_out = self.fiber_out.movedim(-2, 0)
# self.data = data_raw.unfold(dimension=-1, size=self.samples_per_slice, step=1)
# self.data = self.data.movedim(-2, 0)
# self.angles = torch.zeros(self.data.shape[0])
...
# ...
# -> [no_slices, 2, 3, samples_per_slice]
@@ -305,51 +347,56 @@ class FiberRegenerationDataset(Dataset):
# ...
# ] -> [no_slices, 2, 3, samples_per_slice]
...
def __len__(self):
return self.data.shape[0]
return self.fiber_in.shape[0]
def __getitem__(self, idx):
if isinstance(idx, slice):
return [self.__getitem__(i) for i in range(*idx.indices(len(self)))]
else:
data_slice = self.data[idx].squeeze()
data_slice = data_slice[:, :, : data_slice.shape[2] // self.output_dim * self.output_dim]
# fiber in: [E_in_x, E_in_y, timestamps]
# fiber out: [E_out_x, E_out_y, timestamps, E_out_x_rot, E_out_y_rot, angle]
data_slice = data_slice.view(data_slice.shape[0], data_slice.shape[1], self.output_dim, -1)
fiber_in = self.fiber_in[idx].squeeze()
fiber_out = self.fiber_out[idx].squeeze()
fiber_in = fiber_in[..., : fiber_in.shape[-1] // self.output_dim * self.output_dim]
fiber_out = fiber_out[..., : fiber_out.shape[-1] // self.output_dim * self.output_dim]
fiber_in = fiber_in.view(fiber_in.shape[0], self.output_dim, -1)
fiber_out = fiber_out.view(fiber_out.shape[0], self.output_dim, -1)
# data_slice = data_slice[:, :, : data_slice.shape[2] // self.output_dim * self.output_dim]
# data_slice = data_slice.view(data_slice.shape[0], data_slice.shape[1], self.output_dim, -1)
# angle = self.angles[idx]
center_angle = fiber_out[5, self.output_dim // 2, 0]
angles = fiber_out[5, :, 0]
plot_data = fiber_out[:2, self.output_dim // 2, 0].detach().clone()
plot_data_rot = fiber_out[3:5, self.output_dim // 2, 0].detach().clone()
data = fiber_out[3:5, :, 0]
# if self.randomise_polarisations:
# angle = torch.rand(1) * torch.pi * 2
# sine = torch.sin(angle)
# cosine = torch.cos(angle)
# data_slice_ = data_slice[1]
# data_slice[1, 0] = data_slice_[0] * cosine - data_slice_[1] * sine
# data_slice[1,1] = data_slice_[0] * sine + data_slice_[1] * cosine
# else:
# angle = torch.zeros(1)
# data = data.mT
# c = torch.cos(angle).unsqueeze(-1)
# s = torch.sin(angle).unsqueeze(-1)
# rot = torch.stack([torch.stack([c, -s], dim=1), torch.stack([s, c], dim=1)], dim=2).squeeze(-1)
# data = torch.bmm(data.mT.unsqueeze(0), rot.to(dtype=data.dtype)).squeeze(-1)
...
# data = data_slice[1, :2, :, 0]
angle = self.angles[idx]
data_index = 1
data_slice[1, :2, :, :] = self.rotate(data_slice[data_index, :2, :, :], angle)
data = data_slice[1, :2, :, 0]
# data = self.rotate(data, angle)
# angle = torch.zeros_like(angle)
# for both polarisations (x/y), calculate the mean of the signal around the current symbol (-> corresponding to a lowpass filter)
angle_data = data_slice[1, :2, :, :].reshape(2, -1).mean(dim=1)
angle_data2 = self.complex_max(data_slice[1, :2, :, :].reshape(2, -1))
plot_data = data_slice[1, :2, self.output_dim // 2, 0]
sop = self.polarimeter(plot_data)
angle_data = fiber_out[:2, :, :].reshape(2, -1).mean(dim=1).repeat(1, self.output_dim)
angle_data2 = self.complex_max(fiber_out[:2, :, :].reshape(2, -1)).repeat(1, self.output_dim)
# sop = self.polarimeter(plot_data)
# angle_data = data_slice[1, :2, :self.output_dim//2, :].squeeze().reshape(2, -1).mean(dim=-1)
# angle = data_slice[1, 3, self.output_dim // 2, 0].real
target = data_slice[0, :2, self.output_dim // 2, 0]
target_timestamp = data_slice[0, 2, self.output_dim // 2, 0].real
target = fiber_in[:2, self.output_dim // 2, 0]
plot_target = fiber_in[:2, self.output_dim // 2, 0].detach().clone()
target_timestamp = fiber_in[2, self.output_dim // 2, 0].real
...
# data_timestamps = data[-1,:].real
@@ -360,22 +407,39 @@ class FiberRegenerationDataset(Dataset):
# transpose to interleave the x and y data in the output tensor
data = data.transpose(0, 1).flatten().squeeze()
angle_data = angle_data.flatten().squeeze()
angle_data2 = angle_data.flatten().squeeze()
angle = angle.flatten().squeeze()
angle_data = angle_data.transpose(0, 1).flatten().squeeze()
angle_data2 = angle_data2.transpose(0,1).flatten().squeeze()
center_angle = center_angle.flatten().squeeze()
angles = angles.flatten().squeeze()
# data_timestamps = data_timestamps.flatten().squeeze()
# target = target.transpose(0,1).flatten().squeeze()
target = target.flatten().squeeze()
target_timestamp = target_timestamp.flatten().squeeze()
plot_target = plot_target.flatten().squeeze()
plot_data = plot_data.flatten().squeeze()
plot_data_rot = plot_data_rot.flatten().squeeze()
return {
"x": data,
"y": target,
"center_angle": center_angle,
"angles": angles,
"mean_angle": angles.mean(),
# "sop": sop,
"angle_data": angle_data,
"angle_data2": angle_data2,
"timestamp": target_timestamp,
"plot_target": plot_target,
"plot_data": plot_data,
"plot_data_rot": plot_data_rot,
}
return {"x": data, "y": target, "angle": angle, "sop": sop, "angle_data": angle_data, "angle_data2": angle_data2, "timestamp": target_timestamp, "plot_data": plot_data}
def complex_max(self, data, dim=-1):
# returns element(s) with the maximum absolute value along a given dimension
# ind = torch.argmax(data.abs(), dim=dim, keepdim=True)
# max_values = torch.gather(data, dim, ind).squeeze(dim=dim)
# return max_values
return torch.gather(data, dim, torch.argmax(data.abs(), dim=dim, keepdim=True)).squeeze(dim=dim)
def rotate(self, data, angle):
# rotates a 2d tensor by a given angle
@@ -388,7 +452,25 @@ class FiberRegenerationDataset(Dataset):
cosine = torch.cos(angle)
return torch.stack([data[0] * cosine - data[1] * sine, data[0] * sine + data[1] * cosine], dim=0)
def rotate_all(self):
def do_rotation(j, num_processes):
for i in range(len(self) // num_processes):
index = i * num_processes + j
self.data[index, 1, :2, :] = self.rotate(self.data[index, 1, :2, :], self.angles[index])
self.processes = []
for j in range(mp.cpu_count()):
self.processes.append(mp.Process(target=do_rotation, args=(j, mp.cpu_count())))
self.processes[-1].start()
for p in self.processes:
p.join()
for i in range(len(self) // mp.cpu_count() * mp.cpu_count(), len(self)):
self.data[i, 1, :2, :] = self.rotate(self.data[i, 1, :2, :], self.angles[i])
def polarimeter(self, data):
# data: [2, ...] -> x, y
# returns [4] -> S0, S1, S2, S3
@@ -396,12 +478,12 @@ class FiberRegenerationDataset(Dataset):
y = data[1].mean()
I_X = x.abs().square()
I_Y = y.abs().square()
I_45 = (x+y).abs().square()
I_RHC = (x + 1j*y).abs().square()
I_45 = (x + y).abs().square()
I_RHC = (x + 1j * y).abs().square()
S0 = I_X + I_Y
S1 = (2*I_X - S0) / S0
S2 = (2*I_45 - S0) / S0
S3 = (2*I_RHC - S0) / S0
S1 = (2 * I_X - S0) / S0
S2 = (2 * I_45 - S0) / S0
S3 = (2 * I_RHC - S0) / S0
return torch.stack([S1, S2, S3], dim=0)
return torch.stack([S0, S1, S2, S3], dim=0)