add training.py for defining and running models without hyperparametertuning
This commit is contained in:
739
src/single-core-regen/hypertraining/training.py
Normal file
739
src/single-core-regen/hypertraining/training.py
Normal file
@@ -0,0 +1,739 @@
|
|||||||
|
import copy
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Literal
|
||||||
|
import matplotlib
|
||||||
|
import torch.nn.utils.parametrize
|
||||||
|
|
||||||
|
try:
|
||||||
|
matplotlib.use("cairo")
|
||||||
|
except ImportError:
|
||||||
|
matplotlib.use("Agg")
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
# import torch.nn.functional as F # mse_loss doesn't support complex numbers
|
||||||
|
import torch.optim as optim
|
||||||
|
import torch.utils.data
|
||||||
|
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
from rich.progress import (
|
||||||
|
Progress,
|
||||||
|
TextColumn,
|
||||||
|
BarColumn,
|
||||||
|
TaskProgressColumn,
|
||||||
|
TimeRemainingColumn,
|
||||||
|
MofNCompleteColumn,
|
||||||
|
TimeElapsedColumn,
|
||||||
|
)
|
||||||
|
from rich.console import Console
|
||||||
|
|
||||||
|
from util.datasets import FiberRegenerationDataset
|
||||||
|
import util
|
||||||
|
|
||||||
|
from .settings import (
|
||||||
|
GlobalSettings,
|
||||||
|
DataSettings,
|
||||||
|
ModelSettings,
|
||||||
|
OptimizerSettings,
|
||||||
|
PytorchSettings,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class regenerator(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*dims,
|
||||||
|
layer_function=util.complexNN.ONN,
|
||||||
|
layer_parametrizations: list[dict] = None,
|
||||||
|
# [
|
||||||
|
# {
|
||||||
|
# "tensor_name": "weight",
|
||||||
|
# "parametrization": util.complexNN.Unitary,
|
||||||
|
# },
|
||||||
|
# {
|
||||||
|
# "tensor_name": "scale",
|
||||||
|
# "parametrization": util.complexNN.Clamp,
|
||||||
|
# },
|
||||||
|
# ],
|
||||||
|
activation_function=util.complexNN.Pow,
|
||||||
|
dtype=torch.float64,
|
||||||
|
dropout_prob=0.01,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super(regenerator, self).__init__()
|
||||||
|
if len(dims) == 0:
|
||||||
|
try:
|
||||||
|
dims = kwargs["dims"]
|
||||||
|
except KeyError:
|
||||||
|
raise ValueError("dims must be provided")
|
||||||
|
self._n_hidden_layers = len(dims) - 2
|
||||||
|
self._layers = nn.Sequential()
|
||||||
|
|
||||||
|
for i in range(self._n_hidden_layers + 1):
|
||||||
|
self._layers.append(layer_function(dims[i], dims[i + 1], dtype=dtype))
|
||||||
|
if i < self._n_hidden_layers:
|
||||||
|
if dropout_prob is not None:
|
||||||
|
self._layers.append(util.complexNN.DropoutComplex(p=dropout_prob))
|
||||||
|
self._layers.append(activation_function())
|
||||||
|
|
||||||
|
# add parametrizations
|
||||||
|
if layer_parametrizations is not None:
|
||||||
|
for layer_parametrization in layer_parametrizations:
|
||||||
|
tensor_name = layer_parametrization.get("tensor_name", None)
|
||||||
|
parametrization = layer_parametrization.get("parametrization", None)
|
||||||
|
param_kwargs = layer_parametrization.get("kwargs", {})
|
||||||
|
if (
|
||||||
|
tensor_name is not None
|
||||||
|
and tensor_name in self._layers[-1]._parameters
|
||||||
|
and parametrization is not None
|
||||||
|
):
|
||||||
|
parametrization(self._layers[-1], tensor_name, **param_kwargs)
|
||||||
|
|
||||||
|
def forward(self, input_x):
|
||||||
|
x = input_x
|
||||||
|
# check if tracing
|
||||||
|
if torch.jit.is_tracing():
|
||||||
|
for layer in self._layers:
|
||||||
|
x = layer(x)
|
||||||
|
else:
|
||||||
|
# with torch.nn.utils.parametrize.cached():
|
||||||
|
for layer in self._layers:
|
||||||
|
x = layer(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def traverse_dict_update(target, source):
|
||||||
|
for k, v in source.items():
|
||||||
|
if isinstance(v, dict):
|
||||||
|
if k not in target:
|
||||||
|
target[k] = {}
|
||||||
|
traverse_dict_update(target[k], v)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
target[k] = v
|
||||||
|
except TypeError:
|
||||||
|
target.__dict__[k] = v
|
||||||
|
|
||||||
|
class Trainer:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
global_settings=None,
|
||||||
|
data_settings=None,
|
||||||
|
pytorch_settings=None,
|
||||||
|
model_settings=None,
|
||||||
|
optimizer_settings=None,
|
||||||
|
console=None,
|
||||||
|
checkpoint_path=None,
|
||||||
|
settings_override=None,
|
||||||
|
reset_epoch=False,
|
||||||
|
):
|
||||||
|
self.resume = checkpoint_path is not None
|
||||||
|
torch.serialization.add_safe_globals([
|
||||||
|
*util.complexNN.__all__,
|
||||||
|
GlobalSettings,
|
||||||
|
DataSettings,
|
||||||
|
ModelSettings,
|
||||||
|
OptimizerSettings,
|
||||||
|
PytorchSettings,
|
||||||
|
regenerator,
|
||||||
|
torch.nn.utils.parametrizations.orthogonal
|
||||||
|
])
|
||||||
|
if self.resume:
|
||||||
|
self.checkpoint_dict = torch.load(checkpoint_path, weights_only=True)
|
||||||
|
if settings_override is not None:
|
||||||
|
traverse_dict_update(self.checkpoint_dict["settings"], settings_override)
|
||||||
|
if reset_epoch:
|
||||||
|
self.checkpoint_dict["epoch"] = -1
|
||||||
|
|
||||||
|
self.global_settings: GlobalSettings = self.checkpoint_dict["settings"]["global_settings"]
|
||||||
|
self.data_settings: DataSettings = self.checkpoint_dict["settings"]["data_settings"]
|
||||||
|
self.pytorch_settings: PytorchSettings = self.checkpoint_dict["settings"]["pytorch_settings"]
|
||||||
|
self.model_settings: ModelSettings = self.checkpoint_dict["settings"]["model_settings"]
|
||||||
|
self.optimizer_settings: OptimizerSettings = self.checkpoint_dict["settings"]["optimizer_settings"]
|
||||||
|
else:
|
||||||
|
if global_settings is None:
|
||||||
|
raise ValueError("global_settings must be provided")
|
||||||
|
if data_settings is None:
|
||||||
|
raise ValueError("data_settings must be provided")
|
||||||
|
if pytorch_settings is None:
|
||||||
|
raise ValueError("pytorch_settings must be provided")
|
||||||
|
if model_settings is None:
|
||||||
|
raise ValueError("model_settings must be provided")
|
||||||
|
if optimizer_settings is None:
|
||||||
|
raise ValueError("optimizer_settings must be provided")
|
||||||
|
|
||||||
|
self.global_settings: GlobalSettings = global_settings
|
||||||
|
self.data_settings: DataSettings = data_settings
|
||||||
|
self.pytorch_settings: PytorchSettings = pytorch_settings
|
||||||
|
self.model_settings: ModelSettings = model_settings
|
||||||
|
self.optimizer_settings: OptimizerSettings = optimizer_settings
|
||||||
|
|
||||||
|
self.console = console or Console()
|
||||||
|
self.writer = None
|
||||||
|
|
||||||
|
def setup_tb_writer(self, append=None):
|
||||||
|
log_dir = self.pytorch_settings.summary_dir + "/" + (datetime.now().strftime("%Y%m%d_%H%M%S"))
|
||||||
|
if append is not None:
|
||||||
|
log_dir += "_" + str(append)
|
||||||
|
|
||||||
|
print(f"Logging to {log_dir}")
|
||||||
|
self.writer = SummaryWriter(log_dir=log_dir)
|
||||||
|
|
||||||
|
def save_checkpoint(self, save_dict, filename):
|
||||||
|
torch.save(save_dict, filename)
|
||||||
|
|
||||||
|
def build_checkpoint_dict(self, loss=None, epoch=None):
|
||||||
|
return {
|
||||||
|
"epoch": -1 if epoch is None else epoch,
|
||||||
|
"loss": float("inf") if loss is None else loss,
|
||||||
|
"model_state_dict": copy.deepcopy(self.model.state_dict()),
|
||||||
|
"optimizer_state_dict": copy.deepcopy(self.optimizer.state_dict()),
|
||||||
|
"scheduler_state_dict": copy.deepcopy(self.scheduler.state_dict()) if hasattr(self, "scheduler") else None,
|
||||||
|
"model_kwargs": copy.deepcopy(self.model_kwargs),
|
||||||
|
"settings": {
|
||||||
|
"global_settings": copy.deepcopy(self.global_settings),
|
||||||
|
"data_settings": copy.deepcopy(self.data_settings),
|
||||||
|
"pytorch_settings": copy.deepcopy(self.pytorch_settings),
|
||||||
|
"model_settings": copy.deepcopy(self.model_settings),
|
||||||
|
"optimizer_settings": copy.deepcopy(self.optimizer_settings),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def define_model(self, model_kwargs=None):
|
||||||
|
if model_kwargs is None:
|
||||||
|
n_hidden_layers = self.model_settings.n_hidden_layers
|
||||||
|
|
||||||
|
input_dim = 2 * self.data_settings.output_size
|
||||||
|
|
||||||
|
dtype = getattr(torch, self.data_settings.dtype)
|
||||||
|
|
||||||
|
afunc = getattr(util.complexNN, self.model_settings.model_activation_func)
|
||||||
|
|
||||||
|
layer_func = getattr(util.complexNN, self.model_settings.model_layer_function)
|
||||||
|
|
||||||
|
layer_parametrizations = self.model_settings.model_layer_parametrizations
|
||||||
|
|
||||||
|
hidden_dims = [self.model_settings.overrides.get(f"n_hidden_nodes_{i}") for i in range(n_hidden_layers)]
|
||||||
|
|
||||||
|
self.model_kwargs = {
|
||||||
|
"dims": (input_dim, *hidden_dims, self.model_settings.output_dim),
|
||||||
|
"layer_function": layer_func,
|
||||||
|
"layer_parametrizations": layer_parametrizations,
|
||||||
|
"activation_function": afunc,
|
||||||
|
"dtype": dtype,
|
||||||
|
"dropout_prob": self.model_settings.dropout_prob,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
self.model_kwargs = model_kwargs
|
||||||
|
input_dim = self.model_kwargs["dims"][0]
|
||||||
|
dtype = self.model_kwargs["dtype"]
|
||||||
|
|
||||||
|
# dims = self.model_kwargs.pop("dims")
|
||||||
|
self.model = regenerator(**self.model_kwargs)
|
||||||
|
|
||||||
|
self.writer.add_graph(self.model, torch.zeros(1, input_dim, dtype=dtype))
|
||||||
|
|
||||||
|
self.model = self.model.to(self.pytorch_settings.device)
|
||||||
|
|
||||||
|
def get_sliced_data(self, override=None):
|
||||||
|
symbols = self.data_settings.symbols
|
||||||
|
|
||||||
|
in_out_delay = self.data_settings.in_out_delay
|
||||||
|
|
||||||
|
xy_delay = self.data_settings.xy_delay
|
||||||
|
|
||||||
|
data_size = self.data_settings.output_size
|
||||||
|
|
||||||
|
dtype = getattr(torch, self.data_settings.dtype)
|
||||||
|
|
||||||
|
num_symbols = None
|
||||||
|
if override is not None:
|
||||||
|
num_symbols = override.get("num_symbols", None)
|
||||||
|
# get dataset
|
||||||
|
dataset = FiberRegenerationDataset(
|
||||||
|
file_path=self.data_settings.config_path,
|
||||||
|
symbols=symbols,
|
||||||
|
output_dim=data_size,
|
||||||
|
target_delay=in_out_delay,
|
||||||
|
xy_delay=xy_delay,
|
||||||
|
drop_first=self.data_settings.drop_first,
|
||||||
|
dtype=dtype,
|
||||||
|
real=not dtype.is_complex,
|
||||||
|
num_symbols=num_symbols,
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset_size = len(dataset)
|
||||||
|
indices = list(range(dataset_size))
|
||||||
|
split = int(np.floor(self.data_settings.train_split * dataset_size))
|
||||||
|
if self.data_settings.shuffle:
|
||||||
|
np.random.seed(self.global_settings.seed)
|
||||||
|
np.random.shuffle(indices)
|
||||||
|
|
||||||
|
train_indices, valid_indices = indices[:split], indices[split:]
|
||||||
|
|
||||||
|
if self.data_settings.shuffle:
|
||||||
|
train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
|
||||||
|
valid_sampler = torch.utils.data.SubsetRandomSampler(valid_indices)
|
||||||
|
else:
|
||||||
|
train_sampler = train_indices
|
||||||
|
valid_sampler = valid_indices
|
||||||
|
|
||||||
|
train_loader = torch.utils.data.DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=self.pytorch_settings.batchsize,
|
||||||
|
sampler=train_sampler,
|
||||||
|
drop_last=True,
|
||||||
|
pin_memory=True,
|
||||||
|
num_workers=self.pytorch_settings.dataloader_workers,
|
||||||
|
prefetch_factor=self.pytorch_settings.dataloader_prefetch,
|
||||||
|
)
|
||||||
|
|
||||||
|
valid_loader = torch.utils.data.DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=self.pytorch_settings.batchsize,
|
||||||
|
sampler=valid_sampler,
|
||||||
|
drop_last=True,
|
||||||
|
pin_memory=True,
|
||||||
|
num_workers=self.pytorch_settings.dataloader_workers,
|
||||||
|
prefetch_factor=self.pytorch_settings.dataloader_prefetch,
|
||||||
|
)
|
||||||
|
|
||||||
|
return train_loader, valid_loader
|
||||||
|
|
||||||
|
def train_model(
|
||||||
|
self,
|
||||||
|
optimizer,
|
||||||
|
train_loader,
|
||||||
|
epoch,
|
||||||
|
enable_progress=False,
|
||||||
|
):
|
||||||
|
if enable_progress:
|
||||||
|
progress = Progress(
|
||||||
|
TextColumn("[yellow] Training..."),
|
||||||
|
TextColumn("Error: {task.description}"),
|
||||||
|
BarColumn(),
|
||||||
|
TaskProgressColumn(),
|
||||||
|
TextColumn("[green]Batch"),
|
||||||
|
MofNCompleteColumn(),
|
||||||
|
TimeRemainingColumn(),
|
||||||
|
TimeElapsedColumn(),
|
||||||
|
transient=False,
|
||||||
|
console=self.console,
|
||||||
|
refresh_per_second=10,
|
||||||
|
)
|
||||||
|
task = progress.add_task("-.---e--", total=len(train_loader))
|
||||||
|
progress.start()
|
||||||
|
|
||||||
|
running_loss2 = 0.0
|
||||||
|
running_loss = 0.0
|
||||||
|
self.model.train()
|
||||||
|
for batch_idx, (x, y) in enumerate(train_loader):
|
||||||
|
self.model.zero_grad(set_to_none=True)
|
||||||
|
x, y = (
|
||||||
|
x.to(self.pytorch_settings.device),
|
||||||
|
y.to(self.pytorch_settings.device),
|
||||||
|
)
|
||||||
|
y_pred = self.model(x)
|
||||||
|
loss = util.complexNN.complex_mse_loss(y_pred, y, power=True)
|
||||||
|
loss_value = loss.item()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
running_loss2 += loss_value
|
||||||
|
running_loss += loss_value
|
||||||
|
|
||||||
|
if enable_progress:
|
||||||
|
progress.update(task, advance=1, description=f"{loss_value:.3e}")
|
||||||
|
|
||||||
|
if batch_idx % self.pytorch_settings.write_every == 0:
|
||||||
|
self.writer.add_scalar(
|
||||||
|
"training loss",
|
||||||
|
running_loss2 / (self.pytorch_settings.write_every if batch_idx > 0 else 1),
|
||||||
|
epoch * len(train_loader) + batch_idx,
|
||||||
|
)
|
||||||
|
running_loss2 = 0.0
|
||||||
|
|
||||||
|
if enable_progress:
|
||||||
|
progress.stop()
|
||||||
|
|
||||||
|
return running_loss / len(train_loader)
|
||||||
|
|
||||||
|
def eval_model(self, valid_loader, epoch, enable_progress=True):
|
||||||
|
if enable_progress:
|
||||||
|
progress = Progress(
|
||||||
|
TextColumn("[green]Evaluating..."),
|
||||||
|
TextColumn("Error: {task.description}"),
|
||||||
|
BarColumn(),
|
||||||
|
TaskProgressColumn(),
|
||||||
|
TextColumn("[green]Batch"),
|
||||||
|
MofNCompleteColumn(),
|
||||||
|
TimeRemainingColumn(),
|
||||||
|
TimeElapsedColumn(),
|
||||||
|
transient=False,
|
||||||
|
console=self.console,
|
||||||
|
refresh_per_second=10,
|
||||||
|
)
|
||||||
|
progress.start()
|
||||||
|
task = progress.add_task("-.---e--", total=len(valid_loader))
|
||||||
|
|
||||||
|
self.model.eval()
|
||||||
|
running_error = 0
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch_idx, (x, y) in enumerate(valid_loader):
|
||||||
|
x, y = (
|
||||||
|
x.to(self.pytorch_settings.device),
|
||||||
|
y.to(self.pytorch_settings.device),
|
||||||
|
)
|
||||||
|
y_pred = self.model(x)
|
||||||
|
error = util.complexNN.complex_mse_loss(y_pred, y, power=True)
|
||||||
|
error_value = error.item()
|
||||||
|
running_error += error_value
|
||||||
|
|
||||||
|
if enable_progress:
|
||||||
|
progress.update(task, advance=1, description=f"{error_value:.3e}")
|
||||||
|
|
||||||
|
|
||||||
|
running_error /= len(valid_loader)
|
||||||
|
self.writer.add_scalar(
|
||||||
|
"eval loss",
|
||||||
|
running_error,
|
||||||
|
epoch,
|
||||||
|
)
|
||||||
|
title_append, subtitle = self.build_title(epoch + 1)
|
||||||
|
self.writer.add_figure(
|
||||||
|
"fiber response",
|
||||||
|
self.plot_model_response(
|
||||||
|
model=self.model,
|
||||||
|
title_append=title_append,
|
||||||
|
subtitle=subtitle,
|
||||||
|
show=False,
|
||||||
|
),
|
||||||
|
epoch + 1,
|
||||||
|
)
|
||||||
|
self.writer.add_figure(
|
||||||
|
"eye diagram",
|
||||||
|
self.plot_model_response(
|
||||||
|
model=self.model,
|
||||||
|
title_append=title_append,
|
||||||
|
subtitle=subtitle,
|
||||||
|
show=False,
|
||||||
|
mode="eye",
|
||||||
|
),
|
||||||
|
epoch + 1,
|
||||||
|
)
|
||||||
|
self.writer_histograms(epoch + 1)
|
||||||
|
|
||||||
|
if enable_progress:
|
||||||
|
progress.stop()
|
||||||
|
|
||||||
|
return running_error
|
||||||
|
|
||||||
|
def run_model(self, model, loader):
|
||||||
|
model.eval()
|
||||||
|
xs = []
|
||||||
|
ys = []
|
||||||
|
y_preds = []
|
||||||
|
with torch.no_grad():
|
||||||
|
model = model.to(self.pytorch_settings.device)
|
||||||
|
for x, y in loader:
|
||||||
|
x, y = (
|
||||||
|
x.to(self.pytorch_settings.device),
|
||||||
|
y.to(self.pytorch_settings.device),
|
||||||
|
)
|
||||||
|
y_pred = model(x).cpu()
|
||||||
|
# x = x.cpu()
|
||||||
|
# y = y.cpu()
|
||||||
|
y_pred = y_pred.view(y_pred.shape[0], -1, 2)
|
||||||
|
y = y.view(y.shape[0], -1, 2)
|
||||||
|
x = x.view(x.shape[0], -1, 2)
|
||||||
|
xs.append(x[:, 0, :].squeeze())
|
||||||
|
ys.append(y.squeeze())
|
||||||
|
y_preds.append(y_pred.squeeze())
|
||||||
|
|
||||||
|
xs = torch.vstack(xs).cpu()
|
||||||
|
ys = torch.vstack(ys).cpu()
|
||||||
|
y_preds = torch.vstack(y_preds).cpu()
|
||||||
|
return ys, xs, y_preds
|
||||||
|
|
||||||
|
def writer_histograms(self, epoch, attributes=["weight", "weight_U", "weight_V", "bias", "sigma", "scale"]):
|
||||||
|
for i, layer in enumerate(self.model._layers):
|
||||||
|
tag = f"layer {i}"
|
||||||
|
for attribute in attributes:
|
||||||
|
if hasattr(layer, attribute):
|
||||||
|
vals: np.ndarray = getattr(layer, attribute).detach().cpu().numpy().flatten()
|
||||||
|
if vals.ndim <= 1 and len(vals) == 1:
|
||||||
|
if np.iscomplexobj(vals):
|
||||||
|
self.writer.add_scalar(f"{tag} {attribute} (Mag)", np.abs(vals), epoch)
|
||||||
|
self.writer.add_scalar(f"{tag} {attribute} (Phase)", np.angle(vals), epoch)
|
||||||
|
else:
|
||||||
|
self.writer.add_scalar(f"{tag} {attribute}", vals, epoch)
|
||||||
|
else:
|
||||||
|
if np.iscomplexobj(vals):
|
||||||
|
self.writer.add_histogram(f"{tag} {attribute} (Mag)", np.abs(vals), epoch, bins="fd")
|
||||||
|
self.writer.add_histogram(f"{tag} {attribute} (Phase)", np.angle(vals), epoch, bins="fd")
|
||||||
|
else:
|
||||||
|
self.writer.add_histogram(f"{tag} {attribute}", vals, epoch, bins="fd")
|
||||||
|
|
||||||
|
def train(self):
|
||||||
|
if self.writer is None:
|
||||||
|
self.setup_tb_writer()
|
||||||
|
|
||||||
|
if self.resume:
|
||||||
|
model_kwargs = self.checkpoint_dict["model_kwargs"]
|
||||||
|
else:
|
||||||
|
model_kwargs = None
|
||||||
|
|
||||||
|
self.define_model(model_kwargs=model_kwargs)
|
||||||
|
|
||||||
|
print(f"number of parameters (trainable): {sum(p.numel() for p in self.model.parameters())} ({sum(p.numel() for p in self.model.parameters() if p.requires_grad)})")
|
||||||
|
|
||||||
|
title_append, subtitle = self.build_title(0)
|
||||||
|
|
||||||
|
self.writer.add_figure(
|
||||||
|
"fiber response",
|
||||||
|
self.plot_model_response(
|
||||||
|
model=self.model,
|
||||||
|
title_append=title_append,
|
||||||
|
subtitle=subtitle,
|
||||||
|
show=False,
|
||||||
|
),
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
self.writer.add_figure(
|
||||||
|
"eye diagram",
|
||||||
|
self.plot_model_response(
|
||||||
|
model=self.model,
|
||||||
|
title_append=title_append,
|
||||||
|
subtitle=subtitle,
|
||||||
|
mode="eye",
|
||||||
|
show=False,
|
||||||
|
),
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
self.writer_histograms(0)
|
||||||
|
|
||||||
|
train_loader, valid_loader = self.get_sliced_data()
|
||||||
|
|
||||||
|
optimizer_name = self.optimizer_settings.optimizer
|
||||||
|
|
||||||
|
lr = self.optimizer_settings.learning_rate
|
||||||
|
|
||||||
|
self.optimizer: optim.Optimizer = getattr(optim, optimizer_name)(self.model.parameters(), lr=lr)
|
||||||
|
if self.optimizer_settings.scheduler is not None:
|
||||||
|
self.scheduler = getattr(optim.lr_scheduler, self.optimizer_settings.scheduler)(
|
||||||
|
self.optimizer, **self.optimizer_settings.scheduler_kwargs
|
||||||
|
)
|
||||||
|
if self.resume:
|
||||||
|
try:
|
||||||
|
self.scheduler.load_state_dict(self.checkpoint_dict["scheduler_state_dict"])
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
self.writer.add_scalar("learning rate", self.scheduler.get_last_lr()[0], -1)
|
||||||
|
|
||||||
|
|
||||||
|
if not self.resume:
|
||||||
|
self.best = self.build_checkpoint_dict()
|
||||||
|
else:
|
||||||
|
self.best = self.checkpoint_dict
|
||||||
|
self.model.load_state_dict(self.best["model_state_dict"], strict=False)
|
||||||
|
try:
|
||||||
|
self.optimizer.load_state_dict(self.best["optimizer_state_dict"])
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
for epoch in range(self.best["epoch"] + 1, self.pytorch_settings.epochs):
|
||||||
|
enable_progress = True
|
||||||
|
if enable_progress:
|
||||||
|
self.console.rule(f"Epoch {epoch + 1}/{self.pytorch_settings.epochs}")
|
||||||
|
self.train_model(
|
||||||
|
self.optimizer,
|
||||||
|
train_loader,
|
||||||
|
epoch,
|
||||||
|
enable_progress=enable_progress,
|
||||||
|
)
|
||||||
|
loss = self.eval_model(
|
||||||
|
valid_loader,
|
||||||
|
epoch,
|
||||||
|
enable_progress=enable_progress,
|
||||||
|
)
|
||||||
|
if self.optimizer_settings.scheduler is not None:
|
||||||
|
lr_old = self.scheduler.get_last_lr()
|
||||||
|
self.scheduler.step(loss)
|
||||||
|
lr_new = self.scheduler.get_last_lr()
|
||||||
|
if lr_old[0] != lr_new[0]:
|
||||||
|
self.writer.add_scalar("learning rate", lr_new[0], epoch)
|
||||||
|
|
||||||
|
if self.pytorch_settings.save_models and self.model is not None:
|
||||||
|
save_path = (
|
||||||
|
Path(self.pytorch_settings.model_dir) / f"{self.writer.get_logdir().split('/')[-1]}_{epoch}.tar"
|
||||||
|
)
|
||||||
|
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
checkpoint = self.build_checkpoint_dict(loss, epoch)
|
||||||
|
self.save_checkpoint(checkpoint, save_path)
|
||||||
|
|
||||||
|
if loss < self.best["loss"]:
|
||||||
|
self.best = checkpoint
|
||||||
|
save_path = (
|
||||||
|
Path(self.pytorch_settings.model_dir) / f"best_{self.writer.get_logdir().split('/')[-1]}.tar"
|
||||||
|
)
|
||||||
|
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
self.save_checkpoint(self.best, save_path)
|
||||||
|
self.writer.flush()
|
||||||
|
|
||||||
|
self.writer.close()
|
||||||
|
return self.best
|
||||||
|
|
||||||
|
def _plot_model_response_eye(self, *signals, labels=None, sps=None, title_append="", subtitle="", show=True):
|
||||||
|
if sps is None:
|
||||||
|
raise ValueError("sps must be provided")
|
||||||
|
if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))):
|
||||||
|
labels = [labels]
|
||||||
|
else:
|
||||||
|
labels = list(labels)
|
||||||
|
|
||||||
|
while len(labels) < len(signals):
|
||||||
|
labels.append(None)
|
||||||
|
|
||||||
|
# check if there are any labels
|
||||||
|
if not any(labels):
|
||||||
|
labels = [f"signal {i + 1}" for i in range(len(signals))]
|
||||||
|
|
||||||
|
fig, axs = plt.subplots(1, 2 * len(signals), sharex=True, sharey=True)
|
||||||
|
fig.set_figwidth(18)
|
||||||
|
fig.suptitle(f"Eye diagram{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}")
|
||||||
|
xaxis = np.linspace(0, 2, 2 * sps, endpoint=False)
|
||||||
|
for j, (label, signal) in enumerate(zip(labels, signals)):
|
||||||
|
# signal = signal.cpu().numpy()
|
||||||
|
for i in range(len(signal) // sps - 1):
|
||||||
|
x, y = signal[i * sps : (i + 2) * sps].T
|
||||||
|
axs[0 + 2 * j].plot(xaxis, np.abs(x) ** 2, color=f"C{j}", alpha=1 / (len(signal) // sps) * 10)
|
||||||
|
axs[1 + 2 * j].plot(xaxis, np.abs(y) ** 2, color=f"C{j}", alpha=1 / (len(signal) // sps) * 10)
|
||||||
|
axs[0 + 2 * j].set_title(label + " x")
|
||||||
|
axs[1 + 2 * j].set_title(label + " y")
|
||||||
|
axs[0 + 2 * j].set_xlabel("Symbol")
|
||||||
|
axs[1 + 2 * j].set_xlabel("Symbol")
|
||||||
|
axs[0 + 2 * j].set_box_aspect(1)
|
||||||
|
axs[1 + 2 * j].set_box_aspect(1)
|
||||||
|
axs[0].set_ylabel("normalized power")
|
||||||
|
fig.tight_layout()
|
||||||
|
# axs[1+2*len(labels)-1].set_ylabel("normalized power")
|
||||||
|
|
||||||
|
if show:
|
||||||
|
plt.show()
|
||||||
|
return fig
|
||||||
|
|
||||||
|
def _plot_model_response_head(self, *signals, labels=None, sps=None, title_append="", subtitle="", show=True):
|
||||||
|
if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))):
|
||||||
|
labels = [labels]
|
||||||
|
else:
|
||||||
|
labels = list(labels)
|
||||||
|
|
||||||
|
while len(labels) < len(signals):
|
||||||
|
labels.append(None)
|
||||||
|
|
||||||
|
# check if there are any labels
|
||||||
|
if not any(labels):
|
||||||
|
labels = [f"signal {i + 1}" for i in range(len(signals))]
|
||||||
|
|
||||||
|
fig, axs = plt.subplots(1, 2, sharex=True, sharey=True)
|
||||||
|
fig.set_figwidth(18)
|
||||||
|
fig.set_figheight(4)
|
||||||
|
fig.suptitle(f"Fiber response{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}")
|
||||||
|
for i, ax in enumerate(axs):
|
||||||
|
for signal, label in zip(signals, labels):
|
||||||
|
if sps is not None:
|
||||||
|
xaxis = np.linspace(0, len(signal) / sps, len(signal), endpoint=False)
|
||||||
|
else:
|
||||||
|
xaxis = np.arange(len(signal))
|
||||||
|
ax.plot(xaxis, np.abs(signal[:, i]) ** 2, label=label)
|
||||||
|
ax.set_xlabel("Sample" if sps is None else "Symbol")
|
||||||
|
ax.set_ylabel("normalized power")
|
||||||
|
ax.legend(loc="upper right")
|
||||||
|
fig.tight_layout()
|
||||||
|
if show:
|
||||||
|
plt.show()
|
||||||
|
return fig
|
||||||
|
|
||||||
|
def plot_model_response(
|
||||||
|
self,
|
||||||
|
model=None,
|
||||||
|
title_append="",
|
||||||
|
subtitle="",
|
||||||
|
mode: Literal["eye", "head"] = "head",
|
||||||
|
show=False,
|
||||||
|
):
|
||||||
|
data_settings_backup = copy.deepcopy(self.data_settings)
|
||||||
|
pytorch_settings_backup = copy.deepcopy(self.pytorch_settings)
|
||||||
|
self.data_settings.drop_first = 100 * 128
|
||||||
|
self.data_settings.shuffle = False
|
||||||
|
self.data_settings.train_split = 1.0
|
||||||
|
self.pytorch_settings.batchsize = (
|
||||||
|
self.pytorch_settings.eye_symbols if mode == "eye" else self.pytorch_settings.head_symbols
|
||||||
|
)
|
||||||
|
plot_loader, _ = self.get_sliced_data(override={"num_symbols": self.pytorch_settings.batchsize})
|
||||||
|
self.data_settings = data_settings_backup
|
||||||
|
self.pytorch_settings = pytorch_settings_backup
|
||||||
|
|
||||||
|
fiber_in, fiber_out, regen = self.run_model(model, plot_loader)
|
||||||
|
fiber_in = fiber_in.view(-1, 2)
|
||||||
|
fiber_out = fiber_out.view(-1, 2)
|
||||||
|
regen = regen.view(-1, 2)
|
||||||
|
|
||||||
|
fiber_in = fiber_in.numpy()
|
||||||
|
fiber_out = fiber_out.numpy()
|
||||||
|
regen = regen.numpy()
|
||||||
|
|
||||||
|
# https://github.com/matplotlib/matplotlib/issues/27713#issue-2104110987
|
||||||
|
# https://github.com/matplotlib/matplotlib/issues/27713#issuecomment-1915497463
|
||||||
|
import gc
|
||||||
|
|
||||||
|
if mode == "head":
|
||||||
|
fig = self._plot_model_response_head(
|
||||||
|
fiber_in,
|
||||||
|
fiber_out,
|
||||||
|
regen,
|
||||||
|
labels=("fiber in", "fiber out", "regen"),
|
||||||
|
sps=plot_loader.dataset.samples_per_symbol,
|
||||||
|
title_append=title_append,
|
||||||
|
subtitle=subtitle,
|
||||||
|
show=show,
|
||||||
|
)
|
||||||
|
elif mode == "eye":
|
||||||
|
# raise NotImplementedError("Eye diagram not implemented")
|
||||||
|
fig = self._plot_model_response_eye(
|
||||||
|
fiber_in,
|
||||||
|
fiber_out,
|
||||||
|
regen,
|
||||||
|
labels=("fiber in", "fiber out", "regen"),
|
||||||
|
sps=plot_loader.dataset.samples_per_symbol,
|
||||||
|
title_append=title_append,
|
||||||
|
subtitle=subtitle,
|
||||||
|
show=show,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown mode: {mode}")
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
def build_title(self, number: int):
|
||||||
|
title_append = f"epoch {number}"
|
||||||
|
model_n_hidden_layers = self.model_settings.n_hidden_layers
|
||||||
|
input_dim = 2 * self.data_settings.output_size
|
||||||
|
model_dims = [
|
||||||
|
self.model_settings.overrides.get(f"n_hidden_nodes_{i}", -1) for i in range(model_n_hidden_layers)
|
||||||
|
]
|
||||||
|
model_dims.insert(0, input_dim)
|
||||||
|
model_dims.append(2)
|
||||||
|
model_dims = [str(dim) for dim in model_dims]
|
||||||
|
model_activation_func = self.model_settings.model_activation_func
|
||||||
|
model_dtype = self.data_settings.dtype
|
||||||
|
|
||||||
|
subtitle = f"{model_n_hidden_layers + 2} layers à ({', '.join(model_dims)}) units, {model_activation_func}, {model_dtype}"
|
||||||
|
|
||||||
|
return title_append, subtitle
|
||||||
Reference in New Issue
Block a user