add training script for polarization estimation, refactor model definitions, randomised polarisation support in data_loader
This commit is contained in:
443
src/single-core-regen/hypertraining/lighning_models.py
Normal file
443
src/single-core-regen/hypertraining/lighning_models.py
Normal file
@@ -0,0 +1,443 @@
|
|||||||
|
from typing import Any
|
||||||
|
import lightning as L
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
# import torch.nn.functional as F
|
||||||
|
|
||||||
|
from util.complexNN import DropoutComplex, Scale, ONNRect, EOActivation, energy_conserving, clamp, complex_mse_loss
|
||||||
|
from util.datasets import FiberRegenerationDataset
|
||||||
|
|
||||||
|
|
||||||
|
class regeneratorData(L.LightningDataModule):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config_globs,
|
||||||
|
output_symbols,
|
||||||
|
output_dim,
|
||||||
|
dtype,
|
||||||
|
drop_first,
|
||||||
|
shuffle=True,
|
||||||
|
train_split=None,
|
||||||
|
batch_size=None,
|
||||||
|
loader_settings=None,
|
||||||
|
seed=None,
|
||||||
|
num_symbols=None,
|
||||||
|
test_globs=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self._config_globs = config_globs
|
||||||
|
self._test_globs = test_globs
|
||||||
|
self._test_data_available = test_globs is not None
|
||||||
|
if self._test_data_available:
|
||||||
|
self.test_dataloader = self._test_dataloader
|
||||||
|
self._output_symbols = output_symbols
|
||||||
|
self._output_dim = output_dim
|
||||||
|
self._dtype = dtype
|
||||||
|
self._drop_first = drop_first
|
||||||
|
self._seed = seed
|
||||||
|
self._shuffle = shuffle
|
||||||
|
self._num_symbols = num_symbols
|
||||||
|
self._train_split = train_split if train_split is not None else 0.8
|
||||||
|
self.batch_size = batch_size if batch_size is not None else 1024
|
||||||
|
self._loader_settings = loader_settings if loader_settings is not None else {}
|
||||||
|
|
||||||
|
def _get_data(self):
|
||||||
|
self._data_train = FiberRegenerationDataset(
|
||||||
|
file_path=self._config_globs,
|
||||||
|
symbols=self._output_symbols,
|
||||||
|
output_dim=self._output_dim,
|
||||||
|
dtype=self._dtype,
|
||||||
|
real=not self._dtype.is_complex,
|
||||||
|
drop_first=self._drop_first,
|
||||||
|
num_symbols=self._num_symbols,
|
||||||
|
)
|
||||||
|
# self._data_plot = FiberRegenerationDataset(
|
||||||
|
# file_path=self._config_globs,
|
||||||
|
# symbols=self._output_symbols,
|
||||||
|
# output_dim=self._output_dim,
|
||||||
|
# dtype=self._dtype,
|
||||||
|
# real=not self._dtype.is_complex,
|
||||||
|
# drop_first=self._drop_first,
|
||||||
|
# num_symbols=400,
|
||||||
|
# )
|
||||||
|
if self._test_data_available:
|
||||||
|
self._data_test = FiberRegenerationDataset(
|
||||||
|
file_path=self._test_globs,
|
||||||
|
symbols=self._output_symbols,
|
||||||
|
output_dim=self._output_dim,
|
||||||
|
dtype=self._dtype,
|
||||||
|
real=not self._dtype.is_complex,
|
||||||
|
drop_first=self._drop_first,
|
||||||
|
num_symbols=self._num_symbols,
|
||||||
|
)
|
||||||
|
return self._data_train, self._data_test
|
||||||
|
return self._data_train
|
||||||
|
|
||||||
|
def _split_data(self, stage="fit", split=None, shuffle=None):
|
||||||
|
_split = split if split is not None else self._train_split
|
||||||
|
_shuffle = shuffle if shuffle is not None else self._shuffle
|
||||||
|
|
||||||
|
dataset_size = len(self._data_train)
|
||||||
|
indices = list(range(dataset_size))
|
||||||
|
split_index = int(np.floor(_split * dataset_size))
|
||||||
|
train_indices, valid_indices = indices[:split_index], indices[split_index:]
|
||||||
|
if _shuffle:
|
||||||
|
np.random.seed(self._seed)
|
||||||
|
np.random.shuffle(train_indices)
|
||||||
|
|
||||||
|
|
||||||
|
if _shuffle:
|
||||||
|
if stage == "fit" or stage == "predict":
|
||||||
|
self._train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
|
||||||
|
# if stage == "fit" or stage == "validate":
|
||||||
|
# self._valid_sampler = torch.utils.data.SubsetRandomSampler(valid_indices)
|
||||||
|
else:
|
||||||
|
if stage == "fit" or stage == "predict":
|
||||||
|
self._train_sampler = train_indices
|
||||||
|
if stage == "fit" or stage == "validate":
|
||||||
|
self._valid_sampler = valid_indices
|
||||||
|
|
||||||
|
if stage == "fit":
|
||||||
|
return self._train_sampler, self._valid_sampler
|
||||||
|
elif stage == "validate":
|
||||||
|
return self._valid_sampler
|
||||||
|
elif stage == "predict":
|
||||||
|
return self._train_sampler
|
||||||
|
|
||||||
|
def prepare_data(self):
|
||||||
|
self._get_data()
|
||||||
|
|
||||||
|
def setup(self, stage=None):
|
||||||
|
stage = stage or "fit"
|
||||||
|
self._split_data(stage=stage)
|
||||||
|
|
||||||
|
def train_dataloader(self):
|
||||||
|
return torch.utils.data.DataLoader(
|
||||||
|
self._data_train,
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
sampler=self._train_sampler,
|
||||||
|
**self._loader_settings
|
||||||
|
)
|
||||||
|
|
||||||
|
def val_dataloader(self):
|
||||||
|
return torch.utils.data.DataLoader(
|
||||||
|
self._data_train,
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
sampler=self._valid_sampler,
|
||||||
|
**self._loader_settings
|
||||||
|
)
|
||||||
|
|
||||||
|
def _test_dataloader(self):
|
||||||
|
return torch.utils.data.DataLoader(
|
||||||
|
self._data_test,
|
||||||
|
shuffle=self._shuffle,
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
**self._loader_settings
|
||||||
|
)
|
||||||
|
|
||||||
|
def predict_dataloader(self):
|
||||||
|
return torch.utils.data.DataLoader(
|
||||||
|
self._data_plot,
|
||||||
|
shuffle=False,
|
||||||
|
batch_size=40,
|
||||||
|
pin_memory=True,
|
||||||
|
drop_last=True,
|
||||||
|
num_workers=4,
|
||||||
|
prefetch_factor=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
# def plot_dataloader(self):
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class regenerator(L.LightningModule):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*dims,
|
||||||
|
layer_function=ONNRect,
|
||||||
|
layer_func_kwargs: dict | None = {"square": True},
|
||||||
|
act_function=EOActivation,
|
||||||
|
act_func_kwargs: dict | None = None,
|
||||||
|
parametrizations: list[dict] | None = [
|
||||||
|
{
|
||||||
|
"tensor_name": "weight",
|
||||||
|
"parametrization": energy_conserving,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tensor_name": "alpha",
|
||||||
|
"parametrization": clamp,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tensor_name": "alpha",
|
||||||
|
"parametrization": clamp,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
dtype=torch.complex64,
|
||||||
|
dropout_prob=0.01,
|
||||||
|
scale_layers=False,
|
||||||
|
optimizer=torch.optim.AdamW,
|
||||||
|
optimizer_kwargs: dict | None = {
|
||||||
|
"lr": 0.01,
|
||||||
|
"amsgrad": True,
|
||||||
|
},
|
||||||
|
lr_scheduler=None,
|
||||||
|
lr_scheduler_kwargs: dict | None = {
|
||||||
|
"patience": 20,
|
||||||
|
"factor": 0.5,
|
||||||
|
"min_lr": 1e-6,
|
||||||
|
"cooldown": 10,
|
||||||
|
},
|
||||||
|
sps = 128,
|
||||||
|
# **kwargs,
|
||||||
|
):
|
||||||
|
torch.set_float32_matmul_precision('high')
|
||||||
|
layer_func_kwargs = layer_func_kwargs if layer_func_kwargs is not None else {}
|
||||||
|
act_func_kwargs = act_func_kwargs if act_func_kwargs is not None else {}
|
||||||
|
optimizer_kwargs = optimizer_kwargs if optimizer_kwargs is not None else {}
|
||||||
|
lr_scheduler_kwargs = lr_scheduler_kwargs if lr_scheduler_kwargs is not None else {}
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.example_input_array = torch.randn(1, dims[0], dtype=dtype)
|
||||||
|
self._sps = sps
|
||||||
|
|
||||||
|
self.optimizer_settings = {
|
||||||
|
"optimizer": optimizer,
|
||||||
|
"optimizer_kwargs": optimizer_kwargs,
|
||||||
|
"lr_scheduler": lr_scheduler,
|
||||||
|
"lr_scheduler_kwargs": lr_scheduler_kwargs,
|
||||||
|
}
|
||||||
|
|
||||||
|
# if len(dims) == 0:
|
||||||
|
# try:
|
||||||
|
# dims = kwargs["dims"]
|
||||||
|
# except KeyError:
|
||||||
|
# raise ValueError("dims must be provided")
|
||||||
|
self._n_hidden_layers = len(dims) - 2
|
||||||
|
|
||||||
|
self.build_model(dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob, scale_layers)
|
||||||
|
|
||||||
|
def build_model(self, dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob, scale_layers):
|
||||||
|
input_layer = nn.Sequential(
|
||||||
|
layer_function(dims[0], dims[1], dtype=dtype, **layer_func_kwargs),
|
||||||
|
act_function(size=dims[1], **act_func_kwargs),
|
||||||
|
DropoutComplex(p=dropout_prob),
|
||||||
|
)
|
||||||
|
|
||||||
|
if scale_layers:
|
||||||
|
input_layer = nn.Sequential(Scale(dims[0]), input_layer)
|
||||||
|
|
||||||
|
self.layer_0 = input_layer
|
||||||
|
|
||||||
|
for i in range(1, self._n_hidden_layers):
|
||||||
|
layer = nn.Sequential(
|
||||||
|
layer_function(dims[i], dims[i + 1], dtype=dtype, **layer_func_kwargs),
|
||||||
|
act_function(size=dims[i + 1], **act_func_kwargs),
|
||||||
|
DropoutComplex(p=dropout_prob),
|
||||||
|
)
|
||||||
|
if scale_layers:
|
||||||
|
layer = nn.Sequential(Scale(dims[i]), layer)
|
||||||
|
setattr(self, f"layer_{i}", layer)
|
||||||
|
|
||||||
|
output_layer = nn.Sequential(
|
||||||
|
layer_function(dims[-2], dims[-1], dtype=dtype, **layer_func_kwargs),
|
||||||
|
act_function(size=dims[-1], **act_func_kwargs),
|
||||||
|
Scale(dims[-1]),
|
||||||
|
)
|
||||||
|
setattr(self, f"layer_{self._n_hidden_layers}", output_layer)
|
||||||
|
|
||||||
|
if parametrizations is not None:
|
||||||
|
self._apply_parametrizations(self, parametrizations)
|
||||||
|
|
||||||
|
def _apply_parametrizations(self, layer, parametrizations):
|
||||||
|
for sub_layer in layer.children():
|
||||||
|
if len(sub_layer._modules) > 0:
|
||||||
|
self._apply_parametrizations(sub_layer, parametrizations)
|
||||||
|
else:
|
||||||
|
for parametrization in parametrizations:
|
||||||
|
tensor_name = parametrization.get("tensor_name", None)
|
||||||
|
if tensor_name is None:
|
||||||
|
continue
|
||||||
|
parametrization_func = parametrization.get("parametrization", None)
|
||||||
|
if parametrization_func is None:
|
||||||
|
continue
|
||||||
|
param_kwargs = parametrization.get("kwargs", {})
|
||||||
|
if tensor_name in sub_layer._parameters:
|
||||||
|
parametrization_func(sub_layer, tensor_name, **param_kwargs)
|
||||||
|
|
||||||
|
def _trace_powers(self, enable, x, powers=None):
|
||||||
|
if not enable:
|
||||||
|
return
|
||||||
|
if powers is None:
|
||||||
|
powers = []
|
||||||
|
powers.append(x.abs().square().sum())
|
||||||
|
return powers
|
||||||
|
|
||||||
|
# def plot(self, mode):
|
||||||
|
# self.predict_step()
|
||||||
|
|
||||||
|
# def validation_epoch_end(self, outputs):
|
||||||
|
# x = torch.vstack([output['x'].view(output['x'].shape[0], -1, 2)[:, output['x'].shape[1]//2, :].squeeze() for output in outputs])
|
||||||
|
# y = torch.vstack([output['y'].view(output['y'].shape[0], -1, 2).squeeze() for output in outputs])
|
||||||
|
# y_hat = torch.vstack([output['y_hat'].view(output['y_hat'].shape[0], -1, 2).squeeze() for output in outputs])
|
||||||
|
# timesteps = torch.vstack([output['timesteps'].squeeze() for output in outputs])
|
||||||
|
# powers = torch.vstack([output['powers'] for output in outputs])
|
||||||
|
|
||||||
|
# return {'x': x, 'y': y, 'y_hat': y_hat, 'timesteps': timesteps, 'powers': powers}
|
||||||
|
|
||||||
|
def on_validation_epoch_end(self):
|
||||||
|
if self.current_epoch % 10 == 0 or self.current_epoch == self.trainer.max_epochs - 1 or self.current_epoch < 10:
|
||||||
|
x = self.val_outputs['x']
|
||||||
|
# x = x.view(x.shape[0], -1, 2)
|
||||||
|
# x = x[:, x.shape[1]//2, :].squeeze()
|
||||||
|
y = self.val_outputs['y']
|
||||||
|
# y = y.view(y.shape[0], -1, 2).squeeze()
|
||||||
|
y_hat = self.val_outputs['y_hat']
|
||||||
|
# y_hat = y_hat.view(y_hat.shape[0], -1, 2).squeeze()
|
||||||
|
timesteps = self.val_outputs['timesteps']
|
||||||
|
# timesteps = timesteps.squeeze()
|
||||||
|
powers = self.val_outputs['powers']
|
||||||
|
# powers = powers.squeeze()
|
||||||
|
|
||||||
|
fiber_in = x.detach().cpu().numpy()
|
||||||
|
fiber_out = y.detach().cpu().numpy()
|
||||||
|
regen = y_hat.detach().cpu().numpy()
|
||||||
|
timesteps = timesteps.detach().cpu().numpy()
|
||||||
|
# powers = np.array([power.detach().cpu().numpy() for power in powers])
|
||||||
|
|
||||||
|
# fiber_in = np.concat(fiber_in, axis=0)
|
||||||
|
# fiber_out = np.concat(fiber_out, axis=0)
|
||||||
|
# regen = np.concat(regen, axis=0)
|
||||||
|
# timesteps = np.concat(timesteps, axis=0)
|
||||||
|
# powers = powers.detach().cpu().numpy()
|
||||||
|
|
||||||
|
|
||||||
|
import gc
|
||||||
|
|
||||||
|
fig = self.plot_model_head(fiber_in, fiber_out, regen, timesteps, sps=self._sps)
|
||||||
|
|
||||||
|
self.logger.experiment.add_figure("model response", fig, self.current_epoch)
|
||||||
|
|
||||||
|
# fig = self.plot_model_eye(fiber_in, fiber_out, regen, timesteps, sps=self._sps)
|
||||||
|
|
||||||
|
# self.logger.experiment.add_figure("model eye", fig, self.current_epoch)
|
||||||
|
|
||||||
|
# fig = self.plot_model_powers(powers)
|
||||||
|
|
||||||
|
# self.logger.experiment.add_figure("powers", fig, self.current_epoch)
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
# x, y, y_hat, timesteps, powers = self.validation_epoch_end(self.outputs)
|
||||||
|
# self.plot(x, y, y_hat, timesteps, powers)
|
||||||
|
|
||||||
|
def plot_model_head(self, fiber_in, fiber_out, regen, timesteps, sps):
|
||||||
|
import matplotlib
|
||||||
|
matplotlib.use("TkCairo")
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
ordering = np.argsort(timesteps)
|
||||||
|
signals = [signal[ordering] for signal in [fiber_in, fiber_out, regen]]
|
||||||
|
timesteps = timesteps[ordering]
|
||||||
|
|
||||||
|
signals = [signal[:sps*40] for signal in signals]
|
||||||
|
timesteps = timesteps[:sps*40]
|
||||||
|
|
||||||
|
fig, axs = plt.subplots(1, 2, sharex=True, sharey=True)
|
||||||
|
fig.set_figwidth(16)
|
||||||
|
fig.set_figheight(4)
|
||||||
|
|
||||||
|
for i, ax in enumerate(axs):
|
||||||
|
for j, signal in enumerate(signals):
|
||||||
|
ax.plot(timesteps / sps, np.square(np.abs(signal[:,i])), label=["fiber in", "fiber out", "regen"][j] + [" x", " y"][i])
|
||||||
|
ax.set_xlabel("symbol")
|
||||||
|
ax.set_ylabel("amplitude")
|
||||||
|
ax.minorticks_on()
|
||||||
|
ax.tick_params(axis="y", which="minor", left=False, right=False)
|
||||||
|
ax.grid(which="major", axis="x")
|
||||||
|
ax.grid(which="minor", axis="x", linestyle=":")
|
||||||
|
ax.grid(which="major", axis="y")
|
||||||
|
ax.legend(loc="upper right")
|
||||||
|
fig.tight_layout()
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
def plot_model_eye(self, fiber_in, fiber_out, regen, timesteps, sps):
|
||||||
|
...
|
||||||
|
|
||||||
|
def plot_model_powers(self, powers):
|
||||||
|
...
|
||||||
|
|
||||||
|
def forward(self, x, trace_powers=False):
|
||||||
|
powers = self._trace_powers(trace_powers, x)
|
||||||
|
x = self.layer_0(x)
|
||||||
|
powers = self._trace_powers(trace_powers, x, powers)
|
||||||
|
for i in range(1, self._n_hidden_layers):
|
||||||
|
x = getattr(self, f"layer_{i}")(x)
|
||||||
|
powers = self._trace_powers(trace_powers, x, powers)
|
||||||
|
x = getattr(self, f"layer_{self._n_hidden_layers}")(x)
|
||||||
|
powers = self._trace_powers(trace_powers, x, powers)
|
||||||
|
if trace_powers:
|
||||||
|
return x, powers
|
||||||
|
return x
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
optimizer = self.optimizer_settings["optimizer"](
|
||||||
|
self.parameters(), **self.optimizer_settings["optimizer_kwargs"]
|
||||||
|
)
|
||||||
|
if self.optimizer_settings["lr_scheduler"] is not None:
|
||||||
|
lr_scheduler = self.optimizer_settings["lr_scheduler"](
|
||||||
|
optimizer, **self.optimizer_settings["lr_scheduler_kwargs"]
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"optimizer": optimizer,
|
||||||
|
"lr_scheduler": {
|
||||||
|
"scheduler": lr_scheduler,
|
||||||
|
"monitor": "val_loss",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return {"optimizer": optimizer}
|
||||||
|
|
||||||
|
def training_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0):
|
||||||
|
x, y, timesteps = batch
|
||||||
|
y_hat = self(x)
|
||||||
|
loss = complex_mse_loss(y_hat, y, power=True)
|
||||||
|
self.log("train_loss", loss, on_epoch=True, on_step=True)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0):
|
||||||
|
x, y, timesteps = batch
|
||||||
|
if batch_idx == 0:
|
||||||
|
y_hat, powers = self.forward(x, trace_powers=True)
|
||||||
|
else:
|
||||||
|
y_hat = self.forward(x)
|
||||||
|
loss = complex_mse_loss(y_hat, y, power=True)
|
||||||
|
self.log("val_loss", loss, on_epoch=True)
|
||||||
|
y = y.view(y.shape[0], -1, 2).squeeze()
|
||||||
|
x = x.view(x.shape[0], -1, 2)
|
||||||
|
x = x[:, x.shape[1]//2, :].squeeze()
|
||||||
|
y_hat = y_hat.view(y_hat.shape[0], -1, 2).squeeze()
|
||||||
|
timesteps = timesteps.squeeze()
|
||||||
|
if batch_idx == 0:
|
||||||
|
powers = np.array([power.detach().cpu() for power in powers])
|
||||||
|
self.val_outputs = {"y": y, "x": x, "y_hat": y_hat, "timesteps": timesteps, "powers": powers}
|
||||||
|
else:
|
||||||
|
self.val_outputs["y"] = torch.vstack([self.val_outputs["y"], y])
|
||||||
|
self.val_outputs["x"] = torch.vstack([self.val_outputs["x"], x])
|
||||||
|
self.val_outputs["y_hat"] = torch.vstack([self.val_outputs["y_hat"], y_hat])
|
||||||
|
self.val_outputs["timesteps"] = torch.concat([self.val_outputs["timesteps"], timesteps], dim=0)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0):
|
||||||
|
x, y, timesteps = batch
|
||||||
|
y_hat = self(x)
|
||||||
|
loss = complex_mse_loss(y_hat, y, power=True)
|
||||||
|
self.log("test_loss", loss, on_epoch=True)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
# def predict_step(self, batch, batch_idx):
|
||||||
|
# x, y, timesteps = batch
|
||||||
|
# y_hat = self(x)
|
||||||
|
# return y, x, y_hat, timesteps
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
204
src/single-core-regen/hypertraining/models.py
Normal file
204
src/single-core-regen/hypertraining/models.py
Normal file
@@ -0,0 +1,204 @@
|
|||||||
|
import torch
|
||||||
|
from torch.nn import Module, Sequential
|
||||||
|
|
||||||
|
from util.complexNN import (
|
||||||
|
DropoutComplex,
|
||||||
|
Scale,
|
||||||
|
ONNRect,
|
||||||
|
photodiode,
|
||||||
|
EOActivation,
|
||||||
|
polarimeter,
|
||||||
|
normalize_by_first
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class polarisation_estimator2(Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(polarisation_estimator2, self).__init__()
|
||||||
|
self.layers = Sequential(
|
||||||
|
polarimeter(),
|
||||||
|
torch.nn.Linear(4, 4),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
torch.nn.Dropout(p=0.01),
|
||||||
|
torch.nn.Linear(4, 4),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
torch.nn.Dropout(p=0.01),
|
||||||
|
torch.nn.Linear(4, 4),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# x = self.polarimeter(x)
|
||||||
|
for layer in self.layers:
|
||||||
|
x = layer(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class polarisation_estimator(Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*dims,
|
||||||
|
layer_function=ONNRect,
|
||||||
|
layer_func_kwargs: dict | None = None,
|
||||||
|
output_layer_function=photodiode,
|
||||||
|
# output_layer_func_kwargs: dict | None = None,
|
||||||
|
act_function=EOActivation,
|
||||||
|
act_func_kwargs: dict | None = None,
|
||||||
|
parametrizations: list[dict] = None,
|
||||||
|
dtype=torch.float64,
|
||||||
|
dropout_prob=0.01,
|
||||||
|
scale_layers=False,
|
||||||
|
):
|
||||||
|
super(polarisation_estimator, self).__init__()
|
||||||
|
self._n_hidden_layers = len(dims) - 2
|
||||||
|
|
||||||
|
layer_func_kwargs = layer_func_kwargs or {}
|
||||||
|
act_func_kwargs = act_func_kwargs or {}
|
||||||
|
|
||||||
|
self.build_model(dims, layer_function, layer_func_kwargs, output_layer_function, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob, scale_layers)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.layer_0(x)
|
||||||
|
for i in range(1, self._n_hidden_layers):
|
||||||
|
x = getattr(self, f"layer_{i}")(x)
|
||||||
|
x = getattr(self, f"layer_{self._n_hidden_layers}")(x)
|
||||||
|
x = torch.remainder(x, torch.ones_like(x) * 2 * torch.pi)
|
||||||
|
return x.squeeze()
|
||||||
|
|
||||||
|
def build_model(self, dims, layer_function, layer_func_kwargs, output_layer_function, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob, scale_layers):
|
||||||
|
for i in range(0, self._n_hidden_layers):
|
||||||
|
self.add_module(f"layer_{i}", Sequential())
|
||||||
|
|
||||||
|
if scale_layers:
|
||||||
|
self.get_submodule(f"layer_{i}").add_module("scale", Scale(dims[i]))
|
||||||
|
|
||||||
|
module = layer_function(dims[i], dims[i + 1], dtype=dtype, **layer_func_kwargs)
|
||||||
|
self.get_submodule(f"layer_{i}").add_module("ONN", module)
|
||||||
|
|
||||||
|
module = act_function(size=dims[i + 1], **act_func_kwargs)
|
||||||
|
self.get_submodule(f"layer_{i}").add_module("activation", module)
|
||||||
|
|
||||||
|
module = DropoutComplex(p=dropout_prob)
|
||||||
|
self.get_submodule(f"layer_{i}").add_module("dropout", module)
|
||||||
|
|
||||||
|
self.add_module(f"layer_{self._n_hidden_layers}", Sequential())
|
||||||
|
|
||||||
|
if scale_layers:
|
||||||
|
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("scale", Scale(dims[-2]))
|
||||||
|
|
||||||
|
module = layer_function(dims[-2], dims[-1], dtype=dtype, **layer_func_kwargs)
|
||||||
|
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("ONN", module)
|
||||||
|
|
||||||
|
module = output_layer_function(size=dims[-1])
|
||||||
|
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("photodiode", module)
|
||||||
|
|
||||||
|
# module = normalize_by_first()
|
||||||
|
# self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("normalize", module)
|
||||||
|
|
||||||
|
if parametrizations is not None:
|
||||||
|
self._apply_parametrizations(self, parametrizations)
|
||||||
|
|
||||||
|
def _apply_parametrizations(self, layer, parametrizations):
|
||||||
|
for sub_layer in layer.children():
|
||||||
|
if len(sub_layer._modules) > 0:
|
||||||
|
self._apply_parametrizations(sub_layer, parametrizations)
|
||||||
|
else:
|
||||||
|
for parametrization in parametrizations:
|
||||||
|
tensor_name = parametrization.get("tensor_name", None)
|
||||||
|
if tensor_name is None:
|
||||||
|
continue
|
||||||
|
parametrization_func = parametrization.get("parametrization", None)
|
||||||
|
if parametrization_func is None:
|
||||||
|
continue
|
||||||
|
param_kwargs = parametrization.get("kwargs", {})
|
||||||
|
if tensor_name in sub_layer._parameters:
|
||||||
|
parametrization_func(sub_layer, tensor_name, **param_kwargs)
|
||||||
|
|
||||||
|
class regenerator(Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*dims,
|
||||||
|
layer_function=ONNRect,
|
||||||
|
layer_func_kwargs: dict | None = None,
|
||||||
|
act_function=EOActivation,
|
||||||
|
act_func_kwargs: dict | None = None,
|
||||||
|
parametrizations: list[dict] = None,
|
||||||
|
dtype=torch.float64,
|
||||||
|
dropout_prob=0.01,
|
||||||
|
scale_layers=False,
|
||||||
|
):
|
||||||
|
super(regenerator, self).__init__()
|
||||||
|
self._n_hidden_layers = len(dims) - 2
|
||||||
|
|
||||||
|
layer_func_kwargs = layer_func_kwargs or {}
|
||||||
|
act_func_kwargs = act_func_kwargs or {}
|
||||||
|
|
||||||
|
self.build_model(dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob, scale_layers)
|
||||||
|
|
||||||
|
def build_model(self, dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob, scale_layers):
|
||||||
|
for i in range(0, self._n_hidden_layers):
|
||||||
|
self.add_module(f"layer_{i}", Sequential())
|
||||||
|
|
||||||
|
if scale_layers:
|
||||||
|
self.get_submodule(f"layer_{i}").add_module("scale", Scale(dims[i]))
|
||||||
|
|
||||||
|
module = layer_function(dims[i], dims[i + 1], dtype=dtype, **layer_func_kwargs)
|
||||||
|
self.get_submodule(f"layer_{i}").add_module("ONN", module)
|
||||||
|
|
||||||
|
module = act_function(size=dims[i + 1], **act_func_kwargs)
|
||||||
|
self.get_submodule(f"layer_{i}").add_module("activation", module)
|
||||||
|
|
||||||
|
module = DropoutComplex(p=dropout_prob)
|
||||||
|
self.get_submodule(f"layer_{i}").add_module("dropout", module)
|
||||||
|
|
||||||
|
self.add_module(f"layer_{self._n_hidden_layers}", Sequential())
|
||||||
|
|
||||||
|
if scale_layers:
|
||||||
|
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("scale", Scale(dims[-2]))
|
||||||
|
|
||||||
|
module = layer_function(dims[-2], dims[-1], dtype=dtype, **layer_func_kwargs)
|
||||||
|
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("ONN", module)
|
||||||
|
|
||||||
|
module = act_function(size=dims[-1], **act_func_kwargs)
|
||||||
|
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("activation", module)
|
||||||
|
|
||||||
|
# module = Scale(size=dims[-1])
|
||||||
|
# self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("out_scale", module)
|
||||||
|
|
||||||
|
if parametrizations is not None:
|
||||||
|
self._apply_parametrizations(self, parametrizations)
|
||||||
|
|
||||||
|
def _apply_parametrizations(self, layer, parametrizations):
|
||||||
|
for sub_layer in layer.children():
|
||||||
|
if len(sub_layer._modules) > 0:
|
||||||
|
self._apply_parametrizations(sub_layer, parametrizations)
|
||||||
|
else:
|
||||||
|
for parametrization in parametrizations:
|
||||||
|
tensor_name = parametrization.get("tensor_name", None)
|
||||||
|
if tensor_name is None:
|
||||||
|
continue
|
||||||
|
parametrization_func = parametrization.get("parametrization", None)
|
||||||
|
if parametrization_func is None:
|
||||||
|
continue
|
||||||
|
param_kwargs = parametrization.get("kwargs", {})
|
||||||
|
if tensor_name in sub_layer._parameters:
|
||||||
|
parametrization_func(sub_layer, tensor_name, **param_kwargs)
|
||||||
|
|
||||||
|
def _trace_powers(self, enable, x, powers=None):
|
||||||
|
if not enable:
|
||||||
|
return
|
||||||
|
if powers is None:
|
||||||
|
powers = []
|
||||||
|
powers.append(x.abs().square().sum())
|
||||||
|
return powers
|
||||||
|
|
||||||
|
def forward(self, x, trace_powers=False):
|
||||||
|
powers = self._trace_powers(trace_powers, x)
|
||||||
|
x = self.layer_0(x)
|
||||||
|
powers = self._trace_powers(trace_powers, x, powers)
|
||||||
|
for i in range(1, self._n_hidden_layers):
|
||||||
|
x = getattr(self, f"layer_{i}")(x)
|
||||||
|
powers = self._trace_powers(trace_powers, x, powers)
|
||||||
|
x = getattr(self, f"layer_{self._n_hidden_layers}")(x)
|
||||||
|
powers = self._trace_powers(trace_powers, x, powers)
|
||||||
|
if trace_powers:
|
||||||
|
return x, powers
|
||||||
|
return x
|
||||||
@@ -20,6 +20,22 @@ class DataSettings:
|
|||||||
xy_delay: tuple | float | int = 0
|
xy_delay: tuple | float | int = 0
|
||||||
drop_first: int = 1000
|
drop_first: int = 1000
|
||||||
train_split: float = 0.8
|
train_split: float = 0.8
|
||||||
|
polarisations: tuple | list = (0,)
|
||||||
|
randomise_polarisations: bool = False
|
||||||
|
|
||||||
|
"""
|
||||||
|
change to:
|
||||||
|
|
||||||
|
config_path: tuple | list | None = None
|
||||||
|
dtype: torch.dtype | None = None
|
||||||
|
symbols: int | float = 1
|
||||||
|
output_dim: int = 2
|
||||||
|
shuffle: bool = True
|
||||||
|
drop_first: float | int = 0
|
||||||
|
train_split: float = 0.8
|
||||||
|
randomise_polarisations: bool = False
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
# pytorch settings
|
# pytorch settings
|
||||||
@@ -30,8 +46,8 @@ class PytorchSettings:
|
|||||||
|
|
||||||
device: str = "cuda"
|
device: str = "cuda"
|
||||||
|
|
||||||
dataloader_workers: int = 2
|
dataloader_workers: int = 1
|
||||||
dataloader_prefetch: int = 2
|
dataloader_prefetch: int = 1
|
||||||
|
|
||||||
save_models: bool = True
|
save_models: bool = True
|
||||||
model_dir: str = ".models"
|
model_dir: str = ".models"
|
||||||
@@ -56,6 +72,24 @@ class ModelSettings:
|
|||||||
model_layer_kwargs: dict | None = None
|
model_layer_kwargs: dict | None = None
|
||||||
model_layer_parametrizations: list= field(default_factory=list)
|
model_layer_parametrizations: list= field(default_factory=list)
|
||||||
|
|
||||||
|
"""
|
||||||
|
change to:
|
||||||
|
|
||||||
|
dims: tuple | list | None = None
|
||||||
|
layer_function: nn.Module | None = None
|
||||||
|
layer_func_kwargs: dict | None = None
|
||||||
|
activation_function: nn.Module | None = None
|
||||||
|
activation_func_kwargs: dict | None = None
|
||||||
|
output_function: nn.Module | None = None
|
||||||
|
output_func_kwargs: dict | None = None
|
||||||
|
dropout_function: nn.Module | None = None
|
||||||
|
dropout_func_kwargs: dict | None = None
|
||||||
|
scale_function: nn.Module | None = None
|
||||||
|
scale_func_kwargs: dict | None = None
|
||||||
|
parametrizations: list | None = None
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class OptimizerSettings:
|
class OptimizerSettings:
|
||||||
@@ -65,6 +99,17 @@ class OptimizerSettings:
|
|||||||
scheduler: str | None = None
|
scheduler: str | None = None
|
||||||
scheduler_kwargs: dict | None = None
|
scheduler_kwargs: dict | None = None
|
||||||
|
|
||||||
|
"""
|
||||||
|
change to:
|
||||||
|
|
||||||
|
optimizer: torch.optim.Optimizer | None = None
|
||||||
|
optimizer_kwargs: dict | None = None
|
||||||
|
learning_rate: float | None = None
|
||||||
|
scheduler: torch.optim.lr_scheduler | None = None
|
||||||
|
scheduler_kwargs: dict | None = None
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
def _pruner_default_kwargs():
|
def _pruner_default_kwargs():
|
||||||
# MedianPruner
|
# MedianPruner
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ from rich.console import Console
|
|||||||
|
|
||||||
from util.datasets import FiberRegenerationDataset
|
from util.datasets import FiberRegenerationDataset
|
||||||
import util
|
import util
|
||||||
|
import hypertraining.models as models
|
||||||
|
|
||||||
from .settings import (
|
from .settings import (
|
||||||
GlobalSettings,
|
GlobalSettings,
|
||||||
@@ -59,8 +60,527 @@ def traverse_dict_update(target, source):
|
|||||||
except TypeError:
|
except TypeError:
|
||||||
target.__dict__[k] = v
|
target.__dict__[k] = v
|
||||||
|
|
||||||
|
def get_parameter_names_and_values(model):
|
||||||
|
def is_parametrized(module):
|
||||||
|
if hasattr(module, "parametrizations"):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
class Trainer:
|
def _get_param_info(module, prefix='', parametrization=False):
|
||||||
|
param_list = []
|
||||||
|
for name, param in module.named_parameters(recurse = parametrization):
|
||||||
|
if parametrization and name.startswith("parametrizations"):
|
||||||
|
name_parts = name.split('.')
|
||||||
|
name = name_parts[1]
|
||||||
|
param = getattr(module, name)
|
||||||
|
full_name = prefix + ('.' if prefix else '') + name
|
||||||
|
param_value = param.data
|
||||||
|
param_list.append((full_name, param_value))
|
||||||
|
|
||||||
|
for child_name, child_module in module.named_children():
|
||||||
|
child_prefix = prefix + ('.' if prefix else '') + child_name
|
||||||
|
if child_name == "parametrizations":
|
||||||
|
continue
|
||||||
|
param_list.extend(_get_param_info(child_module, child_prefix, is_parametrized(child_module)))
|
||||||
|
|
||||||
|
return param_list
|
||||||
|
|
||||||
|
return _get_param_info(model)
|
||||||
|
|
||||||
|
class PolarizationTrainer:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
global_settings=None,
|
||||||
|
data_settings=None,
|
||||||
|
pytorch_settings=None,
|
||||||
|
model_settings=None,
|
||||||
|
optimizer_settings=None,
|
||||||
|
console=None,
|
||||||
|
checkpoint_path=None,
|
||||||
|
settings_override=None,
|
||||||
|
reset_epoch=False,
|
||||||
|
):
|
||||||
|
self.mod = torch.pi/2
|
||||||
|
self.resume = checkpoint_path is not None
|
||||||
|
torch.serialization.add_safe_globals([
|
||||||
|
*util.complexNN.__all__,
|
||||||
|
GlobalSettings,
|
||||||
|
DataSettings,
|
||||||
|
ModelSettings,
|
||||||
|
OptimizerSettings,
|
||||||
|
PytorchSettings,
|
||||||
|
models.regenerator,
|
||||||
|
torch.nn.utils.parametrizations.orthogonal,
|
||||||
|
])
|
||||||
|
if self.resume:
|
||||||
|
print(f"loading checkpoint from {checkpoint_path}")
|
||||||
|
self.checkpoint_dict = torch.load(checkpoint_path, weights_only=True)
|
||||||
|
if settings_override is not None:
|
||||||
|
traverse_dict_update(self.checkpoint_dict["settings"], settings_override)
|
||||||
|
if reset_epoch:
|
||||||
|
self.checkpoint_dict["epoch"] = -1
|
||||||
|
|
||||||
|
self.global_settings: GlobalSettings = self.checkpoint_dict["settings"]["global_settings"]
|
||||||
|
self.data_settings: DataSettings = self.checkpoint_dict["settings"]["data_settings"]
|
||||||
|
self.pytorch_settings: PytorchSettings = self.checkpoint_dict["settings"]["pytorch_settings"]
|
||||||
|
self.model_settings: ModelSettings = self.checkpoint_dict["settings"]["model_settings"]
|
||||||
|
self.optimizer_settings: OptimizerSettings = self.checkpoint_dict["settings"]["optimizer_settings"]
|
||||||
|
else:
|
||||||
|
if global_settings is None:
|
||||||
|
global_settings = GlobalSettings()
|
||||||
|
raise UserWarning("Global settings not provided, using default settings")
|
||||||
|
if data_settings is None:
|
||||||
|
data_settings = DataSettings()
|
||||||
|
raise UserWarning("Data settings not provided, using default settings")
|
||||||
|
if pytorch_settings is None:
|
||||||
|
pytorch_settings = PytorchSettings()
|
||||||
|
raise UserWarning("Pytorch settings not provided, using default settings")
|
||||||
|
if model_settings is None:
|
||||||
|
model_settings = ModelSettings()
|
||||||
|
raise UserWarning("Model settings not provided, using default settings")
|
||||||
|
if optimizer_settings is None:
|
||||||
|
optimizer_settings = OptimizerSettings()
|
||||||
|
raise UserWarning("Optimizer settings not provided, using default settings")
|
||||||
|
|
||||||
|
self.global_settings: GlobalSettings = global_settings
|
||||||
|
self.data_settings: DataSettings = data_settings
|
||||||
|
self.pytorch_settings: PytorchSettings = pytorch_settings
|
||||||
|
self.model_settings: ModelSettings = model_settings
|
||||||
|
self.optimizer_settings: OptimizerSettings = optimizer_settings
|
||||||
|
|
||||||
|
self.console = console or Console()
|
||||||
|
self.writer = None
|
||||||
|
|
||||||
|
def setup_tb_writer(self, append=None):
|
||||||
|
log_dir = self.pytorch_settings.summary_dir + "/pol_" + (datetime.now().strftime("%Y%m%d_%H%M%S"))
|
||||||
|
if append is not None:
|
||||||
|
log_dir += "_" + str(append)
|
||||||
|
|
||||||
|
print(f"Logging to {log_dir}")
|
||||||
|
self.writer = SummaryWriter(log_dir=log_dir)
|
||||||
|
|
||||||
|
def save_checkpoint(self, save_dict, filename):
|
||||||
|
torch.save(save_dict, filename)
|
||||||
|
|
||||||
|
def build_checkpoint_dict(self, loss=None, epoch=None):
|
||||||
|
return {
|
||||||
|
"epoch": -1 if epoch is None else epoch,
|
||||||
|
"loss": float("inf") if loss is None else loss,
|
||||||
|
"model_state_dict": copy.deepcopy(self.model.state_dict()),
|
||||||
|
"optimizer_state_dict": copy.deepcopy(self.optimizer.state_dict()),
|
||||||
|
"scheduler_state_dict": copy.deepcopy(self.scheduler.state_dict()) if hasattr(self, "scheduler") else None,
|
||||||
|
"model_kwargs": copy.deepcopy(self.model_kwargs),
|
||||||
|
"settings": {
|
||||||
|
"global_settings": copy.deepcopy(self.global_settings),
|
||||||
|
"data_settings": copy.deepcopy(self.data_settings),
|
||||||
|
"pytorch_settings": copy.deepcopy(self.pytorch_settings),
|
||||||
|
"model_settings": copy.deepcopy(self.model_settings),
|
||||||
|
"optimizer_settings": copy.deepcopy(self.optimizer_settings),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def define_model(self, model_kwargs=None):
|
||||||
|
if self.resume:
|
||||||
|
model_kwargs = self.checkpoint_dict["model_kwargs"]
|
||||||
|
else:
|
||||||
|
model_kwargs = model_kwargs
|
||||||
|
|
||||||
|
if model_kwargs is None:
|
||||||
|
n_hidden_layers = self.model_settings.n_hidden_layers
|
||||||
|
|
||||||
|
input_dim = 2 * self.data_settings.output_size
|
||||||
|
|
||||||
|
dtype = getattr(torch, self.data_settings.dtype)
|
||||||
|
|
||||||
|
afunc = getattr(util.complexNN, self.model_settings.model_activation_func)
|
||||||
|
|
||||||
|
layer_func = getattr(util.complexNN, self.model_settings.model_layer_function)
|
||||||
|
|
||||||
|
layer_parametrizations = self.model_settings.model_layer_parametrizations
|
||||||
|
|
||||||
|
hidden_dims = [self.model_settings.overrides.get(f"n_hidden_nodes_{i}") for i in range(n_hidden_layers)]
|
||||||
|
|
||||||
|
self.model_kwargs = {
|
||||||
|
"dims": (input_dim, *hidden_dims, self.model_settings.output_dim),
|
||||||
|
"layer_function": layer_func,
|
||||||
|
"layer_func_kwargs": self.model_settings.model_layer_kwargs,
|
||||||
|
"act_function": afunc,
|
||||||
|
"act_func_kwargs": None,
|
||||||
|
"parametrizations": layer_parametrizations,
|
||||||
|
"dtype": dtype,
|
||||||
|
"dropout_prob": self.model_settings.dropout_prob,
|
||||||
|
"scale_layers": self.model_settings.scale,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
self.model_kwargs = model_kwargs
|
||||||
|
input_dim = self.model_kwargs["dims"][0]
|
||||||
|
dtype = self.model_kwargs["dtype"]
|
||||||
|
|
||||||
|
# dims = self.model_kwargs.pop("dims")
|
||||||
|
model_kwargs = copy.deepcopy(self.model_kwargs)
|
||||||
|
self.model = models.polarisation_estimator(*model_kwargs.pop('dims'),**model_kwargs)
|
||||||
|
# self.model = models.polarisation_estimator2()
|
||||||
|
|
||||||
|
if self.writer is not None:
|
||||||
|
try:
|
||||||
|
self.writer.add_graph(self.model, torch.rand(1, input_dim, dtype=dtype), use_strict_trace=False)
|
||||||
|
except RuntimeError:
|
||||||
|
self.writer.add_graph(self.model, torch.rand(1, 2, dtype=dtype), use_strict_trace=False)
|
||||||
|
|
||||||
|
self.model = self.model.to(self.pytorch_settings.device)
|
||||||
|
if self.resume:
|
||||||
|
self.model.load_state_dict(self.checkpoint_dict["model_state_dict"], strict=False)
|
||||||
|
|
||||||
|
def get_sliced_data(self, override=None):
|
||||||
|
symbols = self.data_settings.symbols
|
||||||
|
|
||||||
|
in_out_delay = self.data_settings.in_out_delay
|
||||||
|
|
||||||
|
xy_delay = self.data_settings.xy_delay
|
||||||
|
|
||||||
|
data_size = self.data_settings.output_size
|
||||||
|
|
||||||
|
dtype = getattr(torch, self.data_settings.dtype)
|
||||||
|
|
||||||
|
num_symbols = None
|
||||||
|
config_path = self.data_settings.config_path
|
||||||
|
polarisations = self.data_settings.polarisations
|
||||||
|
randomise_polarisations = self.data_settings.randomise_polarisations
|
||||||
|
if override is not None:
|
||||||
|
num_symbols = override.get("num_symbols", None)
|
||||||
|
config_path = override.get("config_path", config_path)
|
||||||
|
polarisations = override.get("polarisations", polarisations)
|
||||||
|
randomise_polarisations = override.get("randomise_polarisation", randomise_polarisations)
|
||||||
|
# get dataset
|
||||||
|
dataset = FiberRegenerationDataset(
|
||||||
|
file_path=config_path,
|
||||||
|
symbols=symbols,
|
||||||
|
output_dim=data_size,
|
||||||
|
target_delay=in_out_delay,
|
||||||
|
xy_delay=xy_delay,
|
||||||
|
drop_first=self.data_settings.drop_first,
|
||||||
|
dtype=dtype,
|
||||||
|
real=not dtype.is_complex,
|
||||||
|
num_symbols=num_symbols,
|
||||||
|
polarisations=polarisations,
|
||||||
|
randomise_polarisations=randomise_polarisations,
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset_size = len(dataset)
|
||||||
|
indices = list(range(dataset_size))
|
||||||
|
split = int(np.floor(self.data_settings.train_split * dataset_size))
|
||||||
|
if self.data_settings.shuffle:
|
||||||
|
np.random.seed(self.global_settings.seed)
|
||||||
|
np.random.shuffle(indices)
|
||||||
|
|
||||||
|
train_indices, valid_indices = indices[:split], indices[split:]
|
||||||
|
|
||||||
|
if self.data_settings.shuffle:
|
||||||
|
train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
|
||||||
|
valid_sampler = torch.utils.data.SubsetRandomSampler(valid_indices)
|
||||||
|
else:
|
||||||
|
train_sampler = train_indices
|
||||||
|
valid_sampler = valid_indices
|
||||||
|
|
||||||
|
train_loader = torch.utils.data.DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=self.pytorch_settings.batchsize,
|
||||||
|
sampler=train_sampler,
|
||||||
|
drop_last=True,
|
||||||
|
pin_memory=True,
|
||||||
|
num_workers=self.pytorch_settings.dataloader_workers,
|
||||||
|
prefetch_factor=self.pytorch_settings.dataloader_prefetch,
|
||||||
|
)
|
||||||
|
|
||||||
|
valid_loader = torch.utils.data.DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=self.pytorch_settings.batchsize,
|
||||||
|
sampler=valid_sampler,
|
||||||
|
drop_last=True,
|
||||||
|
pin_memory=True,
|
||||||
|
num_workers=self.pytorch_settings.dataloader_workers,
|
||||||
|
prefetch_factor=self.pytorch_settings.dataloader_prefetch,
|
||||||
|
)
|
||||||
|
|
||||||
|
return train_loader, valid_loader
|
||||||
|
|
||||||
|
def train_model(
|
||||||
|
self,
|
||||||
|
optimizer,
|
||||||
|
train_loader,
|
||||||
|
epoch,
|
||||||
|
enable_progress=False,
|
||||||
|
):
|
||||||
|
if enable_progress:
|
||||||
|
progress = Progress(
|
||||||
|
TextColumn("[yellow] Training..."),
|
||||||
|
TextColumn("Error: {task.description}"),
|
||||||
|
BarColumn(),
|
||||||
|
TaskProgressColumn(),
|
||||||
|
TextColumn("[green]Batch"),
|
||||||
|
MofNCompleteColumn(),
|
||||||
|
TimeRemainingColumn(),
|
||||||
|
TimeElapsedColumn(),
|
||||||
|
transient=False,
|
||||||
|
console=self.console,
|
||||||
|
refresh_per_second=10,
|
||||||
|
)
|
||||||
|
task = progress.add_task("-.---e--", total=len(train_loader))
|
||||||
|
progress.start()
|
||||||
|
|
||||||
|
running_loss2 = 0.0
|
||||||
|
running_loss = 0.0
|
||||||
|
self.model.train()
|
||||||
|
loader_len = len(train_loader)
|
||||||
|
write_div = 0
|
||||||
|
loss_div = 0
|
||||||
|
for batch_idx, batch in enumerate(train_loader):
|
||||||
|
x = batch["x"]
|
||||||
|
y = batch["sop"]
|
||||||
|
self.model.zero_grad(set_to_none=True)
|
||||||
|
x, y = (
|
||||||
|
x.to(self.pytorch_settings.device),
|
||||||
|
y.to(self.pytorch_settings.device),
|
||||||
|
)
|
||||||
|
y_pred = self.model(x)
|
||||||
|
# loss = torch.nn.functional.smooth_l1_loss(torch.cos(torch.fmod(y_pred, torch.pi/2)).squeeze(), torch.cos(torch.fmod(y, torch.pi/2)).squeeze(), beta=0.5)
|
||||||
|
loss = torch.nn.functional.mse_loss(y_pred, y)
|
||||||
|
# loss = util.complexNN.naive_angle_loss(y_pred, y, mod=torch.pi/2)
|
||||||
|
loss_value = loss.item()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
running_loss += loss_value
|
||||||
|
running_loss2 += loss_value
|
||||||
|
write_div += 1
|
||||||
|
loss_div += 1
|
||||||
|
|
||||||
|
if enable_progress:
|
||||||
|
progress.update(task, advance=1, description=f"{loss_value:.3e}")
|
||||||
|
|
||||||
|
if batch_idx % self.pytorch_settings.write_every == 0:
|
||||||
|
self.writer.add_scalar(
|
||||||
|
"training loss",
|
||||||
|
running_loss2 / write_div,
|
||||||
|
epoch * loader_len + batch_idx,
|
||||||
|
)
|
||||||
|
running_loss2 = 0.0
|
||||||
|
write_div = 0
|
||||||
|
|
||||||
|
if enable_progress:
|
||||||
|
progress.stop()
|
||||||
|
|
||||||
|
return running_loss / loss_div
|
||||||
|
|
||||||
|
def eval_model(self, valid_loader, epoch, enable_progress=True):
|
||||||
|
if enable_progress:
|
||||||
|
progress = Progress(
|
||||||
|
TextColumn("[green]Evaluating..."),
|
||||||
|
TextColumn("Error: {task.description}"),
|
||||||
|
BarColumn(),
|
||||||
|
TaskProgressColumn(),
|
||||||
|
TextColumn("[green]Batch"),
|
||||||
|
MofNCompleteColumn(),
|
||||||
|
TimeRemainingColumn(),
|
||||||
|
TimeElapsedColumn(),
|
||||||
|
transient=False,
|
||||||
|
console=self.console,
|
||||||
|
refresh_per_second=10,
|
||||||
|
)
|
||||||
|
progress.start()
|
||||||
|
task = progress.add_task("-.---e--", total=len(valid_loader))
|
||||||
|
|
||||||
|
self.model.eval()
|
||||||
|
running_loss = 0
|
||||||
|
loss_div = 0
|
||||||
|
with torch.no_grad():
|
||||||
|
for _, batch in enumerate(valid_loader):
|
||||||
|
x = batch["x"]
|
||||||
|
y = batch["sop"]
|
||||||
|
x, y = (
|
||||||
|
x.to(self.pytorch_settings.device),
|
||||||
|
y.to(self.pytorch_settings.device),
|
||||||
|
)
|
||||||
|
y_pred = self.model(x)
|
||||||
|
# loss = torch.nn.functional.smooth_l1_loss(torch.cos(torch.fmod(y_pred, torch.pi/2)).squeeze(), torch.cos(torch.fmod(y, torch.pi/2)).squeeze(), beta=0.5)
|
||||||
|
loss = torch.nn.functional.mse_loss(y_pred, y)
|
||||||
|
# loss = util.complexNN.naive_angle_loss(y_pred, y, mod=torch.pi/2)
|
||||||
|
loss_value = loss.item()
|
||||||
|
running_loss += loss_value
|
||||||
|
loss_div += 1
|
||||||
|
|
||||||
|
if enable_progress:
|
||||||
|
progress.update(task, advance=1, description=f"{loss_value:.3e}")
|
||||||
|
|
||||||
|
running_loss = running_loss/loss_div
|
||||||
|
|
||||||
|
self.writer.add_scalar(
|
||||||
|
"eval loss",
|
||||||
|
running_loss,
|
||||||
|
epoch,
|
||||||
|
)
|
||||||
|
|
||||||
|
# self.write_parameters(epoch + 1)
|
||||||
|
self.writer.flush()
|
||||||
|
|
||||||
|
if enable_progress:
|
||||||
|
progress.stop()
|
||||||
|
|
||||||
|
return running_loss
|
||||||
|
|
||||||
|
# def run_model(self, model, loader, trace_powers=False):
|
||||||
|
# model.eval()
|
||||||
|
# fiber_out = []
|
||||||
|
# fiber_in = []
|
||||||
|
# regen = []
|
||||||
|
# timestamps = []
|
||||||
|
|
||||||
|
# with torch.no_grad():
|
||||||
|
# model = model.to(self.pytorch_settings.device)
|
||||||
|
# for batch in loader:
|
||||||
|
# x = batch["x"]
|
||||||
|
# y = batch["angle"]
|
||||||
|
# timestamp = batch["timestamp"]
|
||||||
|
# plot_data = batch["plot_data"]
|
||||||
|
# x, y = (
|
||||||
|
# x.to(self.pytorch_settings.device),
|
||||||
|
# y.to(self.pytorch_settings.device),
|
||||||
|
# )
|
||||||
|
# if trace_powers:
|
||||||
|
# y_pred, powers = model(x, trace_powers).cpu()
|
||||||
|
# else:
|
||||||
|
# y_pred = model(x, trace_powers).cpu()
|
||||||
|
# # x = x.cpu()
|
||||||
|
# # y = y.cpu()
|
||||||
|
# y_pred = y_pred.view(y_pred.shape[0], -1, 2)
|
||||||
|
# y = y.view(y.shape[0], -1, 2)
|
||||||
|
# plot_data = plot_data.view(plot_data.shape[0], -1, 2)
|
||||||
|
# # x = x.view(x.shape[0], -1, 2)
|
||||||
|
|
||||||
|
# # timestamp = timestamp.view(-1, 1)
|
||||||
|
# fiber_out.append(plot_data.squeeze())
|
||||||
|
# fiber_in.append(y.squeeze())
|
||||||
|
# regen.append(y_pred.squeeze())
|
||||||
|
# timestamps.append(timestamp.squeeze())
|
||||||
|
|
||||||
|
# fiber_out = torch.vstack(fiber_out).cpu()
|
||||||
|
# fiber_in = torch.vstack(fiber_in).cpu()
|
||||||
|
# regen = torch.vstack(regen).cpu()
|
||||||
|
# timestamps = torch.concat(timestamps).cpu()
|
||||||
|
# if trace_powers:
|
||||||
|
# return fiber_in, fiber_out, regen, timestamps, powers
|
||||||
|
# return fiber_in, fiber_out, regen, timestamps
|
||||||
|
|
||||||
|
def write_parameters(self, epoch, attributes: list[str] | tuple[str] = None):
|
||||||
|
parameter_list = get_parameter_names_and_values(self.model)
|
||||||
|
for name, value in parameter_list:
|
||||||
|
plot = (attributes is None) or (name in attributes)
|
||||||
|
if plot:
|
||||||
|
vals: np.ndarray = value.detach().cpu().numpy().flatten()
|
||||||
|
if vals.ndim <= 1 and len(vals) == 1:
|
||||||
|
if np.iscomplexobj(vals):
|
||||||
|
self.writer.add_scalar(f"{name} (Mag)", np.abs(vals), epoch)
|
||||||
|
self.writer.add_scalar(f"{name} (Phase)", np.angle(vals), epoch)
|
||||||
|
else:
|
||||||
|
self.writer.add_scalar(f"{name}", vals, epoch)
|
||||||
|
else:
|
||||||
|
if np.iscomplexobj(vals):
|
||||||
|
self.writer.add_histogram(f"{name} (Mag)", np.abs(vals), epoch, bins="fd")
|
||||||
|
self.writer.add_histogram(f"{name} (Phase)", np.angle(vals), epoch, bins="fd")
|
||||||
|
else:
|
||||||
|
self.writer.add_histogram(f"{name}", vals, epoch, bins="fd")
|
||||||
|
|
||||||
|
def train(self):
|
||||||
|
if self.writer is None:
|
||||||
|
self.setup_tb_writer()
|
||||||
|
|
||||||
|
self.define_model()
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"number of parameters (trainable): {sum(p.numel() for p in self.model.parameters())} ({sum(p.numel() for p in self.model.parameters() if p.requires_grad)})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# self.write_parameters(0)
|
||||||
|
|
||||||
|
if isinstance(self.data_settings.config_path, (list, tuple)):
|
||||||
|
for i, config_path in enumerate(self.data_settings.config_path):
|
||||||
|
paths = Path.cwd().glob(config_path)
|
||||||
|
for j, path in enumerate(paths):
|
||||||
|
text = str(path) + '\n'
|
||||||
|
with open(path, 'r') as f:
|
||||||
|
text += f.read()
|
||||||
|
text += '\n'
|
||||||
|
self.writer.add_text(f"config_{i*len(self.data_settings.config_path)+j}", text)
|
||||||
|
|
||||||
|
elif isinstance(self.data_settings.config_path, str):
|
||||||
|
paths = Path.cwd().glob(self.data_settings.config_path)
|
||||||
|
for j, path in enumerate(paths):
|
||||||
|
text = str(path) + '\n'
|
||||||
|
with open(path, 'r') as f:
|
||||||
|
text += f.read()
|
||||||
|
text += '\n'
|
||||||
|
self.writer.add_text(f"config_{j}", text)
|
||||||
|
|
||||||
|
self.writer.flush()
|
||||||
|
|
||||||
|
train_loader, valid_loader = self.get_sliced_data()
|
||||||
|
|
||||||
|
optimizer_name = self.optimizer_settings.optimizer
|
||||||
|
|
||||||
|
self.optimizer: optim.Optimizer = getattr(optim, optimizer_name)(
|
||||||
|
self.model.parameters(), **self.optimizer_settings.optimizer_kwargs
|
||||||
|
)
|
||||||
|
if self.optimizer_settings.scheduler is not None:
|
||||||
|
self.scheduler = getattr(optim.lr_scheduler, self.optimizer_settings.scheduler)(
|
||||||
|
self.optimizer, **self.optimizer_settings.scheduler_kwargs
|
||||||
|
)
|
||||||
|
self.writer.add_scalar("learning rate", self.optimizer.param_groups[0]["lr"], -1)
|
||||||
|
|
||||||
|
if not self.resume:
|
||||||
|
self.best = self.build_checkpoint_dict()
|
||||||
|
else:
|
||||||
|
self.best = self.checkpoint_dict
|
||||||
|
self.best["loss"] = float("inf")
|
||||||
|
|
||||||
|
for epoch in range(self.best["epoch"] + 1, self.pytorch_settings.epochs):
|
||||||
|
enable_progress = True
|
||||||
|
if enable_progress:
|
||||||
|
self.console.rule(f"Epoch {epoch + 1}/{self.pytorch_settings.epochs}")
|
||||||
|
self.train_model(
|
||||||
|
self.optimizer,
|
||||||
|
train_loader,
|
||||||
|
epoch,
|
||||||
|
enable_progress=enable_progress,
|
||||||
|
)
|
||||||
|
loss = self.eval_model(
|
||||||
|
valid_loader,
|
||||||
|
epoch,
|
||||||
|
enable_progress=enable_progress,
|
||||||
|
)
|
||||||
|
if self.optimizer_settings.scheduler is not None:
|
||||||
|
self.scheduler.step(loss)
|
||||||
|
self.writer.add_scalar("learning rate", self.optimizer.param_groups[0]["lr"], epoch)
|
||||||
|
if self.pytorch_settings.save_models and self.model is not None:
|
||||||
|
save_path = (
|
||||||
|
Path(self.pytorch_settings.model_dir) / f"pol_{self.writer.get_logdir().split('/')[-1]}_{epoch}.tar"
|
||||||
|
)
|
||||||
|
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
checkpoint = self.build_checkpoint_dict(loss, epoch)
|
||||||
|
self.save_checkpoint(checkpoint, save_path)
|
||||||
|
|
||||||
|
if loss < self.best["loss"]:
|
||||||
|
self.best = checkpoint
|
||||||
|
save_path = (
|
||||||
|
Path(self.pytorch_settings.model_dir) / f"best_pol_{self.writer.get_logdir().split('/')[-1]}.tar"
|
||||||
|
)
|
||||||
|
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
self.save_checkpoint(self.best, save_path)
|
||||||
|
self.writer.flush()
|
||||||
|
|
||||||
|
self.writer.close()
|
||||||
|
return self.best
|
||||||
|
|
||||||
|
class RegenerationTrainer:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
@@ -82,10 +602,11 @@ class Trainer:
|
|||||||
ModelSettings,
|
ModelSettings,
|
||||||
OptimizerSettings,
|
OptimizerSettings,
|
||||||
PytorchSettings,
|
PytorchSettings,
|
||||||
util.complexNN.regenerator,
|
models.regenerator,
|
||||||
torch.nn.utils.parametrizations.orthogonal,
|
torch.nn.utils.parametrizations.orthogonal,
|
||||||
])
|
])
|
||||||
if self.resume:
|
if self.resume:
|
||||||
|
print(f"loading checkpoint from {checkpoint_path}")
|
||||||
self.checkpoint_dict = torch.load(checkpoint_path, weights_only=True)
|
self.checkpoint_dict = torch.load(checkpoint_path, weights_only=True)
|
||||||
if settings_override is not None:
|
if settings_override is not None:
|
||||||
traverse_dict_update(self.checkpoint_dict["settings"], settings_override)
|
traverse_dict_update(self.checkpoint_dict["settings"], settings_override)
|
||||||
@@ -170,11 +691,13 @@ class Trainer:
|
|||||||
self.model_kwargs = {
|
self.model_kwargs = {
|
||||||
"dims": (input_dim, *hidden_dims, self.model_settings.output_dim),
|
"dims": (input_dim, *hidden_dims, self.model_settings.output_dim),
|
||||||
"layer_function": layer_func,
|
"layer_function": layer_func,
|
||||||
"layer_parametrizations": layer_parametrizations,
|
"layer_func_kwargs": self.model_settings.model_layer_kwargs,
|
||||||
"activation_function": afunc,
|
"act_function": afunc,
|
||||||
|
"act_func_kwargs": None,
|
||||||
|
"parametrizations": layer_parametrizations,
|
||||||
"dtype": dtype,
|
"dtype": dtype,
|
||||||
"dropout_prob": self.model_settings.dropout_prob,
|
"dropout_prob": self.model_settings.dropout_prob,
|
||||||
"scale": self.model_settings.scale,
|
"scale_layers": self.model_settings.scale,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
self.model_kwargs = model_kwargs
|
self.model_kwargs = model_kwargs
|
||||||
@@ -182,7 +705,8 @@ class Trainer:
|
|||||||
dtype = self.model_kwargs["dtype"]
|
dtype = self.model_kwargs["dtype"]
|
||||||
|
|
||||||
# dims = self.model_kwargs.pop("dims")
|
# dims = self.model_kwargs.pop("dims")
|
||||||
self.model = util.complexNN.regenerator(**self.model_kwargs)
|
model_kwargs = copy.deepcopy(self.model_kwargs)
|
||||||
|
self.model = models.regenerator(*model_kwargs.pop('dims'),**model_kwargs)
|
||||||
|
|
||||||
if self.writer is not None:
|
if self.writer is not None:
|
||||||
self.writer.add_graph(self.model, torch.zeros(1, input_dim, dtype=dtype))
|
self.writer.add_graph(self.model, torch.zeros(1, input_dim, dtype=dtype))
|
||||||
@@ -204,9 +728,13 @@ class Trainer:
|
|||||||
|
|
||||||
num_symbols = None
|
num_symbols = None
|
||||||
config_path = self.data_settings.config_path
|
config_path = self.data_settings.config_path
|
||||||
|
polarisations = self.data_settings.polarisations
|
||||||
|
randomise_polarisations = self.data_settings.randomise_polarisations
|
||||||
if override is not None:
|
if override is not None:
|
||||||
num_symbols = override.get("num_symbols", None)
|
num_symbols = override.get("num_symbols", None)
|
||||||
config_path = override.get("config_path", config_path)
|
config_path = override.get("config_path", config_path)
|
||||||
|
polarisations = override.get("polarisations", polarisations)
|
||||||
|
randomise_polarisations = override.get("randomise_polarisation", randomise_polarisations)
|
||||||
# get dataset
|
# get dataset
|
||||||
dataset = FiberRegenerationDataset(
|
dataset = FiberRegenerationDataset(
|
||||||
file_path=config_path,
|
file_path=config_path,
|
||||||
@@ -218,6 +746,8 @@ class Trainer:
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
real=not dtype.is_complex,
|
real=not dtype.is_complex,
|
||||||
num_symbols=num_symbols,
|
num_symbols=num_symbols,
|
||||||
|
polarisations=polarisations,
|
||||||
|
randomise_polarisations=randomise_polarisations,
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset_size = len(dataset)
|
dataset_size = len(dataset)
|
||||||
@@ -286,7 +816,9 @@ class Trainer:
|
|||||||
running_loss = 0.0
|
running_loss = 0.0
|
||||||
self.model.train()
|
self.model.train()
|
||||||
loader_len = len(train_loader)
|
loader_len = len(train_loader)
|
||||||
for batch_idx, (x, y, _) in enumerate(train_loader):
|
for batch_idx, batch in enumerate(train_loader):
|
||||||
|
x = batch["x"]
|
||||||
|
y = batch["y"]
|
||||||
self.model.zero_grad(set_to_none=True)
|
self.model.zero_grad(set_to_none=True)
|
||||||
x, y = (
|
x, y = (
|
||||||
x.to(self.pytorch_settings.device),
|
x.to(self.pytorch_settings.device),
|
||||||
@@ -307,7 +839,7 @@ class Trainer:
|
|||||||
self.writer.add_scalar(
|
self.writer.add_scalar(
|
||||||
"training loss",
|
"training loss",
|
||||||
running_loss2 / (self.pytorch_settings.write_every if batch_idx > 0 else 1),
|
running_loss2 / (self.pytorch_settings.write_every if batch_idx > 0 else 1),
|
||||||
epoch + batch_idx/loader_len,
|
epoch * loader_len + batch_idx,
|
||||||
)
|
)
|
||||||
running_loss2 = 0.0
|
running_loss2 = 0.0
|
||||||
|
|
||||||
@@ -337,7 +869,9 @@ class Trainer:
|
|||||||
self.model.eval()
|
self.model.eval()
|
||||||
running_error = 0
|
running_error = 0
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for _, (x, y, _) in enumerate(valid_loader):
|
for _, batch in enumerate(valid_loader):
|
||||||
|
x = batch["x"]
|
||||||
|
y = batch["y"]
|
||||||
x, y = (
|
x, y = (
|
||||||
x.to(self.pytorch_settings.device),
|
x.to(self.pytorch_settings.device),
|
||||||
y.to(self.pytorch_settings.device),
|
y.to(self.pytorch_settings.device),
|
||||||
@@ -360,37 +894,26 @@ class Trainer:
|
|||||||
if (epoch + 1) % 10 == 0 or epoch < 10:
|
if (epoch + 1) % 10 == 0 or epoch < 10:
|
||||||
# plotting is slow, so only do it every 10 epochs
|
# plotting is slow, so only do it every 10 epochs
|
||||||
title_append, subtitle = self.build_title(epoch + 1)
|
title_append, subtitle = self.build_title(epoch + 1)
|
||||||
self.writer.add_figure(
|
head_fig, eye_fig, powers_fig = self.plot_model_response(
|
||||||
"fiber response",
|
|
||||||
self.plot_model_response(
|
|
||||||
model=self.model,
|
model=self.model,
|
||||||
title_append=title_append,
|
title_append=title_append,
|
||||||
subtitle=subtitle,
|
subtitle=subtitle,
|
||||||
show=False,
|
show=False,
|
||||||
),
|
)
|
||||||
|
self.writer.add_figure(
|
||||||
|
"fiber response",
|
||||||
|
head_fig,
|
||||||
epoch + 1,
|
epoch + 1,
|
||||||
)
|
)
|
||||||
self.writer.add_figure(
|
self.writer.add_figure(
|
||||||
"eye diagram",
|
"eye diagram",
|
||||||
self.plot_model_response(
|
eye_fig,
|
||||||
model=self.model,
|
|
||||||
title_append=title_append,
|
|
||||||
subtitle=subtitle,
|
|
||||||
show=False,
|
|
||||||
mode="eye",
|
|
||||||
),
|
|
||||||
epoch + 1,
|
epoch + 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.writer.add_figure(
|
self.writer.add_figure(
|
||||||
"powers",
|
"powers",
|
||||||
self.plot_model_response(
|
powers_fig,
|
||||||
model=self.model,
|
|
||||||
title_append=title_append,
|
|
||||||
subtitle=subtitle,
|
|
||||||
mode="powers",
|
|
||||||
show=False,
|
|
||||||
),
|
|
||||||
epoch + 1,
|
epoch + 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -411,7 +934,11 @@ class Trainer:
|
|||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
model = model.to(self.pytorch_settings.device)
|
model = model.to(self.pytorch_settings.device)
|
||||||
for x, y, timestamp in loader:
|
for batch in loader:
|
||||||
|
x = batch["x"]
|
||||||
|
y = batch["y"]
|
||||||
|
timestamp = batch["timestamp"]
|
||||||
|
plot_data = batch["plot_data"]
|
||||||
x, y = (
|
x, y = (
|
||||||
x.to(self.pytorch_settings.device),
|
x.to(self.pytorch_settings.device),
|
||||||
y.to(self.pytorch_settings.device),
|
y.to(self.pytorch_settings.device),
|
||||||
@@ -424,9 +951,11 @@ class Trainer:
|
|||||||
# y = y.cpu()
|
# y = y.cpu()
|
||||||
y_pred = y_pred.view(y_pred.shape[0], -1, 2)
|
y_pred = y_pred.view(y_pred.shape[0], -1, 2)
|
||||||
y = y.view(y.shape[0], -1, 2)
|
y = y.view(y.shape[0], -1, 2)
|
||||||
x = x.view(x.shape[0], -1, 2)
|
plot_data = plot_data.view(plot_data.shape[0], -1, 2)
|
||||||
|
# x = x.view(x.shape[0], -1, 2)
|
||||||
|
|
||||||
# timestamp = timestamp.view(-1, 1)
|
# timestamp = timestamp.view(-1, 1)
|
||||||
fiber_out.append(x[:, x.shape[1] // 2, :].squeeze())
|
fiber_out.append(plot_data.squeeze())
|
||||||
fiber_in.append(y.squeeze())
|
fiber_in.append(y.squeeze())
|
||||||
regen.append(y_pred.squeeze())
|
regen.append(y_pred.squeeze())
|
||||||
timestamps.append(timestamp.squeeze())
|
timestamps.append(timestamp.squeeze())
|
||||||
@@ -440,28 +969,23 @@ class Trainer:
|
|||||||
return fiber_in, fiber_out, regen, timestamps
|
return fiber_in, fiber_out, regen, timestamps
|
||||||
|
|
||||||
def write_parameters(self, epoch, attributes: list[str] | tuple[str] = None):
|
def write_parameters(self, epoch, attributes: list[str] | tuple[str] = None):
|
||||||
for i, layer in enumerate(self.model._layers):
|
parameter_list = get_parameter_names_and_values(self.model)
|
||||||
tag = f"layer {i}"
|
for name, value in parameter_list:
|
||||||
if hasattr(layer, "parametrizations"):
|
plot = (attributes is None) or (name in attributes)
|
||||||
attribute_pool = set(layer.parametrizations._modules) | set(layer._parameters)
|
|
||||||
else:
|
|
||||||
attribute_pool = set(layer._parameters)
|
|
||||||
for attribute in attribute_pool:
|
|
||||||
plot = (attributes is None) or (attribute in attributes)
|
|
||||||
if plot:
|
if plot:
|
||||||
vals: np.ndarray = getattr(layer, attribute).detach().cpu().numpy().flatten()
|
vals: np.ndarray = value.detach().cpu().numpy().flatten()
|
||||||
if vals.ndim <= 1 and len(vals) == 1:
|
if vals.ndim <= 1 and len(vals) == 1:
|
||||||
if np.iscomplexobj(vals):
|
if np.iscomplexobj(vals):
|
||||||
self.writer.add_scalar(f"{tag} {attribute} (Mag)", np.abs(vals), epoch)
|
self.writer.add_scalar(f"{name} (Mag)", np.abs(vals), epoch)
|
||||||
self.writer.add_scalar(f"{tag} {attribute} (Phase)", np.angle(vals), epoch)
|
self.writer.add_scalar(f"{name} (Phase)", np.angle(vals), epoch)
|
||||||
else:
|
else:
|
||||||
self.writer.add_scalar(f"{tag} {attribute}", vals, epoch)
|
self.writer.add_scalar(f"{name}", vals, epoch)
|
||||||
else:
|
else:
|
||||||
if np.iscomplexobj(vals):
|
if np.iscomplexobj(vals):
|
||||||
self.writer.add_histogram(f"{tag} {attribute} (Mag)", np.abs(vals), epoch, bins="fd")
|
self.writer.add_histogram(f"{name} (Mag)", np.abs(vals), epoch, bins="fd")
|
||||||
self.writer.add_histogram(f"{tag} {attribute} (Phase)", np.angle(vals), epoch, bins="fd")
|
self.writer.add_histogram(f"{name} (Phase)", np.angle(vals), epoch, bins="fd")
|
||||||
else:
|
else:
|
||||||
self.writer.add_histogram(f"{tag} {attribute}", vals, epoch, bins="fd")
|
self.writer.add_histogram(f"{name}", vals, epoch, bins="fd")
|
||||||
|
|
||||||
def train(self):
|
def train(self):
|
||||||
if self.writer is None:
|
if self.writer is None:
|
||||||
@@ -474,44 +998,48 @@ class Trainer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
title_append, subtitle = self.build_title(0)
|
title_append, subtitle = self.build_title(0)
|
||||||
|
head_fig, eye_fig, powers_fig = self.plot_model_response(
|
||||||
self.writer.add_figure(
|
|
||||||
"fiber response",
|
|
||||||
self.plot_model_response(
|
|
||||||
model=self.model,
|
model=self.model,
|
||||||
title_append=title_append,
|
title_append=title_append,
|
||||||
subtitle=subtitle,
|
subtitle=subtitle,
|
||||||
show=False,
|
show=False,
|
||||||
),
|
)
|
||||||
|
self.writer.add_figure(
|
||||||
|
"fiber response",
|
||||||
|
head_fig,
|
||||||
0,
|
0,
|
||||||
)
|
)
|
||||||
self.writer.add_figure(
|
self.writer.add_figure(
|
||||||
"eye diagram",
|
"eye diagram",
|
||||||
self.plot_model_response(
|
eye_fig,
|
||||||
model=self.model,
|
|
||||||
title_append=title_append,
|
|
||||||
subtitle=subtitle,
|
|
||||||
mode="eye",
|
|
||||||
show=False,
|
|
||||||
),
|
|
||||||
0,
|
0,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.writer.add_figure(
|
self.writer.add_figure(
|
||||||
"powers",
|
"powers",
|
||||||
self.plot_model_response(
|
powers_fig,
|
||||||
model=self.model,
|
|
||||||
title_append=title_append,
|
|
||||||
subtitle=subtitle,
|
|
||||||
mode="powers",
|
|
||||||
show=False,
|
|
||||||
),
|
|
||||||
0,
|
0,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.write_parameters(0)
|
self.write_parameters(0)
|
||||||
|
|
||||||
self.writer.add_text("datasets", '\n'.join(self.data_settings.config_path))
|
if isinstance(self.data_settings.config_path, (list, tuple)):
|
||||||
|
for i, config_path in enumerate(self.data_settings.config_path):
|
||||||
|
paths = Path.cwd().glob(config_path)
|
||||||
|
for j, path in enumerate(paths):
|
||||||
|
text = str(path) + '\n'
|
||||||
|
with open(path, 'r') as f:
|
||||||
|
text += f.read()
|
||||||
|
text += '\n'
|
||||||
|
self.writer.add_text(f"config_{i*len(self.data_settings.config_path)+j}", text)
|
||||||
|
elif isinstance(self.data_settings.config_path, str):
|
||||||
|
paths = Path.cwd().glob(self.data_settings.config_path)
|
||||||
|
for j, path in enumerate(paths):
|
||||||
|
text = str(path) + '\n'
|
||||||
|
with open(path, 'r') as f:
|
||||||
|
text += f.read()
|
||||||
|
text += '\n'
|
||||||
|
self.writer.add_text(f"config_{j}", text)
|
||||||
|
|
||||||
self.writer.flush()
|
self.writer.flush()
|
||||||
|
|
||||||
@@ -741,13 +1269,12 @@ class Trainer:
|
|||||||
|
|
||||||
def plot_model_response(
|
def plot_model_response(
|
||||||
self,
|
self,
|
||||||
model=None,
|
model:torch.nn.Module=None,
|
||||||
title_append="",
|
title_append="",
|
||||||
subtitle="",
|
subtitle="",
|
||||||
mode: Literal["eye", "head", "powers"] = "head",
|
# mode: Literal["eye", "head", "powers"] = "head",
|
||||||
show=False,
|
show=False,
|
||||||
):
|
):
|
||||||
if mode == "powers":
|
|
||||||
input_data = torch.ones(
|
input_data = torch.ones(
|
||||||
1, 2 * self.data_settings.output_size, dtype=getattr(torch, self.data_settings.dtype)
|
1, 2 * self.data_settings.output_size, dtype=getattr(torch, self.data_settings.dtype)
|
||||||
).to(self.pytorch_settings.device)
|
).to(self.pytorch_settings.device)
|
||||||
@@ -757,38 +1284,35 @@ class Trainer:
|
|||||||
_, powers = model(input_data, trace_powers=True)
|
_, powers = model(input_data, trace_powers=True)
|
||||||
|
|
||||||
powers = [power.item() for power in powers]
|
powers = [power.item() for power in powers]
|
||||||
layer_names = ["input", *[str(x).split("(")[0] for x in model._layers._modules.values()]]
|
layer_names = [name for (name, _) in model.named_children()]
|
||||||
|
|
||||||
# remove dropout layers
|
power_fig = self._plot_model_response_powers(
|
||||||
mask = [1 if "Dropout" not in layer_name else 0 for layer_name in layer_names]
|
|
||||||
layer_names = [layer_name for layer_name, m in zip(layer_names, mask) if m]
|
|
||||||
powers = [power for power, m in zip(powers, mask) if m]
|
|
||||||
|
|
||||||
fig = self._plot_model_response_powers(
|
|
||||||
powers, layer_names, title_append=title_append, subtitle=subtitle, show=show
|
powers, layer_names, title_append=title_append, subtitle=subtitle, show=show
|
||||||
)
|
)
|
||||||
return fig
|
|
||||||
|
|
||||||
data_settings_backup = copy.deepcopy(self.data_settings)
|
data_settings_backup = copy.deepcopy(self.data_settings)
|
||||||
pytorch_settings_backup = copy.deepcopy(self.pytorch_settings)
|
pytorch_settings_backup = copy.deepcopy(self.pytorch_settings)
|
||||||
self.data_settings.drop_first = 99.5 + random.randint(0, 1000)
|
self.data_settings.drop_first = 99.5 + random.randint(0, 1000)
|
||||||
self.data_settings.shuffle = False
|
self.data_settings.shuffle = False
|
||||||
self.data_settings.train_split = 1.0
|
self.data_settings.train_split = 1.0
|
||||||
self.pytorch_settings.batchsize = (
|
self.pytorch_settings.batchsize = max(self.pytorch_settings.head_symbols, self.pytorch_settings.eye_symbols)
|
||||||
self.pytorch_settings.eye_symbols if mode == "eye" else self.pytorch_settings.head_symbols
|
|
||||||
)
|
|
||||||
config_path = random.choice(self.data_settings.config_path) if isinstance(self.data_settings.config_path, (list, tuple)) else self.data_settings.config_path
|
config_path = random.choice(self.data_settings.config_path) if isinstance(self.data_settings.config_path, (list, tuple)) else self.data_settings.config_path
|
||||||
fiber_length = int(float(str(config_path).split('-')[-7])/1000)
|
fiber_length = int(float(str(config_path).split('-')[4])/1000)
|
||||||
plot_loader, _ = self.get_sliced_data(
|
if not hasattr(self, "_plot_loader"):
|
||||||
|
self._plot_loader, _ = self.get_sliced_data(
|
||||||
override={
|
override={
|
||||||
"num_symbols": self.pytorch_settings.batchsize,
|
"num_symbols": self.pytorch_settings.batchsize,
|
||||||
"config_path": config_path,
|
"config_path": config_path,
|
||||||
|
"shuffle": False,
|
||||||
|
"polarisations": (np.random.rand(1)*np.pi*2,),
|
||||||
|
"randomise_polarisation": False,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
self._sps = self._plot_loader.dataset.samples_per_symbol
|
||||||
self.data_settings = data_settings_backup
|
self.data_settings = data_settings_backup
|
||||||
self.pytorch_settings = pytorch_settings_backup
|
self.pytorch_settings = pytorch_settings_backup
|
||||||
|
|
||||||
fiber_in, fiber_out, regen, timestamps = self.run_model(model, plot_loader)
|
fiber_in, fiber_out, regen, timestamps = self.run_model(model, self._plot_loader)
|
||||||
fiber_in = fiber_in.view(-1, 2)
|
fiber_in = fiber_in.view(-1, 2)
|
||||||
fiber_out = fiber_out.view(-1, 2)
|
fiber_out = fiber_out.view(-1, 2)
|
||||||
regen = regen.view(-1, 2)
|
regen = regen.view(-1, 2)
|
||||||
@@ -802,36 +1326,32 @@ class Trainer:
|
|||||||
# https://github.com/matplotlib/matplotlib/issues/27713#issuecomment-1915497463
|
# https://github.com/matplotlib/matplotlib/issues/27713#issuecomment-1915497463
|
||||||
import gc
|
import gc
|
||||||
|
|
||||||
if mode == "head":
|
head_fig = self._plot_model_response_head(
|
||||||
fig = self._plot_model_response_head(
|
fiber_in[:self.pytorch_settings.head_symbols*self._sps],
|
||||||
fiber_in,
|
fiber_out[:self.pytorch_settings.head_symbols*self._sps],
|
||||||
fiber_out,
|
regen[:self.pytorch_settings.head_symbols*self._sps],
|
||||||
regen,
|
timestamps=timestamps[:self.pytorch_settings.head_symbols*self._sps],
|
||||||
timestamps=timestamps,
|
|
||||||
labels=("fiber in", "fiber out", "regen"),
|
labels=("fiber in", "fiber out", "regen"),
|
||||||
sps=plot_loader.dataset.samples_per_symbol,
|
sps=self._sps,
|
||||||
title_append=title_append + f" ({fiber_length} km)",
|
title_append=title_append + f" ({fiber_length} km)",
|
||||||
subtitle=subtitle,
|
subtitle=subtitle,
|
||||||
show=show,
|
show=show,
|
||||||
)
|
)
|
||||||
elif mode == "eye":
|
|
||||||
# raise NotImplementedError("Eye diagram not implemented")
|
# raise NotImplementedError("Eye diagram not implemented")
|
||||||
fig = self._plot_model_response_eye(
|
eye_fig = self._plot_model_response_eye(
|
||||||
fiber_in,
|
fiber_in[:self.pytorch_settings.eye_symbols*self._sps],
|
||||||
fiber_out,
|
fiber_out[:self.pytorch_settings.eye_symbols*self._sps],
|
||||||
regen,
|
regen[:self.pytorch_settings.eye_symbols*self._sps],
|
||||||
timestamps=timestamps,
|
timestamps=timestamps[:self.pytorch_settings.eye_symbols*self._sps],
|
||||||
labels=("fiber in", "fiber out", "regen"),
|
labels=("fiber in", "fiber out", "regen"),
|
||||||
sps=plot_loader.dataset.samples_per_symbol,
|
sps=self._sps,
|
||||||
title_append=title_append + f" ({fiber_length} km)",
|
title_append=title_append + f" ({fiber_length} km)",
|
||||||
subtitle=subtitle,
|
subtitle=subtitle,
|
||||||
show=show,
|
show=show,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown mode: {mode}")
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
return fig
|
return head_fig, eye_fig, power_fig
|
||||||
|
|
||||||
def build_title(self, number: int):
|
def build_title(self, number: int):
|
||||||
title_append = f"epoch {number}"
|
title_append = f"epoch {number}"
|
||||||
|
|||||||
@@ -1,7 +1,10 @@
|
|||||||
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import matplotlib
|
import matplotlib
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
import torch.utils.tensorboard
|
||||||
|
import torch.utils.tensorboard.summary
|
||||||
from hypertraining.settings import (
|
from hypertraining.settings import (
|
||||||
GlobalSettings,
|
GlobalSettings,
|
||||||
DataSettings,
|
DataSettings,
|
||||||
@@ -10,7 +13,7 @@ from hypertraining.settings import (
|
|||||||
OptimizerSettings,
|
OptimizerSettings,
|
||||||
)
|
)
|
||||||
|
|
||||||
from hypertraining.training import Trainer
|
from hypertraining.training import RegenerationTrainer, PolarizationTrainer
|
||||||
|
|
||||||
# import torch
|
# import torch
|
||||||
import json
|
import json
|
||||||
@@ -23,7 +26,7 @@ global_settings = GlobalSettings(
|
|||||||
)
|
)
|
||||||
|
|
||||||
data_settings = DataSettings(
|
data_settings = DataSettings(
|
||||||
config_path="data/20241204-13*-128-16384-100000-0-0-17-0-PAM4-0.ini",
|
config_path="data/*-128-16384-10000-0-0-0-0-PAM4-0-0.4.ini",
|
||||||
# config_path=[f"data/20241202-*-128-16384-{length}-0-0-17-0-PAM4-0.ini" for length in range(48000, 53000, 1000)],
|
# config_path=[f"data/20241202-*-128-16384-{length}-0-0-17-0-PAM4-0.ini" for length in range(48000, 53000, 1000)],
|
||||||
dtype="complex64",
|
dtype="complex64",
|
||||||
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
|
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
|
||||||
@@ -31,17 +34,16 @@ data_settings = DataSettings(
|
|||||||
# output_size = (11, 32), # 26 taps -> 2 taps per input symbol -> 1 tap every 1cm (model has 52 inputs (x/y))
|
# output_size = (11, 32), # 26 taps -> 2 taps per input symbol -> 1 tap every 1cm (model has 52 inputs (x/y))
|
||||||
output_size=26, # study: single_core_regen_20241123_011232 (model_input_dim/2)
|
output_size=26, # study: single_core_regen_20241123_011232 (model_input_dim/2)
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
in_out_delay=0,
|
drop_first=64,
|
||||||
xy_delay=0,
|
|
||||||
drop_first=128 * 64,
|
|
||||||
train_split=0.8,
|
train_split=0.8,
|
||||||
|
randomise_polarisations=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
pytorch_settings = PytorchSettings(
|
pytorch_settings = PytorchSettings(
|
||||||
epochs=10000,
|
epochs=10000,
|
||||||
batchsize=2**12,
|
batchsize=2**14,
|
||||||
device="cuda",
|
device="cuda",
|
||||||
dataloader_workers=12,
|
dataloader_workers=16,
|
||||||
dataloader_prefetch=8,
|
dataloader_prefetch=8,
|
||||||
summary_dir=".runs",
|
summary_dir=".runs",
|
||||||
write_every=2**5,
|
write_every=2**5,
|
||||||
@@ -51,12 +53,14 @@ pytorch_settings = PytorchSettings(
|
|||||||
|
|
||||||
model_settings = ModelSettings(
|
model_settings = ModelSettings(
|
||||||
output_dim=2,
|
output_dim=2,
|
||||||
n_hidden_layers=4,
|
n_hidden_layers=5,
|
||||||
overrides={
|
overrides={
|
||||||
|
# "hidden_layer_dims": (8, 8, 4, 4),
|
||||||
"n_hidden_nodes_0": 8,
|
"n_hidden_nodes_0": 8,
|
||||||
"n_hidden_nodes_1": 8,
|
"n_hidden_nodes_1": 8,
|
||||||
"n_hidden_nodes_2": 4,
|
"n_hidden_nodes_2": 4,
|
||||||
"n_hidden_nodes_3": 4,
|
"n_hidden_nodes_3": 4,
|
||||||
|
"n_hidden_nodes_4": 2,
|
||||||
},
|
},
|
||||||
model_activation_func="EOActivation",
|
model_activation_func="EOActivation",
|
||||||
dropout_prob=0.01,
|
dropout_prob=0.01,
|
||||||
@@ -92,6 +96,14 @@ model_settings = ModelSettings(
|
|||||||
"tensor_name": "scales",
|
"tensor_name": "scales",
|
||||||
"parametrization": util.complexNN.clamp,
|
"parametrization": util.complexNN.clamp,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"tensor_name": "angle",
|
||||||
|
"parametrization": util.complexNN.clamp,
|
||||||
|
"kwargs": {
|
||||||
|
"min": -torch.pi,
|
||||||
|
"max": torch.pi,
|
||||||
|
},
|
||||||
|
},
|
||||||
# {
|
# {
|
||||||
# "tensor_name": "scale",
|
# "tensor_name": "scale",
|
||||||
# "parametrization": util.complexNN.clamp,
|
# "parametrization": util.complexNN.clamp,
|
||||||
@@ -153,7 +165,7 @@ def sweep_lengths(*lengths, model=None, data_glob:str=None, strategy="newest"):
|
|||||||
regens = {}
|
regens = {}
|
||||||
timestampss = {}
|
timestampss = {}
|
||||||
|
|
||||||
trainer = Trainer(
|
trainer = RegenerationTrainer(
|
||||||
checkpoint_path=model,
|
checkpoint_path=model,
|
||||||
)
|
)
|
||||||
trainer.define_model()
|
trainer.define_model()
|
||||||
@@ -165,13 +177,13 @@ def sweep_lengths(*lengths, model=None, data_glob:str=None, strategy="newest"):
|
|||||||
continue
|
continue
|
||||||
if strategy == "newest":
|
if strategy == "newest":
|
||||||
sorted_kwargs = {
|
sorted_kwargs = {
|
||||||
'key': lambda x: x.stat().st_mtime,
|
"key": lambda x: x.stat().st_mtime,
|
||||||
'reverse': True,
|
"reverse": True,
|
||||||
}
|
}
|
||||||
elif strategy == "oldest":
|
elif strategy == "oldest":
|
||||||
sorted_kwargs = {
|
sorted_kwargs = {
|
||||||
'key': lambda x: x.stat().st_mtime,
|
"key": lambda x: x.stat().st_mtime,
|
||||||
'reverse': False,
|
"reverse": False,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown strategy {strategy}.")
|
raise ValueError(f"Unknown strategy {strategy}.")
|
||||||
@@ -193,7 +205,6 @@ def sweep_lengths(*lengths, model=None, data_glob:str=None, strategy="newest"):
|
|||||||
|
|
||||||
channel_names[1] = "fiber in x"
|
channel_names[1] = "fiber in x"
|
||||||
|
|
||||||
|
|
||||||
for li, length in enumerate(timestampss.keys()):
|
for li, length in enumerate(timestampss.keys()):
|
||||||
data[2 + 2 * li, 0, :] = timestampss[length] / 128
|
data[2 + 2 * li, 0, :] = timestampss[length] / 128
|
||||||
data[2 + 2 * li, 1, :] = fiber_outs[length][:, 0].abs().square()
|
data[2 + 2 * li, 1, :] = fiber_outs[length][:, 0].abs().square()
|
||||||
@@ -210,7 +221,7 @@ def sweep_lengths(*lengths, model=None, data_glob:str=None, strategy="newest"):
|
|||||||
eye = util.eye_diagram.eye_diagram(data.to(dtype=torch.float32).detach().cpu().numpy(), channel_names=channel_names)
|
eye = util.eye_diagram.eye_diagram(data.to(dtype=torch.float32).detach().cpu().numpy(), channel_names=channel_names)
|
||||||
|
|
||||||
print_attrs = ("channel_name", "success", "min_area")
|
print_attrs = ("channel_name", "success", "min_area")
|
||||||
with np.printoptions(precision=3, suppress=True, formatter={'float': '{:0.3e}'.format}):
|
with np.printoptions(precision=3, suppress=True, formatter={"float": "{:0.3e}".format}):
|
||||||
for result in eye.eye_stats:
|
for result in eye.eye_stats:
|
||||||
print_dict = {attr: result[attr] for attr in print_attrs}
|
print_dict = {attr: result[attr] for attr in print_attrs}
|
||||||
rprint(print_dict)
|
rprint(print_dict)
|
||||||
@@ -221,18 +232,77 @@ def sweep_lengths(*lengths, model=None, data_glob:str=None, strategy="newest"):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
# lengths = range(90000, 100000+10000, 10000)
|
||||||
lengths = range(90000, 100000+10000, 10000)
|
|
||||||
# lengths = [100000]
|
# lengths = [100000]
|
||||||
sweep_lengths(*lengths, model=".models/best_20241204_132605.tar", data_glob="data/202412*-{length}-*-0.ini", strategy="newest")
|
# sweep_lengths(*lengths, model=".models/best_20241204_132605.tar", data_glob="data/202412*-{length}-*-0.ini", strategy="newest")
|
||||||
|
|
||||||
# trainer = Trainer(
|
trainer = RegenerationTrainer(
|
||||||
# global_settings=global_settings,
|
global_settings=global_settings,
|
||||||
# data_settings=data_settings,
|
data_settings=data_settings,
|
||||||
# pytorch_settings=pytorch_settings,
|
pytorch_settings=pytorch_settings,
|
||||||
# model_settings=model_settings,
|
model_settings=model_settings,
|
||||||
# optimizer_settings=optimizer_settings,
|
optimizer_settings=optimizer_settings,
|
||||||
# # checkpoint_path=".models/best_20241202_143149.tar",
|
# checkpoint_path=".models/best_20241205_235929.tar",
|
||||||
# # 20241202_143149
|
# 20241202_143149
|
||||||
|
)
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
# from hypertraining.lighning_models import regenerator, regeneratorData
|
||||||
|
# import lightning as L
|
||||||
|
|
||||||
|
# model = regenerator(
|
||||||
|
# 2 * data_settings.output_size,
|
||||||
|
# *model_settings.overrides["hidden_layer_dims"],
|
||||||
|
# model_settings.output_dim,
|
||||||
|
# layer_function=getattr(util.complexNN, model_settings.model_layer_function),
|
||||||
|
# layer_func_kwargs=model_settings.model_layer_kwargs,
|
||||||
|
# act_function=getattr(util.complexNN, model_settings.model_activation_func),
|
||||||
|
# act_func_kwargs=None,
|
||||||
|
# parametrizations=model_settings.model_layer_parametrizations,
|
||||||
|
# dtype=getattr(torch, data_settings.dtype),
|
||||||
|
# dropout_prob=model_settings.dropout_prob,
|
||||||
|
# scale_layers=model_settings.scale,
|
||||||
|
# optimizer=getattr(torch.optim, optimizer_settings.optimizer),
|
||||||
|
# optimizer_kwargs=optimizer_settings.optimizer_kwargs,
|
||||||
|
# lr_scheduler=getattr(torch.optim.lr_scheduler, optimizer_settings.scheduler),
|
||||||
|
# lr_scheduler_kwargs=optimizer_settings.scheduler_kwargs,
|
||||||
# )
|
# )
|
||||||
# trainer.train()
|
|
||||||
|
# dm = regeneratorData(
|
||||||
|
# config_globs=data_settings.config_path,
|
||||||
|
# output_symbols=data_settings.symbols,
|
||||||
|
# output_dim=data_settings.output_size,
|
||||||
|
# dtype=getattr(torch, data_settings.dtype),
|
||||||
|
# drop_first=data_settings.drop_first,
|
||||||
|
# shuffle=data_settings.shuffle,
|
||||||
|
# train_split=data_settings.train_split,
|
||||||
|
# batch_size=pytorch_settings.batchsize,
|
||||||
|
# loader_settings={
|
||||||
|
# "num_workers": pytorch_settings.dataloader_workers,
|
||||||
|
# "prefetch_factor": pytorch_settings.dataloader_prefetch,
|
||||||
|
# "pin_memory": True,
|
||||||
|
# "drop_last": True,
|
||||||
|
# },
|
||||||
|
# seed=global_settings.seed,
|
||||||
|
# )
|
||||||
|
|
||||||
|
# # writer = L.SummaryWriter(pytorch_settings.summary_dir + f"/{datetime.now().strftime('%Y%m%d_%H%M%S')}")
|
||||||
|
|
||||||
|
# # from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
# subdir = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||||||
|
|
||||||
|
# # writer = SummaryWriter(pytorch_settings.summary_dir + f"/{subdir}")
|
||||||
|
|
||||||
|
# logger = L.pytorch.loggers.TensorBoardLogger(pytorch_settings.summary_dir, name=subdir, log_graph=True)
|
||||||
|
|
||||||
|
# trainer = L.Trainer(
|
||||||
|
# fast_dev_run=False,
|
||||||
|
# # max_epochs=pytorch_settings.epochs,
|
||||||
|
# max_epochs=2,
|
||||||
|
# enable_checkpointing=True,
|
||||||
|
# default_root_dir=f".models/{subdir}/",
|
||||||
|
# logger=logger,
|
||||||
|
# )
|
||||||
|
|
||||||
|
# trainer.fit(model, dm)
|
||||||
|
|||||||
230
src/single-core-regen/train_pol_estimator.py
Normal file
230
src/single-core-regen/train_pol_estimator.py
Normal file
@@ -0,0 +1,230 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
import matplotlib
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.utils.tensorboard
|
||||||
|
import torch.utils.tensorboard.summary
|
||||||
|
from hypertraining.settings import (
|
||||||
|
GlobalSettings,
|
||||||
|
DataSettings,
|
||||||
|
PytorchSettings,
|
||||||
|
ModelSettings,
|
||||||
|
OptimizerSettings,
|
||||||
|
)
|
||||||
|
|
||||||
|
from hypertraining.training import RegenerationTrainer, PolarizationTrainer
|
||||||
|
|
||||||
|
# import torch
|
||||||
|
import json
|
||||||
|
import util
|
||||||
|
|
||||||
|
from rich import print as rprint
|
||||||
|
|
||||||
|
global_settings = GlobalSettings(
|
||||||
|
seed=0xC0FFEE,
|
||||||
|
)
|
||||||
|
|
||||||
|
data_settings = DataSettings(
|
||||||
|
config_path="data/*-128-16384-10000-0-0-0-0-PAM4-0-0.4.ini",
|
||||||
|
# config_path=[f"data/20241202-*-128-16384-{length}-0-0-17-0-PAM4-0.ini" for length in range(48000, 53000, 1000)],
|
||||||
|
dtype="complex64",
|
||||||
|
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
|
||||||
|
symbols=13, # study: single_core_regen_20241123_011232
|
||||||
|
# output_size = (11, 32), # 26 taps -> 2 taps per input symbol -> 1 tap every 1cm (model has 52 inputs (x/y))
|
||||||
|
output_size=26, # study: single_core_regen_20241123_011232 (model_input_dim/2)
|
||||||
|
shuffle=True,
|
||||||
|
drop_first=64,
|
||||||
|
train_split=0.8,
|
||||||
|
# polarisations=tuple(np.random.rand(2)*2*np.pi),
|
||||||
|
randomise_polarisations=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
pytorch_settings = PytorchSettings(
|
||||||
|
epochs=10000,
|
||||||
|
batchsize=2**12,
|
||||||
|
device="cuda",
|
||||||
|
dataloader_workers=16,
|
||||||
|
dataloader_prefetch=8,
|
||||||
|
summary_dir=".runs",
|
||||||
|
write_every=2**5,
|
||||||
|
save_models=True,
|
||||||
|
model_dir=".models",
|
||||||
|
)
|
||||||
|
|
||||||
|
model_settings = ModelSettings(
|
||||||
|
output_dim=3,
|
||||||
|
n_hidden_layers=3,
|
||||||
|
overrides={
|
||||||
|
"n_hidden_nodes_0": 2,
|
||||||
|
"n_hidden_nodes_1": 2,
|
||||||
|
"n_hidden_nodes_2": 2,
|
||||||
|
},
|
||||||
|
dropout_prob=0.01,
|
||||||
|
model_layer_function="ONNRect",
|
||||||
|
model_activation_func="EOActivation",
|
||||||
|
model_layer_kwargs={"square": True},
|
||||||
|
scale=False,
|
||||||
|
model_layer_parametrizations=[
|
||||||
|
{
|
||||||
|
"tensor_name": "weight",
|
||||||
|
"parametrization": util.complexNN.energy_conserving,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tensor_name": "alpha",
|
||||||
|
"parametrization": util.complexNN.clamp,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tensor_name": "gain",
|
||||||
|
"parametrization": util.complexNN.clamp,
|
||||||
|
"kwargs": {
|
||||||
|
"min": 0,
|
||||||
|
"max": float("inf"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tensor_name": "phase_bias",
|
||||||
|
"parametrization": util.complexNN.clamp,
|
||||||
|
"kwargs": {
|
||||||
|
"min": 0,
|
||||||
|
"max": 2 * torch.pi,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tensor_name": "scales",
|
||||||
|
"parametrization": util.complexNN.clamp,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tensor_name": "angle",
|
||||||
|
"parametrization": util.complexNN.clamp,
|
||||||
|
"kwargs": {
|
||||||
|
"min": 0,
|
||||||
|
"max": 2*torch.pi,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tensor_name": "loss",
|
||||||
|
"parametrization": util.complexNN.clamp,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizer_settings = OptimizerSettings(
|
||||||
|
optimizer="AdamW",
|
||||||
|
optimizer_kwargs={
|
||||||
|
"lr": 0.005,
|
||||||
|
"amsgrad": True,
|
||||||
|
# "weight_decay": 1e-7,
|
||||||
|
},
|
||||||
|
# learning_rate=0.05,
|
||||||
|
scheduler="ReduceLROnPlateau",
|
||||||
|
scheduler_kwargs={
|
||||||
|
"patience": 2**6,
|
||||||
|
"factor": 0.75,
|
||||||
|
# "threshold": 1e-3,
|
||||||
|
"min_lr": 1e-6,
|
||||||
|
"cooldown": 10,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def save_dict_to_file(dictionary, filename):
|
||||||
|
"""
|
||||||
|
Save the best dictionary to a JSON file.
|
||||||
|
|
||||||
|
:param best: Dictionary containing the best training results.
|
||||||
|
:type best: dict
|
||||||
|
:param filename: Path to the JSON file where the dictionary will be saved.
|
||||||
|
:type filename: str
|
||||||
|
"""
|
||||||
|
with open(filename, "w") as f:
|
||||||
|
json.dump(dictionary, f, indent=4)
|
||||||
|
|
||||||
|
|
||||||
|
def sweep_lengths(*lengths, model=None, data_glob: str = None, strategy="newest"):
|
||||||
|
assert model is not None, "Model must be provided."
|
||||||
|
assert data_glob is not None, "Data glob must be provided."
|
||||||
|
model = model
|
||||||
|
|
||||||
|
fiber_ins = {}
|
||||||
|
fiber_outs = {}
|
||||||
|
regens = {}
|
||||||
|
timestampss = {}
|
||||||
|
|
||||||
|
trainer = RegenerationTrainer(
|
||||||
|
checkpoint_path=model,
|
||||||
|
)
|
||||||
|
trainer.define_model()
|
||||||
|
|
||||||
|
for length in lengths:
|
||||||
|
data_glob_length = data_glob.replace("{length}", str(length))
|
||||||
|
files = list(Path.cwd().glob(data_glob_length))
|
||||||
|
if len(files) == 0:
|
||||||
|
continue
|
||||||
|
if strategy == "newest":
|
||||||
|
sorted_kwargs = {
|
||||||
|
"key": lambda x: x.stat().st_mtime,
|
||||||
|
"reverse": True,
|
||||||
|
}
|
||||||
|
elif strategy == "oldest":
|
||||||
|
sorted_kwargs = {
|
||||||
|
"key": lambda x: x.stat().st_mtime,
|
||||||
|
"reverse": False,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown strategy {strategy}.")
|
||||||
|
file = sorted(files, **sorted_kwargs)[0]
|
||||||
|
|
||||||
|
loader, _ = trainer.get_sliced_data(override={"config_path": file})
|
||||||
|
fiber_in, fiber_out, regen, timestamps = trainer.run_model(trainer.model, loader=loader)
|
||||||
|
|
||||||
|
fiber_ins[length] = fiber_in
|
||||||
|
fiber_outs[length] = fiber_out
|
||||||
|
regens[length] = regen
|
||||||
|
timestampss[length] = timestamps
|
||||||
|
|
||||||
|
data = torch.zeros(2 * len(timestampss.keys()) + 2, 2, tuple(fiber_outs.values())[-1].shape[0])
|
||||||
|
channel_names = ["" for _ in range(2 * len(timestampss.keys()) + 2)]
|
||||||
|
|
||||||
|
data[1, 0, :] = timestampss[tuple(timestampss.keys())[-1]] / 128
|
||||||
|
data[1, 1, :] = fiber_ins[tuple(timestampss.keys())[-1]][:, 0].abs().square()
|
||||||
|
|
||||||
|
channel_names[1] = "fiber in x"
|
||||||
|
|
||||||
|
for li, length in enumerate(timestampss.keys()):
|
||||||
|
data[2 + 2 * li, 0, :] = timestampss[length] / 128
|
||||||
|
data[2 + 2 * li, 1, :] = fiber_outs[length][:, 0].abs().square()
|
||||||
|
data[2 + 2 * li + 1, 0, :] = timestampss[length] / 128
|
||||||
|
data[2 + 2 * li + 1, 1, :] = regens[length][:, 0].abs().square()
|
||||||
|
|
||||||
|
channel_names[2 + 2 * li + 1] = f"regen x {length}"
|
||||||
|
channel_names[2 + 2 * li] = f"fiber out x {length}"
|
||||||
|
|
||||||
|
# get current backend
|
||||||
|
backend = matplotlib.get_backend()
|
||||||
|
|
||||||
|
matplotlib.use("TkCairo")
|
||||||
|
eye = util.eye_diagram.eye_diagram(data.to(dtype=torch.float32).detach().cpu().numpy(), channel_names=channel_names)
|
||||||
|
|
||||||
|
print_attrs = ("channel_name", "success", "min_area")
|
||||||
|
with np.printoptions(precision=3, suppress=True, formatter={"float": "{:0.3e}".format}):
|
||||||
|
for result in eye.eye_stats:
|
||||||
|
print_dict = {attr: result[attr] for attr in print_attrs}
|
||||||
|
rprint(print_dict)
|
||||||
|
rprint()
|
||||||
|
|
||||||
|
eye.plot(all_stats=False)
|
||||||
|
matplotlib.use(backend)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
trainer = PolarizationTrainer(
|
||||||
|
global_settings=global_settings,
|
||||||
|
data_settings=data_settings,
|
||||||
|
pytorch_settings=pytorch_settings,
|
||||||
|
model_settings=model_settings,
|
||||||
|
optimizer_settings=optimizer_settings,
|
||||||
|
# checkpoint_path='.models/pol_pol_20241208_122418_1116.tar',
|
||||||
|
# reset_epoch=True
|
||||||
|
)
|
||||||
|
trainer.train()
|
||||||
@@ -260,12 +260,94 @@ class ONNRect(nn.Module):
|
|||||||
self.crop = lambda x: x
|
self.crop = lambda x: x
|
||||||
self.crop.__doc__ = "No cropping"
|
self.crop.__doc__ = "No cropping"
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.pad(x)
|
x = self.pad(x).to(dtype=self.weight.dtype)
|
||||||
out = self.crop((self.weight @ x.mT).mT)
|
out = self.crop((self.weight @ x.mT).mT)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class polarimeter(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(polarimeter, self).__init__()
|
||||||
|
# self.input_length = input_length
|
||||||
|
|
||||||
|
def forward(self, data):
|
||||||
|
# S0 = I
|
||||||
|
# S1 = (2*I_x - I)/I
|
||||||
|
# S2 = (2*I_45 - I)/I
|
||||||
|
# S3 = (2*I_RHC - I)/I
|
||||||
|
|
||||||
|
# # data: (batch, input_length*2) -> (batch, input_length, 2)
|
||||||
|
data = data.view(data.shape[0], -1, 2)
|
||||||
|
x = data[:, :, 0].mean(dim=1)
|
||||||
|
y = data[:, :, 1].mean(dim=1)
|
||||||
|
|
||||||
|
# x = x.mean(dim=1)
|
||||||
|
# y = y.mean(dim=1)
|
||||||
|
|
||||||
|
# angle = torch.atan2(y.abs().square().real, x.abs().square().real)
|
||||||
|
|
||||||
|
# return torch.stack([angle, angle, angle, angle], dim=1)
|
||||||
|
|
||||||
|
# horizontal polarisation
|
||||||
|
I_x = x.abs().square()
|
||||||
|
|
||||||
|
# vertical polarisation
|
||||||
|
I_y = y.abs().square()
|
||||||
|
|
||||||
|
# 45 degree polarisation
|
||||||
|
I_45 = (x + y).abs().square()
|
||||||
|
|
||||||
|
|
||||||
|
# right hand circular polarisation
|
||||||
|
I_RHC = (x + 1j*y).abs().square()
|
||||||
|
|
||||||
|
# S0 = I_x + I_y
|
||||||
|
# S1 = I_x - I_y
|
||||||
|
# S2 = I_45 - I_m45
|
||||||
|
# S3 = I_RHC - I_LHC
|
||||||
|
|
||||||
|
S0 = (I_x + I_y)
|
||||||
|
S1 = ((2*I_x - S0)/S0)
|
||||||
|
S2 = ((2*I_45 - S0)/S0)
|
||||||
|
S3 = ((2*I_RHC - S0)/S0)
|
||||||
|
|
||||||
|
return torch.stack([S0/S0, S1/S0, S2/S0, S3/S0], dim=1)
|
||||||
|
|
||||||
|
class normalize_by_first(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(normalize_by_first, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, data):
|
||||||
|
return data / data[:, 0].unsqueeze(1)
|
||||||
|
|
||||||
|
class photodiode(nn.Module):
|
||||||
|
def __init__(self, size, bias=True):
|
||||||
|
super(photodiode, self).__init__()
|
||||||
|
self.input_dim = size
|
||||||
|
self.scale = nn.Parameter(torch.rand(size))
|
||||||
|
self.pd_bias = nn.Parameter(torch.rand(size))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x.abs().square().to(dtype=x.dtype.to_real()).mul(self.scale).add(self.pd_bias)
|
||||||
|
|
||||||
|
|
||||||
|
class input_rotator(nn.Module):
|
||||||
|
def __init__(self, input_dim):
|
||||||
|
super(input_rotator, self).__init__()
|
||||||
|
assert input_dim % 2 == 0, "Input dimension must be even"
|
||||||
|
self.input_dim = input_dim
|
||||||
|
# self.angle = nn.Parameter(torch.randn(1, dtype=self.dtype.to_real()))
|
||||||
|
|
||||||
|
def forward(self, x, angle=None):
|
||||||
|
# take channels (0,1), (2,3), ... and rotate them by the angle
|
||||||
|
angle = angle or self.angle
|
||||||
|
sine = torch.sin(angle)
|
||||||
|
cosine = torch.cos(angle)
|
||||||
|
rot = torch.tensor([[cosine, -sine], [sine, cosine]], dtype=self.dtype)
|
||||||
|
return torch.matmul(x.view(-1, 2), rot).view(x.shape)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# def __repr__(self):
|
# def __repr__(self):
|
||||||
# return f"ONNRect({self.input_dim}, {self.output_dim})"
|
# return f"ONNRect({self.input_dim}, {self.output_dim})"
|
||||||
|
|
||||||
@@ -371,7 +453,7 @@ class Identity(nn.Module):
|
|||||||
M(z) = z
|
M(z) = z
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, size=None):
|
||||||
super(Identity, self).__init__()
|
super(Identity, self).__init__()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
@@ -404,9 +486,28 @@ class MZISingle(nn.Module):
|
|||||||
def forward(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
return x * torch.exp(1j * self.phi) * torch.sin(self.omega + self.func(x))
|
return x * torch.exp(1j * self.phi) * torch.sin(self.omega + self.func(x))
|
||||||
|
|
||||||
|
def naive_angle_loss(x: torch.Tensor, target: torch.Tensor, mod=2*torch.pi):
|
||||||
|
return torch.fmod((x - target), mod).square().mean()
|
||||||
|
|
||||||
|
def cosine_loss(x: torch.Tensor, target: torch.Tensor):
|
||||||
|
return (2*(1 - torch.cos(x - target))).mean()
|
||||||
|
|
||||||
|
def angle_mse_loss(x: torch.Tensor, target: torch.Tensor):
|
||||||
|
x = torch.fmod(x, 2*torch.pi)
|
||||||
|
target = torch.fmod(target, 2*torch.pi)
|
||||||
|
|
||||||
|
x_cos = torch.cos(x)
|
||||||
|
x_sin = torch.sin(x)
|
||||||
|
target_cos = torch.cos(target)
|
||||||
|
target_sin = torch.sin(target)
|
||||||
|
|
||||||
|
cos_diff = x_cos - target_cos
|
||||||
|
sin_diff = x_sin - target_sin
|
||||||
|
squared_diff = cos_diff**2 + sin_diff**2
|
||||||
|
return squared_diff.mean()
|
||||||
|
|
||||||
class EOActivation(nn.Module):
|
class EOActivation(nn.Module):
|
||||||
def __init__(self, bias, size=None):
|
def __init__(self, size=None):
|
||||||
# 10.1109/SiPhotonics60897.2024.10543376
|
# 10.1109/SiPhotonics60897.2024.10543376
|
||||||
super(EOActivation, self).__init__()
|
super(EOActivation, self).__init__()
|
||||||
if size is None:
|
if size is None:
|
||||||
@@ -571,81 +672,10 @@ class ZReLU(nn.Module):
|
|||||||
return torch.relu(x)
|
return torch.relu(x)
|
||||||
|
|
||||||
|
|
||||||
class regenerator(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*dims,
|
|
||||||
layer_function=ONN,
|
|
||||||
layer_kwargs: dict | None = None,
|
|
||||||
layer_parametrizations: list[dict] = None,
|
|
||||||
activation_function=Pow,
|
|
||||||
dtype=torch.float64,
|
|
||||||
dropout_prob=0.01,
|
|
||||||
scale=False,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
super(regenerator, self).__init__()
|
|
||||||
if len(dims) == 0:
|
|
||||||
try:
|
|
||||||
dims = kwargs["dims"]
|
|
||||||
except KeyError:
|
|
||||||
raise ValueError("dims must be provided")
|
|
||||||
self._n_hidden_layers = len(dims) - 2
|
|
||||||
self._layers = nn.Sequential()
|
|
||||||
if layer_kwargs is None:
|
|
||||||
layer_kwargs = {}
|
|
||||||
# self.powers = []
|
|
||||||
|
|
||||||
for i in range(self._n_hidden_layers + 1):
|
|
||||||
if scale:
|
|
||||||
self._layers.append(Scale(dims[i]))
|
|
||||||
self._layers.append(layer_function(dims[i], dims[i + 1], dtype=dtype, **layer_kwargs))
|
|
||||||
if i < self._n_hidden_layers:
|
|
||||||
if dropout_prob is not None:
|
|
||||||
self._layers.append(DropoutComplex(p=dropout_prob))
|
|
||||||
self._layers.append(activation_function(bias=True, size=dims[i + 1]))
|
|
||||||
|
|
||||||
self._layers.append(Scale(dims[-1]))
|
|
||||||
|
|
||||||
# add parametrizations
|
|
||||||
if layer_parametrizations is not None:
|
|
||||||
for layer in self._layers:
|
|
||||||
for layer_parametrization in layer_parametrizations:
|
|
||||||
tensor_name = layer_parametrization.get("tensor_name", None)
|
|
||||||
parametrization = layer_parametrization.get("parametrization", None)
|
|
||||||
param_kwargs = layer_parametrization.get("kwargs", {})
|
|
||||||
if tensor_name is not None and tensor_name in layer._parameters and parametrization is not None:
|
|
||||||
parametrization(layer, tensor_name, **param_kwargs)
|
|
||||||
|
|
||||||
# def __call__(self, input_x, **kwargs):
|
|
||||||
# return self.forward(input_x, **kwargs)
|
|
||||||
|
|
||||||
def forward(self, input_x, trace_powers=False):
|
|
||||||
x = input_x
|
|
||||||
|
|
||||||
if trace_powers:
|
|
||||||
powers = [x.abs().square().sum()]
|
|
||||||
|
|
||||||
# check if tracing
|
|
||||||
if torch.jit.is_tracing():
|
|
||||||
for layer in self._layers:
|
|
||||||
x = layer(x)
|
|
||||||
if trace_powers:
|
|
||||||
powers.append(x.abs().square().sum())
|
|
||||||
else:
|
|
||||||
# with torch.nn.utils.parametrize.cached():
|
|
||||||
for layer in self._layers:
|
|
||||||
x = layer(x)
|
|
||||||
if trace_powers:
|
|
||||||
powers.append(x.abs().square().sum())
|
|
||||||
if trace_powers:
|
|
||||||
return x, powers
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
complex_sse_loss,
|
complex_sse_loss,
|
||||||
complex_mse_loss,
|
complex_mse_loss,
|
||||||
|
angle_mse_loss,
|
||||||
UnitaryLayer,
|
UnitaryLayer,
|
||||||
unitary,
|
unitary,
|
||||||
energy_conserving,
|
energy_conserving,
|
||||||
@@ -662,6 +692,7 @@ __all__ = [
|
|||||||
ZReLU,
|
ZReLU,
|
||||||
MZISingle,
|
MZISingle,
|
||||||
EOActivation,
|
EOActivation,
|
||||||
|
photodiode,
|
||||||
# SaturableAbsorberLambertW,
|
# SaturableAbsorberLambertW,
|
||||||
# SaturableAbsorber,
|
# SaturableAbsorber,
|
||||||
# SpreadLayer,
|
# SpreadLayer,
|
||||||
|
|||||||
@@ -113,6 +113,8 @@ class FiberRegenerationDataset(Dataset):
|
|||||||
dtype: torch.dtype = None,
|
dtype: torch.dtype = None,
|
||||||
real: bool = False,
|
real: bool = False,
|
||||||
device=None,
|
device=None,
|
||||||
|
polarisations: tuple | list = (0,),
|
||||||
|
randomise_polarisations: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -145,6 +147,8 @@ class FiberRegenerationDataset(Dataset):
|
|||||||
assert output_dim is None or output_dim > 0, "output_len must be positive or None"
|
assert output_dim is None or output_dim > 0, "output_len must be positive or None"
|
||||||
assert drop_first >= 0, "drop_first must be non-negative"
|
assert drop_first >= 0, "drop_first must be non-negative"
|
||||||
|
|
||||||
|
self.randomise_polarisations = randomise_polarisations
|
||||||
|
|
||||||
faux = kwargs.pop("faux", False)
|
faux = kwargs.pop("faux", False)
|
||||||
|
|
||||||
if faux:
|
if faux:
|
||||||
@@ -165,7 +169,7 @@ class FiberRegenerationDataset(Dataset):
|
|||||||
data_raw = None
|
data_raw = None
|
||||||
self.config = None
|
self.config = None
|
||||||
files = []
|
files = []
|
||||||
for file_path in (file_path if isinstance(file_path, (tuple, list)) else [file_path]):
|
for file_path in file_path if isinstance(file_path, (tuple, list)) else [file_path]:
|
||||||
data, config = load_data(
|
data, config = load_data(
|
||||||
file_path,
|
file_path,
|
||||||
skipfirst=drop_first,
|
skipfirst=drop_first,
|
||||||
@@ -186,6 +190,19 @@ class FiberRegenerationDataset(Dataset):
|
|||||||
files.append(config["data"]["file"].strip('"'))
|
files.append(config["data"]["file"].strip('"'))
|
||||||
self.config["data"]["file"] = str(files)
|
self.config["data"]["file"] = str(files)
|
||||||
|
|
||||||
|
for i, angle in enumerate(torch.tensor(np.array(polarisations))):
|
||||||
|
data_raw_copy = data_raw.clone()
|
||||||
|
if angle == 0:
|
||||||
|
continue
|
||||||
|
sine = torch.sin(angle)
|
||||||
|
cosine = torch.cos(angle)
|
||||||
|
data_raw_copy[:, 2] = data_raw[:, 2] * cosine - data_raw[:, 3] * sine
|
||||||
|
data_raw_copy[:, 3] = data_raw[:, 2] * sine + data_raw[:, 3] * cosine
|
||||||
|
if i == 0:
|
||||||
|
data_raw = data_raw_copy
|
||||||
|
else:
|
||||||
|
data_raw = torch.cat([data_raw, data_raw_copy], dim=0)
|
||||||
|
|
||||||
self.device = data_raw.device
|
self.device = data_raw.device
|
||||||
|
|
||||||
self.samples_per_symbol = int(self.config["glova"]["sps"])
|
self.samples_per_symbol = int(self.config["glova"]["sps"])
|
||||||
@@ -258,17 +275,27 @@ class FiberRegenerationDataset(Dataset):
|
|||||||
elif self.target_delay_samples < 0:
|
elif self.target_delay_samples < 0:
|
||||||
data_raw = data_raw[:, : self.target_delay_samples]
|
data_raw = data_raw[:, : self.target_delay_samples]
|
||||||
|
|
||||||
timestamps = data_raw[-1, :]
|
timestamps = data_raw[4, :]
|
||||||
data_raw = data_raw[:-1, :]
|
data_raw = data_raw[:4, :]
|
||||||
data_raw = data_raw.view(2, 2, -1)
|
data_raw = data_raw.view(2, 2, -1)
|
||||||
timestamps_doubled = torch.cat([timestamps.unsqueeze(dim=0), timestamps.unsqueeze(dim=0)], dim=0).unsqueeze(dim=1)
|
timestamps_doubled = torch.cat([timestamps.unsqueeze(dim=0), timestamps.unsqueeze(dim=0)], dim=0).unsqueeze(
|
||||||
|
dim=1
|
||||||
|
)
|
||||||
data_raw = torch.cat([data_raw, timestamps_doubled], dim=1)
|
data_raw = torch.cat([data_raw, timestamps_doubled], dim=1)
|
||||||
|
# data_raw = torch.cat([data_raw, timestamps_doubled], dim=1)
|
||||||
# data layout
|
# data layout
|
||||||
# [ [E_in_x, E_in_y, timestamps],
|
# [ [E_in_x, E_in_y, timestamps],
|
||||||
# [E_out_x, E_out_y, timestamps] ]
|
# [E_out_x, E_out_y, timestamps] ]
|
||||||
|
|
||||||
self.data = data_raw.unfold(dimension=-1, size=self.samples_per_slice, step=1)
|
self.data = data_raw.unfold(dimension=-1, size=self.samples_per_slice, step=1)
|
||||||
self.data = self.data.movedim(-2, 0)
|
self.data = self.data.movedim(-2, 0)
|
||||||
|
|
||||||
|
if randomise_polarisations:
|
||||||
|
self.angles = torch.rand(self.data.shape[0]) * np.pi * 2
|
||||||
|
# self.data[:, 1, :2, :] = self.rotate(self.data[:, 1, :2, :], self.angles)
|
||||||
|
else:
|
||||||
|
self.angles = torch.zeros(self.data.shape[0])
|
||||||
|
# ...
|
||||||
# -> [no_slices, 2, 3, samples_per_slice]
|
# -> [no_slices, 2, 3, samples_per_slice]
|
||||||
|
|
||||||
# data layout
|
# data layout
|
||||||
@@ -293,18 +320,88 @@ class FiberRegenerationDataset(Dataset):
|
|||||||
|
|
||||||
data_slice = data_slice.view(data_slice.shape[0], data_slice.shape[1], self.output_dim, -1)
|
data_slice = data_slice.view(data_slice.shape[0], data_slice.shape[1], self.output_dim, -1)
|
||||||
|
|
||||||
target = data_slice[0, :, self.output_dim//2, 0]
|
# if self.randomise_polarisations:
|
||||||
data = data_slice[1, :, :, 0]
|
# angle = torch.rand(1) * torch.pi * 2
|
||||||
|
# sine = torch.sin(angle)
|
||||||
|
# cosine = torch.cos(angle)
|
||||||
|
# data_slice_ = data_slice[1]
|
||||||
|
# data_slice[1, 0] = data_slice_[0] * cosine - data_slice_[1] * sine
|
||||||
|
# data_slice[1,1] = data_slice_[0] * sine + data_slice_[1] * cosine
|
||||||
|
# else:
|
||||||
|
# angle = torch.zeros(1)
|
||||||
|
|
||||||
|
# data = data_slice[1, :2, :, 0]
|
||||||
|
|
||||||
|
angle = self.angles[idx]
|
||||||
|
|
||||||
|
data_index = 1
|
||||||
|
|
||||||
|
data_slice[1, :2, :, :] = self.rotate(data_slice[data_index, :2, :, :], angle)
|
||||||
|
|
||||||
|
data = data_slice[1, :2, :, 0]
|
||||||
|
# data = self.rotate(data, angle)
|
||||||
|
|
||||||
|
# for both polarisations (x/y), calculate the mean of the signal around the current symbol (-> corresponding to a lowpass filter)
|
||||||
|
angle_data = data_slice[1, :2, :, :].reshape(2, -1).mean(dim=1)
|
||||||
|
angle_data2 = self.complex_max(data_slice[1, :2, :, :].reshape(2, -1))
|
||||||
|
plot_data = data_slice[1, :2, self.output_dim // 2, 0]
|
||||||
|
sop = self.polarimeter(plot_data)
|
||||||
|
# angle_data = data_slice[1, :2, :self.output_dim//2, :].squeeze().reshape(2, -1).mean(dim=-1)
|
||||||
|
# angle = data_slice[1, 3, self.output_dim // 2, 0].real
|
||||||
|
target = data_slice[0, :2, self.output_dim // 2, 0]
|
||||||
|
target_timestamp = data_slice[0, 2, self.output_dim // 2, 0].real
|
||||||
|
...
|
||||||
|
|
||||||
# data_timestamps = data[-1,:].real
|
# data_timestamps = data[-1,:].real
|
||||||
data = data[:-1, :]
|
# data = data[:-1, :]
|
||||||
target_timestamp = target[-1].real
|
# target_timestamp = target[-1].real
|
||||||
target = target[:-1]
|
# target = target[:-1]
|
||||||
|
# plot_data = plot_data[:-1]
|
||||||
|
|
||||||
|
# transpose to interleave the x and y data in the output tensor
|
||||||
data = data.transpose(0, 1).flatten().squeeze()
|
data = data.transpose(0, 1).flatten().squeeze()
|
||||||
|
angle_data = angle_data.flatten().squeeze()
|
||||||
|
angle_data2 = angle_data.flatten().squeeze()
|
||||||
|
angle = angle.flatten().squeeze()
|
||||||
# data_timestamps = data_timestamps.flatten().squeeze()
|
# data_timestamps = data_timestamps.flatten().squeeze()
|
||||||
target = target.flatten().squeeze()
|
target = target.flatten().squeeze()
|
||||||
target_timestamp = target_timestamp.flatten().squeeze()
|
target_timestamp = target_timestamp.flatten().squeeze()
|
||||||
|
|
||||||
return data, target, target_timestamp
|
return {"x": data, "y": target, "angle": angle, "sop": sop, "angle_data": angle_data, "angle_data2": angle_data2, "timestamp": target_timestamp, "plot_data": plot_data}
|
||||||
|
|
||||||
|
def complex_max(self, data, dim=-1):
|
||||||
|
# returns element(s) with the maximum absolute value along a given dimension
|
||||||
|
# ind = torch.argmax(data.abs(), dim=dim, keepdim=True)
|
||||||
|
# max_values = torch.gather(data, dim, ind).squeeze(dim=dim)
|
||||||
|
# return max_values
|
||||||
|
return torch.gather(data, dim, torch.argmax(data.abs(), dim=dim, keepdim=True)).squeeze(dim=dim)
|
||||||
|
|
||||||
|
|
||||||
|
def rotate(self, data, angle):
|
||||||
|
# rotates a 2d tensor by a given angle
|
||||||
|
# data: [2, ...]
|
||||||
|
# angle: [1]
|
||||||
|
# returns: [2, ...]
|
||||||
|
|
||||||
|
# get sine and cosine of the angle
|
||||||
|
sine = torch.sin(angle)
|
||||||
|
cosine = torch.cos(angle)
|
||||||
|
|
||||||
|
return torch.stack([data[0] * cosine - data[1] * sine, data[0] * sine + data[1] * cosine], dim=0)
|
||||||
|
|
||||||
|
def polarimeter(self, data):
|
||||||
|
# data: [2, ...] -> x, y
|
||||||
|
# returns [4] -> S0, S1, S2, S3
|
||||||
|
x = data[0].mean()
|
||||||
|
y = data[1].mean()
|
||||||
|
I_X = x.abs().square()
|
||||||
|
I_Y = y.abs().square()
|
||||||
|
I_45 = (x+y).abs().square()
|
||||||
|
I_RHC = (x + 1j*y).abs().square()
|
||||||
|
|
||||||
|
S0 = I_X + I_Y
|
||||||
|
S1 = (2*I_X - S0) / S0
|
||||||
|
S2 = (2*I_45 - S0) / S0
|
||||||
|
S3 = (2*I_RHC - S0) / S0
|
||||||
|
|
||||||
|
return torch.stack([S1, S2, S3], dim=0)
|
||||||
Reference in New Issue
Block a user