764 lines
28 KiB
Python
764 lines
28 KiB
Python
import copy
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Literal
|
|
import matplotlib.pyplot as plt
|
|
|
|
import numpy as np
|
|
import optuna
|
|
|
|
# import optunahub
|
|
import warnings
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
# import torch.nn.functional as F # mse_loss doesn't support complex numbers
|
|
import torch.optim as optim
|
|
import torch.utils.data
|
|
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
# from rich.progress import (
|
|
# Progress,
|
|
# TextColumn,
|
|
# BarColumn,
|
|
# TaskProgressColumn,
|
|
# TimeRemainingColumn,
|
|
# MofNCompleteColumn,
|
|
# TimeElapsedColumn,
|
|
# )
|
|
# from rich.console import Console
|
|
# from rich import print as rprint
|
|
|
|
import multiprocessing
|
|
|
|
from util.datasets import FiberRegenerationDataset
|
|
|
|
# from util.optuna_helpers import (
|
|
# suggest_categorical_optional, # noqa: F401
|
|
# suggest_float_optional, # noqa: F401
|
|
# suggest_int_optional, # noqa: F401
|
|
# )
|
|
from util.optuna_helpers import install_optional_suggests
|
|
import util
|
|
|
|
from .settings import (
|
|
GlobalSettings,
|
|
DataSettings,
|
|
ModelSettings,
|
|
OptunaSettings,
|
|
OptimizerSettings,
|
|
PytorchSettings,
|
|
)
|
|
|
|
install_optional_suggests()
|
|
|
|
|
|
class HyperTraining:
|
|
def __init__(
|
|
self,
|
|
*,
|
|
global_settings,
|
|
data_settings,
|
|
pytorch_settings,
|
|
model_settings,
|
|
optimizer_settings,
|
|
optuna_settings,
|
|
# console=None,
|
|
):
|
|
self.global_settings: GlobalSettings = global_settings
|
|
self.data_settings: DataSettings = data_settings
|
|
self.pytorch_settings: PytorchSettings = pytorch_settings
|
|
self.model_settings: ModelSettings = model_settings
|
|
self.optimizer_settings: OptimizerSettings = optimizer_settings
|
|
self.optuna_settings: OptunaSettings = optuna_settings
|
|
self.processes = None
|
|
|
|
# self.console = console or Console()
|
|
|
|
# set some extra settings to make the code more readable
|
|
self._extra_optuna_settings()
|
|
self.stop_study = True
|
|
|
|
def setup_tb_writer(self, study_name=None, append=None):
|
|
log_dir = self.pytorch_settings.summary_dir + "/" + (study_name or self.optuna_settings.study_name)
|
|
if append is not None:
|
|
log_dir += "_" + str(append)
|
|
|
|
return SummaryWriter(log_dir)
|
|
|
|
def resume_latest_study(self, verbose=True):
|
|
study_name = self.get_latest_study()
|
|
|
|
if study_name:
|
|
if verbose:
|
|
print(f"Resuming study: {study_name}")
|
|
self.optuna_settings.study_name = study_name
|
|
|
|
def get_latest_study(self, verbose=False) -> optuna.Study:
|
|
studies = self.get_studies()
|
|
study = None
|
|
for study in studies:
|
|
study.datetime_start = study.datetime_start or datetime.min
|
|
if studies:
|
|
study = sorted(studies, key=lambda x: x.datetime_start, reverse=True)[0]
|
|
if verbose:
|
|
print(f"Last study: {study.study_name}")
|
|
else:
|
|
if verbose:
|
|
print("No previous studies found")
|
|
return optuna.load_study(study_name=study.study_name, storage=self.optuna_settings.storage)
|
|
|
|
# def study(self) -> optuna.Study:
|
|
# return optuna.load_study(self.optuna_settings.study_name, storage=self.optuna_settings.storage)
|
|
|
|
def get_studies(self):
|
|
return optuna.get_all_study_summaries(storage=self.optuna_settings.storage)
|
|
|
|
def setup_study(self):
|
|
# module = optunahub.load_module(package="samplers/auto_sampler")
|
|
if self.optuna_settings._parallel:
|
|
self.processes = []
|
|
|
|
pruner = getattr(optuna.pruners, self.optuna_settings.pruner, None)
|
|
|
|
if pruner and self.optuna_settings.pruner_kwargs is not None:
|
|
pruner = pruner(**self.optuna_settings.pruner_kwargs)
|
|
elif pruner:
|
|
pruner = pruner()
|
|
|
|
self.study = optuna.create_study(
|
|
study_name=self.optuna_settings.study_name,
|
|
storage=self.optuna_settings.storage,
|
|
load_if_exists=True,
|
|
direction=self.optuna_settings._direction,
|
|
directions=self.optuna_settings._directions,
|
|
pruner=pruner,
|
|
# sampler=module.AutoSampler(),
|
|
)
|
|
|
|
# print("using sampler:", self.study.sampler)
|
|
|
|
with warnings.catch_warnings(action="ignore"):
|
|
self.study.set_metric_names(self.optuna_settings.metrics_names)
|
|
|
|
def run_study(self):
|
|
try:
|
|
if self.optuna_settings._parallel:
|
|
self._run_parallel_study()
|
|
else:
|
|
self._run_study()
|
|
except KeyboardInterrupt:
|
|
print("Stopping. Please wait for the processes to finish.")
|
|
self.stop_study = True
|
|
|
|
def trials_left(self):
|
|
return self.optuna_settings.n_trials - len(self.study.get_trials(states=self.optuna_settings.n_trials_filter))
|
|
|
|
def remove_completed_processes(self):
|
|
if self.processes is None:
|
|
return
|
|
for p, process in enumerate(self.processes):
|
|
if not process.is_alive():
|
|
process.join()
|
|
self.processes.pop(p)
|
|
|
|
def remove_outliers(self):
|
|
if self.optuna_settings.remove_outliers is not None:
|
|
trials = self.study.get_trials(states=(optuna.trial.TrialState.COMPLETE,))
|
|
if len(trials) == 0:
|
|
return
|
|
vals = [trial.value for trial in trials]
|
|
vals = np.log(vals)
|
|
mean = np.mean(vals)
|
|
std = np.std(vals)
|
|
outliers = [
|
|
trial for trial in trials if np.log(trial.value) > mean + self.optuna_settings.remove_outliers * std
|
|
]
|
|
for trial in outliers:
|
|
trial: optuna.trial.Trial = trial
|
|
trial.state = optuna.trial.TrialState.FAIL
|
|
trial.set_user_attr("outlier", True)
|
|
|
|
def _run_study(self):
|
|
while trials_left := self.trials_left():
|
|
self.remove_outliers()
|
|
self._run_optimize(n_trials=trials_left, timeout=self.optuna_settings.timeout)
|
|
|
|
def _run_parallel_study(self):
|
|
while trials_left := self.trials_left():
|
|
self.remove_outliers()
|
|
self.remove_completed_processes()
|
|
|
|
n_trials = max(trials_left, self.optuna_settings._n_threads) // self.optuna_settings._n_threads
|
|
|
|
def target_fun():
|
|
self._run_optimize(n_trials=n_trials, timeout=self.optuna_settings.timeout)
|
|
|
|
for _ in range(self.optuna_settings._n_threads - len(self.processes)):
|
|
self.processes.append(multiprocessing.Process(target=target_fun))
|
|
self.processes[-1].start()
|
|
|
|
def _run_optimize(self, **kwargs):
|
|
self.study.optimize(
|
|
self.objective,
|
|
**kwargs,
|
|
show_progress_bar=not self.optuna_settings._parallel,
|
|
)
|
|
|
|
def _extra_optuna_settings(self):
|
|
self.optuna_settings._multi_objective = len(self.optuna_settings.directions) > 1
|
|
|
|
if self.optuna_settings._multi_objective:
|
|
self.optuna_settings._direction = None
|
|
self.optuna_settings._directions = self.optuna_settings.directions
|
|
else:
|
|
self.optuna_settings._direction = self.optuna_settings.directions[0]
|
|
self.optuna_settings._directions = None
|
|
|
|
self.optuna_settings._n_train_batches = (
|
|
self.optuna_settings.n_train_batches if self.optuna_settings.limit_examples else float("inf")
|
|
)
|
|
self.optuna_settings._n_valid_batches = (
|
|
self.optuna_settings.n_valid_batches if self.optuna_settings.limit_examples else float("inf")
|
|
)
|
|
|
|
self.optuna_settings._n_threads = self.optuna_settings.n_workers
|
|
|
|
self.optuna_settings._parallel = self.optuna_settings._n_threads > 1
|
|
|
|
def define_model(self, trial: optuna.Trial, writer=None):
|
|
n_layers = trial.suggest_int_optional("model_n_hidden_layers", self.model_settings.n_hidden_layers)
|
|
|
|
input_dim = trial.suggest_int_optional(
|
|
"model_input_dim",
|
|
self.data_settings.output_size,
|
|
step=2,
|
|
multiply=2,
|
|
set_new=False,
|
|
)
|
|
|
|
# trial.set_user_attr("model_input_dim", input_dim)
|
|
|
|
dtype = trial.suggest_categorical_optional("model_dtype", self.data_settings.dtype, set_new=False)
|
|
dtype = getattr(torch, dtype)
|
|
|
|
afunc = trial.suggest_categorical_optional("model_activation_func", self.model_settings.model_activation_func)
|
|
# T0 = trial.suggest_float_optional("T0", self.model_settings.satabsT0 , log=True)
|
|
|
|
layers = []
|
|
last_dim = input_dim
|
|
n_nodes = last_dim
|
|
for i in range(n_layers):
|
|
if hidden_dim_override := self.model_settings.overrides.get(f"n_hidden_nodes_{i}", False):
|
|
hidden_dim = trial.suggest_int_optional(f"model_hidden_dim_{i}", hidden_dim_override)
|
|
else:
|
|
hidden_dim = trial.suggest_int_optional(
|
|
f"model_hidden_dim_{i}",
|
|
self.model_settings.n_hidden_nodes,
|
|
)
|
|
layers.append(util.complexNN.SemiUnitaryLayer(last_dim, hidden_dim, dtype=dtype))
|
|
last_dim = hidden_dim
|
|
layers.append(getattr(util.complexNN, afunc)())
|
|
n_nodes += last_dim
|
|
|
|
layers.append(util.complexNN.SemiUnitaryLayer(last_dim, self.model_settings.output_dim, dtype=dtype))
|
|
|
|
model = nn.Sequential(*layers)
|
|
|
|
if writer is not None:
|
|
writer.add_graph(model, torch.zeros(1, input_dim, dtype=dtype), use_strict_trace=False)
|
|
|
|
n_params = sum(p.numel() for p in model.parameters())
|
|
trial.set_user_attr("model_n_params", n_params)
|
|
trial.set_user_attr("model_n_nodes", n_nodes)
|
|
|
|
return model.to(self.pytorch_settings.device)
|
|
|
|
def get_sliced_data(self, trial: optuna.Trial, override=None):
|
|
symbols = trial.suggest_float_optional("dataset_symbols", self.data_settings.symbols, set_new=False)
|
|
|
|
in_out_delay = trial.suggest_float_optional(
|
|
"dataset_in_out_delay", self.data_settings.in_out_delay, set_new=False
|
|
)
|
|
|
|
xy_delay = trial.suggest_float_optional("dataset_xy_delay", self.data_settings.xy_delay, set_new=False)
|
|
|
|
data_size = int(
|
|
0.5
|
|
* trial.suggest_int_optional(
|
|
"model_input_dim",
|
|
self.data_settings.output_size,
|
|
step=2,
|
|
multiply=2,
|
|
set_new=False,
|
|
)
|
|
)
|
|
|
|
dtype = trial.suggest_categorical_optional("model_dtype", self.data_settings.dtype, set_new=False)
|
|
dtype = getattr(torch, dtype)
|
|
|
|
num_symbols = None
|
|
if override is not None:
|
|
num_symbols = override.get("num_symbols", None)
|
|
# get dataset
|
|
dataset = FiberRegenerationDataset(
|
|
file_path=self.data_settings.config_path,
|
|
symbols=symbols,
|
|
output_dim=data_size,
|
|
target_delay=in_out_delay,
|
|
xy_delay=xy_delay,
|
|
drop_first=self.data_settings.drop_first,
|
|
dtype=dtype,
|
|
real=not dtype.is_complex,
|
|
num_symbols=num_symbols,
|
|
)
|
|
|
|
dataset_size = len(dataset)
|
|
indices = list(range(dataset_size))
|
|
split = int(np.floor(self.data_settings.train_split * dataset_size))
|
|
if self.data_settings.shuffle:
|
|
np.random.seed(self.global_settings.seed)
|
|
np.random.shuffle(indices)
|
|
|
|
train_indices, valid_indices = indices[:split], indices[split:]
|
|
|
|
if self.data_settings.shuffle:
|
|
train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
|
|
valid_sampler = torch.utils.data.SubsetRandomSampler(valid_indices)
|
|
else:
|
|
train_sampler = train_indices
|
|
valid_sampler = valid_indices
|
|
|
|
train_loader = torch.utils.data.DataLoader(
|
|
dataset,
|
|
batch_size=self.pytorch_settings.batchsize,
|
|
sampler=train_sampler,
|
|
drop_last=True,
|
|
pin_memory=True,
|
|
num_workers=self.pytorch_settings.dataloader_workers,
|
|
prefetch_factor=self.pytorch_settings.dataloader_prefetch,
|
|
)
|
|
|
|
valid_loader = torch.utils.data.DataLoader(
|
|
dataset,
|
|
batch_size=self.pytorch_settings.batchsize,
|
|
sampler=valid_sampler,
|
|
drop_last=True,
|
|
pin_memory=True,
|
|
num_workers=self.pytorch_settings.dataloader_workers,
|
|
prefetch_factor=self.pytorch_settings.dataloader_prefetch,
|
|
)
|
|
|
|
return train_loader, valid_loader
|
|
|
|
def train_model(
|
|
self,
|
|
trial,
|
|
model,
|
|
optimizer,
|
|
train_loader,
|
|
epoch,
|
|
writer=None,
|
|
# enable_progress=False,
|
|
):
|
|
# if enable_progress:
|
|
# progress = Progress(
|
|
# TextColumn("[yellow] Training..."),
|
|
# TextColumn("Error: {task.description}"),
|
|
# BarColumn(),
|
|
# TaskProgressColumn(),
|
|
# TextColumn("[green]Batch"),
|
|
# MofNCompleteColumn(),
|
|
# TimeRemainingColumn(),
|
|
# TimeElapsedColumn(),
|
|
# # description="Training",
|
|
# transient=False,
|
|
# console=self.console,
|
|
# refresh_per_second=10,
|
|
# )
|
|
# task = progress.add_task("-.---e--", total=len(train_loader))
|
|
# progress.start()
|
|
|
|
running_loss2 = 0.0
|
|
running_loss = 0.0
|
|
model.train()
|
|
for batch_idx, (x, y) in enumerate(train_loader):
|
|
if batch_idx >= self.optuna_settings._n_train_batches:
|
|
break
|
|
model.zero_grad(set_to_none=True)
|
|
x, y = (
|
|
x.to(self.pytorch_settings.device),
|
|
y.to(self.pytorch_settings.device),
|
|
)
|
|
y_pred = model(x)
|
|
loss = util.complexNN.complex_mse_loss(y_pred, y)
|
|
loss_value = loss.item()
|
|
loss.backward()
|
|
optimizer.step()
|
|
running_loss2 += loss_value
|
|
running_loss += loss_value
|
|
|
|
# if enable_progress:
|
|
# progress.update(task, advance=1, description=f"{loss_value:.3e}")
|
|
|
|
if writer is not None:
|
|
if batch_idx % self.pytorch_settings.write_every == 0:
|
|
writer.add_scalar(
|
|
"training loss",
|
|
running_loss2 / (self.pytorch_settings.write_every if batch_idx > 0 else 1),
|
|
epoch * min(len(train_loader), self.optuna_settings._n_train_batches) + batch_idx,
|
|
)
|
|
running_loss2 = 0.0
|
|
|
|
# if enable_progress:
|
|
# progress.stop()
|
|
|
|
return running_loss / min(len(train_loader), self.optuna_settings._n_train_batches)
|
|
|
|
def eval_model(
|
|
self,
|
|
trial,
|
|
model,
|
|
valid_loader,
|
|
epoch,
|
|
writer=None,
|
|
# enable_progress=True
|
|
):
|
|
# if enable_progress:
|
|
# progress = Progress(
|
|
# TextColumn("[green]Evaluating..."),
|
|
# TextColumn("Error: {task.description}"),
|
|
# BarColumn(),
|
|
# TaskProgressColumn(),
|
|
# TextColumn("[green]Batch"),
|
|
# MofNCompleteColumn(),
|
|
# TimeRemainingColumn(),
|
|
# TimeElapsedColumn(),
|
|
# # description="Training",
|
|
# transient=False,
|
|
# console=self.console,
|
|
# refresh_per_second=10,
|
|
# )
|
|
# progress.start()
|
|
# task = progress.add_task("-.---e--", total=len(valid_loader))
|
|
|
|
model.eval()
|
|
running_error = 0
|
|
running_error_2 = 0
|
|
with torch.no_grad():
|
|
for batch_idx, (x, y) in enumerate(valid_loader):
|
|
if batch_idx >= self.optuna_settings._n_valid_batches:
|
|
break
|
|
x, y = (
|
|
x.to(self.pytorch_settings.device),
|
|
y.to(self.pytorch_settings.device),
|
|
)
|
|
y_pred = model(x)
|
|
error = util.complexNN.complex_mse_loss(y_pred, y)
|
|
error_value = error.item()
|
|
running_error += error_value
|
|
running_error_2 += error_value
|
|
|
|
# if enable_progress:
|
|
# progress.update(task, advance=1, description=f"{error_value:.3e}")
|
|
|
|
if writer is not None:
|
|
if batch_idx % self.pytorch_settings.write_every == 0:
|
|
writer.add_scalar(
|
|
"eval loss",
|
|
running_error_2 / (self.pytorch_settings.write_every if batch_idx > 0 else 1),
|
|
epoch * min(len(valid_loader), self.optuna_settings._n_valid_batches) + batch_idx,
|
|
)
|
|
running_error_2 = 0.0
|
|
|
|
running_error /= min(len(valid_loader), self.optuna_settings._n_valid_batches)
|
|
|
|
if writer is not None:
|
|
title_append, subtitle = self.build_title(trial)
|
|
writer.add_figure(
|
|
"fiber response",
|
|
self.plot_model_response(
|
|
trial,
|
|
model=model,
|
|
title_append=title_append,
|
|
subtitle=subtitle,
|
|
show=False,
|
|
),
|
|
epoch + 1,
|
|
)
|
|
|
|
# if enable_progress:
|
|
# progress.stop()
|
|
|
|
return running_error
|
|
|
|
def run_model(self, model, loader):
|
|
model.eval()
|
|
xs = []
|
|
ys = []
|
|
y_preds = []
|
|
with torch.no_grad():
|
|
model = model.to(self.pytorch_settings.device)
|
|
for x, y in loader:
|
|
x, y = (
|
|
x.to(self.pytorch_settings.device),
|
|
y.to(self.pytorch_settings.device),
|
|
)
|
|
y_pred = model(x).cpu()
|
|
# x = x.cpu()
|
|
# y = y.cpu()
|
|
y_pred = y_pred.view(y_pred.shape[0], -1, 2)
|
|
y = y.view(y.shape[0], -1, 2)
|
|
x = x.view(x.shape[0], -1, 2)
|
|
xs.append(x[:, 0, :].squeeze())
|
|
ys.append(y.squeeze())
|
|
y_preds.append(y_pred.squeeze())
|
|
|
|
xs = torch.vstack(xs).cpu()
|
|
ys = torch.vstack(ys).cpu()
|
|
y_preds = torch.vstack(y_preds).cpu()
|
|
return ys, xs, y_preds
|
|
|
|
def objective(self, trial: optuna.Trial, plot_before=False):
|
|
if self.stop_study:
|
|
trial.study.stop()
|
|
model = None
|
|
|
|
writer = self.setup_tb_writer(
|
|
self.optuna_settings.study_name,
|
|
f"{trial.number:0{len(str(self.optuna_settings.n_trials)) + 2}}",
|
|
)
|
|
|
|
model = self.define_model(trial, writer)
|
|
|
|
# n_nodes = trial.params.get("model_n_hidden_layers", self.model_settings.model_n_layers) * trial.params.get("model_hidden_dim", self.model_settings.unit_count)
|
|
|
|
title_append, subtitle = self.build_title(trial)
|
|
|
|
writer.add_figure(
|
|
"fiber response",
|
|
self.plot_model_response(
|
|
trial,
|
|
model=model,
|
|
title_append=title_append,
|
|
subtitle=subtitle,
|
|
show=plot_before,
|
|
),
|
|
0,
|
|
)
|
|
|
|
train_loader, valid_loader = self.get_sliced_data(trial)
|
|
|
|
optimizer_name = trial.suggest_categorical_optional("optimizer", self.optimizer_settings.optimizer)
|
|
|
|
lr = trial.suggest_float_optional("lr", self.optimizer_settings.learning_rate, log=True)
|
|
|
|
optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)
|
|
if self.optimizer_settings.scheduler is not None:
|
|
scheduler = getattr(optim.lr_scheduler, self.optimizer_settings.scheduler)(
|
|
optimizer, **self.optimizer_settings.scheduler_kwargs
|
|
)
|
|
|
|
for epoch in range(self.pytorch_settings.epochs):
|
|
trial.set_user_attr("epoch", epoch)
|
|
# enable_progress = self.optuna_settings.n_threads == 1
|
|
# if enable_progress:
|
|
# self.console.rule(
|
|
# f"Epoch {epoch + 1}/{self.pytorch_settings.epochs}"
|
|
# )
|
|
self.train_model(
|
|
trial,
|
|
model,
|
|
optimizer,
|
|
train_loader,
|
|
epoch,
|
|
writer,
|
|
# enable_progress=enable_progress,
|
|
)
|
|
error = self.eval_model(
|
|
trial,
|
|
model,
|
|
valid_loader,
|
|
epoch,
|
|
writer,
|
|
# enable_progress=enable_progress,
|
|
)
|
|
if self.optimizer_settings.scheduler is not None:
|
|
scheduler.step(error)
|
|
|
|
trial.set_user_attr("mse", error)
|
|
trial.set_user_attr("log_mse", np.log10(error + np.finfo(float).eps))
|
|
trial.set_user_attr("neg_mse", -error)
|
|
trial.set_user_attr("neg_log_mse", -np.log10(error + np.finfo(float).eps))
|
|
if not self.optuna_settings._multi_objective:
|
|
trial.report(error, epoch)
|
|
if trial.should_prune():
|
|
raise optuna.exceptions.TrialPruned()
|
|
|
|
writer.close()
|
|
|
|
if self.optuna_settings._multi_objective:
|
|
return -np.log10(error + np.finfo(float).eps), trial.user_attrs.get("model_n_nodes", -1)
|
|
|
|
if self.pytorch_settings.save_models and model is not None:
|
|
save_path = Path(self.pytorch_settings.model_dir) / f"{self.optuna_settings.study_name}_{trial.number}.pth"
|
|
save_path.parent.mkdir(parents=True, exist_ok=True)
|
|
torch.save(model, save_path)
|
|
|
|
return error
|
|
|
|
def _plot_model_response_eye(self, *signals, labels=None, sps=None, title_append="", subtitle="", show=True):
|
|
if sps is None:
|
|
raise ValueError("sps must be provided")
|
|
if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))):
|
|
labels = [labels]
|
|
else:
|
|
labels = list(labels)
|
|
|
|
while len(labels) < len(signals):
|
|
labels.append(None)
|
|
|
|
# check if there are any labels
|
|
if not any(labels):
|
|
labels = [f"signal {i + 1}" for i in range(len(signals))]
|
|
|
|
fig, axs = plt.subplots(2, len(signals), sharex=True, sharey=True)
|
|
fig.suptitle(f"Eye diagram{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}")
|
|
xaxis = np.linspace(0, 2, 2 * sps, endpoint=False)
|
|
for j, (label, signal) in enumerate(zip(labels, signals)):
|
|
# signal = signal.cpu().numpy()
|
|
for i in range(len(signal) // sps - 1):
|
|
x, y = signal[i * sps : (i + 2) * sps].T
|
|
axs[0, j].plot(xaxis, np.abs(x) ** 2, color="C0", alpha=0.02)
|
|
axs[1, j].plot(xaxis, np.abs(y) ** 2, color="C0", alpha=0.02)
|
|
axs[0, j].set_title(label + " x")
|
|
axs[1, j].set_title(label + " y")
|
|
axs[0, j].set_xlabel("Symbol")
|
|
axs[1, j].set_xlabel("Symbol")
|
|
axs[0, j].set_ylabel("normalized power")
|
|
axs[1, j].set_ylabel("normalized power")
|
|
|
|
if show:
|
|
plt.show()
|
|
return fig
|
|
|
|
def _plot_model_response_head(self, *signals, labels=None, sps=None, title_append="", subtitle="", show=True):
|
|
if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))):
|
|
labels = [labels]
|
|
else:
|
|
labels = list(labels)
|
|
|
|
while len(labels) < len(signals):
|
|
labels.append(None)
|
|
|
|
# check if there are any labels
|
|
if not any(labels):
|
|
labels = [f"signal {i + 1}" for i in range(len(signals))]
|
|
|
|
fig, axs = plt.subplots(1, 2, sharex=True, sharey=True)
|
|
fig.set_size_inches(18, 6)
|
|
fig.suptitle(f"Fiber response{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}")
|
|
for i, ax in enumerate(axs):
|
|
for signal, label in zip(signals, labels):
|
|
if sps is not None:
|
|
xaxis = np.linspace(0, len(signal) / sps, len(signal), endpoint=False)
|
|
else:
|
|
xaxis = np.arange(len(signal))
|
|
ax.plot(xaxis, np.abs(signal[:, i]) ** 2, label=label)
|
|
ax.set_xlabel("Sample" if sps is None else "Symbol")
|
|
ax.set_ylabel("normalized power")
|
|
ax.legend(loc="upper right")
|
|
if show:
|
|
plt.show()
|
|
return fig
|
|
|
|
def plot_model_response(
|
|
self,
|
|
trial,
|
|
model=None,
|
|
title_append="",
|
|
subtitle="",
|
|
mode: Literal["eye", "head"] = "head",
|
|
show=True,
|
|
):
|
|
data_settings_backup = copy.deepcopy(self.data_settings)
|
|
pytorch_settings_backup = copy.deepcopy(self.pytorch_settings)
|
|
self.data_settings.drop_first = 100*128
|
|
self.data_settings.shuffle = False
|
|
self.data_settings.train_split = 1.0
|
|
self.pytorch_settings.batchsize = (
|
|
self.pytorch_settings.eye_symbols if mode == "eye" else self.pytorch_settings.head_symbols
|
|
)
|
|
plot_loader, _ = self.get_sliced_data(trial, override={"num_symbols": self.pytorch_settings.batchsize})
|
|
self.data_settings = data_settings_backup
|
|
self.pytorch_settings = pytorch_settings_backup
|
|
|
|
fiber_in, fiber_out, regen = self.run_model(model, plot_loader)
|
|
fiber_in = fiber_in.view(-1, 2)
|
|
fiber_out = fiber_out.view(-1, 2)
|
|
regen = regen.view(-1, 2)
|
|
|
|
fiber_in = fiber_in.numpy()
|
|
fiber_out = fiber_out.numpy()
|
|
regen = regen.numpy()
|
|
|
|
# https://github.com/matplotlib/matplotlib/issues/27713#issue-2104110987
|
|
# https://github.com/matplotlib/matplotlib/issues/27713#issuecomment-1915497463
|
|
import gc
|
|
|
|
if mode == "head":
|
|
fig = self._plot_model_response_head(
|
|
fiber_in,
|
|
fiber_out,
|
|
regen,
|
|
labels=("fiber in", "fiber out", "regen"),
|
|
sps=plot_loader.dataset.samples_per_symbol,
|
|
title_append=title_append,
|
|
subtitle=subtitle,
|
|
show=show,
|
|
)
|
|
elif mode == "eye":
|
|
# raise NotImplementedError("Eye diagram not implemented")
|
|
fig = self._plot_model_response_eye(
|
|
fiber_in,
|
|
fiber_out,
|
|
regen,
|
|
labels=("fiber in", "fiber out", "regen"),
|
|
sps=plot_loader.dataset.samples_per_symbol,
|
|
title_append=title_append,
|
|
subtitle=subtitle,
|
|
show=show,
|
|
)
|
|
else:
|
|
raise ValueError(f"Unknown mode: {mode}")
|
|
gc.collect()
|
|
|
|
return fig
|
|
|
|
@staticmethod
|
|
def build_title(trial: optuna.trial.Trial):
|
|
title_append = f"for trial {trial.number}"
|
|
model_n_hidden_layers = util.misc.multi_getattr((trial.params, trial.user_attrs), "model_n_hidden_layers", 0)
|
|
input_dim = util.misc.multi_getattr((trial.params, trial.user_attrs), "model_input_dim", 0)
|
|
model_dims = [
|
|
util.misc.multi_getattr((trial.params, trial.user_attrs), f"model_hidden_dim_{i}", 0)
|
|
for i in range(model_n_hidden_layers)
|
|
]
|
|
model_dims.insert(0, input_dim)
|
|
model_dims.append(2)
|
|
model_dims = [str(dim) for dim in model_dims]
|
|
model_activation_func = util.misc.multi_getattr(
|
|
(trial.params, trial.user_attrs),
|
|
"model_activation_func",
|
|
"unknown act. fun",
|
|
)
|
|
model_dtype = util.misc.multi_getattr((trial.params, trial.user_attrs), "model_dtype", "unknown dtype")
|
|
|
|
subtitle = (
|
|
f"{model_n_hidden_layers+2} layers à ({', '.join(model_dims)}) units, {model_activation_func}, {model_dtype}"
|
|
)
|
|
|
|
return title_append, subtitle
|