move hypertraining class into separate file;

move settings dataclasses into separate file;
add SemiUnitaryLayer;
clean up model response plotting code;
cnt hyperparameter search
This commit is contained in:
Joseph Hopfmüller
2024-11-20 22:49:31 +01:00
parent cdca5de473
commit 674033ac2e
11 changed files with 1064 additions and 553 deletions

View File

@@ -0,0 +1,735 @@
import copy
from datetime import datetime
from pathlib import Path
from typing import Literal
import matplotlib.pyplot as plt
import numpy as np
import optuna
import warnings
import torch
import torch.nn as nn
# import torch.nn.functional as F # mse_loss doesn't support complex numbers
import torch.optim as optim
import torch.utils.data
from torch.utils.tensorboard import SummaryWriter
from rich.progress import (
Progress,
TextColumn,
BarColumn,
TaskProgressColumn,
TimeRemainingColumn,
MofNCompleteColumn,
TimeElapsedColumn,
)
from rich.console import Console
# from rich import print as rprint
import multiprocessing
from util.datasets import FiberRegenerationDataset
from util.optuna_helpers import (
force_suggest_categorical,
force_suggest_float,
force_suggest_int,
)
import util
from .settings import (
GlobalSettings,
DataSettings,
ModelSettings,
OptunaSettings,
OptimizerSettings,
PytorchSettings,
)
class HyperTraining:
def __init__(
self,
*,
global_settings,
data_settings,
pytorch_settings,
model_settings,
optimizer_settings,
optuna_settings,
console=None,
):
self.global_settings: GlobalSettings = global_settings
self.data_settings: DataSettings = data_settings
self.pytorch_settings: PytorchSettings = pytorch_settings
self.model_settings: ModelSettings = model_settings
self.optimizer_settings: OptimizerSettings = optimizer_settings
self.optuna_settings: OptunaSettings = optuna_settings
self.console = console or Console()
# set some extra settings to make the code more readable
self._extra_optuna_settings()
def setup_tb_writer(self, study_name=None, append=None):
log_dir = (
self.pytorch_settings.summary_dir
+ "/"
+ (study_name or self.optuna_settings.study_name)
)
if append is not None:
log_dir += "_" + str(append)
return SummaryWriter(log_dir)
def resume_latest_study(self, verbose=True):
study_name = self.get_latest_study()
if study_name:
print(f"Resuming study: {study_name}")
self.optuna_settings.study_name = study_name
def get_latest_study(self, verbose=True):
studies = self.get_studies()
for study in studies:
study.datetime_start = study.datetime_start or datetime.min
if studies:
study = sorted(studies, key=lambda x: x.datetime_start, reverse=True)[0]
if verbose:
print(f"Last study: {study.study_name}")
study_name = study.study_name
else:
if verbose:
print("No previous studies found")
study_name = None
return study_name
def get_studies(self):
return optuna.get_all_study_summaries(storage=self.optuna_settings.storage)
def setup_study(self):
self.study = optuna.create_study(
study_name=self.optuna_settings.study_name,
storage=self.optuna_settings.storage,
load_if_exists=True,
direction=self.optuna_settings.direction,
directions=self.optuna_settings.directions,
)
with warnings.catch_warnings(action="ignore"):
self.study.set_metric_names(self.optuna_settings.metrics_names)
self.n_threads = min(
self.optuna_settings.n_trials, self.optuna_settings.n_threads
)
self.processes = []
if self.n_threads > 1:
for _ in range(self.n_threads):
p = multiprocessing.Process(
# target=lambda n_trials: self._run_optimize(self, n_trials),
target=self._run_optimize,
args=(self.optuna_settings.n_trials // self.n_threads,),
)
self.processes.append(p)
# def plot_eye(self, width=2, symbols=None, alpha=None, complex=False, show=True):
# data, config = util.datasets.load_data(
# self.data_settings.config_path,
# skipfirst=10,
# symbols=symbols or 1000,
# real=not complex,
# normalize=True,
# )
# eye_data = {"data": data.numpy(), "sps": int(config["glova"]["sps"])}
# return util.plot.eye(
# **eye_data,
# width=width,
# show=show,
# alpha=alpha,
# complex=complex,
# symbols=symbols or 1000,
# skipfirst=0,
# )
def run_study(self):
if self.processes:
for p in self.processes:
p.start()
for p in self.processes:
p.join()
remaining_trials = self.optuna_settings.n_trials % self.n_threads
else:
remaining_trials = self.optuna_settings.n_trials
if remaining_trials:
self._run_optimize(remaining_trials)
def _run_optimize(self, n_trials):
self.study.optimize(
self.objective, n_trials=n_trials, timeout=self.optuna_settings.timeout
)
def _extra_optuna_settings(self):
self.optuna_settings.multi_objective = len(self.optuna_settings.directions) > 1
if self.optuna_settings.multi_objective:
self.optuna_settings.direction = None
else:
self.optuna_settings.direction = self.optuna_settings.directions[0]
self.optuna_settings.directions = None
self.optuna_settings.n_train_batches = (
self.optuna_settings.n_train_batches
if self.optuna_settings.limit_examples
else float("inf")
)
self.optuna_settings.n_valid_batches = (
self.optuna_settings.n_valid_batches
if self.optuna_settings.limit_examples
else float("inf")
)
def define_model(self, trial: optuna.Trial, writer=None):
n_layers = force_suggest_int(
trial, "model_n_layers", self.model_settings.model_n_layers
)
input_dim = 2 * trial.params.get(
"model_input_dim",
force_suggest_int(trial, "model_input_dim", self.data_settings.model_input_dim),
)
dtype = trial.params.get(
"model_dtype",
force_suggest_categorical(trial, "model_dtype", self.data_settings.dtype),
)
dtype = getattr(torch, dtype)
afunc = force_suggest_categorical(
trial, "model_activation_func", self.model_settings.model_activation_func
)
layers = []
last_dim = input_dim
for i in range(n_layers):
hidden_dim = force_suggest_int(
trial, f"model_hidden_dim_{i}", self.model_settings.unit_count
)
layers.append(
util.complexNN.SemiUnitaryLayer(last_dim, hidden_dim, dtype=dtype)
)
last_dim = hidden_dim
layers.append(getattr(util.complexNN, afunc)())
layers.append(
util.complexNN.UnitaryLayer(
hidden_dim, self.model_settings.output_dim, dtype=dtype
)
)
model = nn.Sequential(*layers)
if writer is not None:
writer.add_graph(
model, torch.zeros(1, input_dim, dtype=dtype), use_strict_trace=False
)
return model.to(self.pytorch_settings.device)
def get_sliced_data(self, trial: optuna.Trial, override=None):
symbols = trial.params.get(
"dataset_symbols",
force_suggest_float(trial, "dataset_symbols", self.data_settings.symbols),
)
xy_delay = trial.params.get(
"dataset_xy_delay",
force_suggest_float(trial, "dataset_xy_delay", self.data_settings.xy_delay),
)
data_size = trial.params.get(
"model_input_dim",
force_suggest_int(trial, "model_input_dim", self.data_settings.model_input_dim),
)
dtype = trial.params.get(
"model_dtype",
force_suggest_categorical(trial, "model_dtype", self.data_settings.dtype),
)
dtype = getattr(torch, dtype)
num_symbols = None
if override is not None:
num_symbols = override.get("num_symbols", None)
# get dataset
dataset = FiberRegenerationDataset(
file_path=self.data_settings.config_path,
symbols=symbols,
output_dim=data_size,
target_delay=self.data_settings.in_out_delay,
xy_delay=xy_delay,
drop_first=self.data_settings.drop_first,
dtype=dtype,
real=not dtype.is_complex,
num_symbols=num_symbols,
)
dataset_size = len(dataset)
indices = list(range(dataset_size))
split = int(np.floor(self.data_settings.train_split * dataset_size))
if self.data_settings.shuffle:
np.random.seed(self.global_settings.seed)
np.random.shuffle(indices)
train_indices, valid_indices = indices[:split], indices[split:]
if self.data_settings.shuffle:
train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
valid_sampler = torch.utils.data.SubsetRandomSampler(valid_indices)
else:
train_sampler = train_indices
valid_sampler = valid_indices
train_loader = torch.utils.data.DataLoader(
dataset,
batch_size=self.pytorch_settings.batchsize,
sampler=train_sampler,
drop_last=True,
pin_memory=True,
num_workers=self.pytorch_settings.dataloader_workers,
prefetch_factor=self.pytorch_settings.dataloader_prefetch,
)
valid_loader = torch.utils.data.DataLoader(
dataset,
batch_size=self.pytorch_settings.batchsize,
sampler=valid_sampler,
drop_last=True,
pin_memory=True,
num_workers=self.pytorch_settings.dataloader_workers,
prefetch_factor=self.pytorch_settings.dataloader_prefetch,
)
return train_loader, valid_loader
def train_model(
self,
trial,
model,
optimizer,
train_loader,
epoch,
writer=None,
enable_progress=False,
):
if enable_progress:
progress = Progress(
TextColumn("[yellow] Training..."),
TextColumn("Error: {task.description}"),
BarColumn(),
TaskProgressColumn(),
TextColumn("[green]Batch"),
MofNCompleteColumn(),
TimeRemainingColumn(),
TimeElapsedColumn(),
# description="Training",
transient=False,
console=self.console,
refresh_per_second=10,
)
task = progress.add_task("-.---e--", total=len(train_loader))
progress.start()
running_loss2 = 0.0
running_loss = 0.0
model.train()
for batch_idx, (x, y) in enumerate(train_loader):
if batch_idx >= self.optuna_settings.n_train_batches:
break
model.zero_grad(set_to_none=True)
x, y = (
x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device),
)
y_pred = model(x)
loss = util.complexNN.complex_mse_loss(y_pred, y)
loss_value = loss.item()
loss.backward()
optimizer.step()
running_loss2 += loss_value
running_loss += loss_value
if enable_progress:
progress.update(task, advance=1, description=f"{loss_value:.3e}")
if writer is not None:
if batch_idx % self.pytorch_settings.write_every == 0:
writer.add_scalar(
"training loss",
running_loss2
/ (self.pytorch_settings.write_every if batch_idx > 0 else 1),
epoch
* min(len(train_loader), self.optuna_settings.n_train_batches)
+ batch_idx,
)
running_loss2 = 0.0
if enable_progress:
progress.stop()
return running_loss / min(
len(train_loader), self.optuna_settings.n_train_batches
)
def eval_model(
self, trial, model, valid_loader, epoch, writer=None, enable_progress=True
):
if enable_progress:
progress = Progress(
TextColumn("[green]Evaluating..."),
TextColumn("Error: {task.description}"),
BarColumn(),
TaskProgressColumn(),
TextColumn("[green]Batch"),
MofNCompleteColumn(),
TimeRemainingColumn(),
TimeElapsedColumn(),
# description="Training",
transient=False,
console=self.console,
refresh_per_second=10,
)
progress.start()
task = progress.add_task("-.---e--", total=len(valid_loader))
model.eval()
running_error = 0
running_error_2 = 0
with torch.no_grad():
for batch_idx, (x, y) in enumerate(valid_loader):
if batch_idx >= self.optuna_settings.n_valid_batches:
break
x, y = (
x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device),
)
y_pred = model(x)
error = util.complexNN.complex_mse_loss(y_pred, y)
error_value = error.item()
running_error += error_value
running_error_2 += error_value
if enable_progress:
progress.update(task, advance=1, description=f"{error_value:.3e}")
if writer is not None:
if batch_idx % self.pytorch_settings.write_every == 0:
writer.add_scalar(
"eval loss",
running_error_2
/ (
self.pytorch_settings.write_every
if batch_idx > 0
else 1
),
epoch
* min(
len(valid_loader), self.optuna_settings.n_valid_batches
)
+ batch_idx,
)
running_error_2 = 0.0
running_error /= min(len(valid_loader), self.optuna_settings.n_valid_batches)
if writer is not None:
title_append, subtitle = self.build_title(trial)
writer.add_figure(
"fiber response",
self.plot_model_response(
trial,
model=model,
title_append=title_append,
subtitle=subtitle,
show=False,
),
epoch + 1,
)
if enable_progress:
progress.stop()
return running_error
def run_model(self, model, loader):
model.eval()
xs = []
ys = []
y_preds = []
with torch.no_grad():
model = model.to(self.pytorch_settings.device)
for x, y in loader:
x, y = (
x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device),
)
y_pred = model(x).cpu()
# x = x.cpu()
# y = y.cpu()
y_pred = y_pred.view(y_pred.shape[0], -1, 2)
y = y.view(y.shape[0], -1, 2)
x = x.view(x.shape[0], -1, 2)
xs.append(x[:, 0, :].squeeze())
ys.append(y.squeeze())
y_preds.append(y_pred.squeeze())
xs = torch.vstack(xs).cpu()
ys = torch.vstack(ys).cpu()
y_preds = torch.vstack(y_preds).cpu()
return ys, xs, y_preds
def objective(self, trial: optuna.Trial, plot_before=False):
model = None
exc = None
try:
# rprint(*list(self.study_name.split("_")))
writer = self.setup_tb_writer(
self.optuna_settings.study_name,
f"{trial.number:0{len(str(self.optuna_settings.n_trials))}}",
)
model = self.define_model(trial, writer)
n_params = sum(p.numel() for p in model.parameters())
# n_nodes = trial.params.get("model_n_layers", self.model_settings.model_n_layers) * trial.params.get("model_hidden_dim", self.model_settings.unit_count)
title_append, subtitle = self.build_title(trial)
writer.add_figure(
"fiber response",
self.plot_model_response(
trial,
model=model,
title_append=title_append,
subtitle=subtitle,
show=plot_before,
),
0,
)
train_loader, valid_loader = self.get_sliced_data(trial)
optimizer_name = force_suggest_categorical(
trial, "optimizer", self.optimizer_settings.optimizer
)
lr = force_suggest_float(
trial, "lr", self.optimizer_settings.learning_rate, log=True
)
optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)
if self.optimizer_settings.scheduler is not None:
scheduler = getattr(optim.lr_scheduler, self.optimizer_settings.scheduler)(
optimizer, **self.optimizer_settings.scheduler_kwargs)
for epoch in range(self.pytorch_settings.epochs):
enable_progress = self.optuna_settings.n_threads == 1
if enable_progress:
self.console.rule(
f"Epoch {epoch + 1}/{self.pytorch_settings.epochs}"
)
self.train_model(
trial,
model,
optimizer,
train_loader,
epoch,
writer,
enable_progress=enable_progress,
)
error = self.eval_model(
trial,
model,
valid_loader,
epoch,
writer,
enable_progress=enable_progress,
)
if self.optimizer_settings.scheduler is not None:
scheduler.step(error)
writer.close()
if self.optuna_settings.multi_objective:
return n_params, error
trial.report(error, epoch)
if trial.should_prune():
raise optuna.exceptions.TrialPruned()
return error
except KeyboardInterrupt:
...
# except Exception as e:
# exc = e
finally:
if model is not None:
save_path = (
Path(self.pytorch_settings.model_dir)
/ f"{self.optuna_settings.study_name}_{trial.number}.pth"
)
save_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(model, save_path)
if exc is not None:
raise exc
def _plot_model_response_eye(
self, *signals, labels=None, sps=None, title_append="", subtitle="", show=True
):
if sps is None:
raise ValueError("sps must be provided")
if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))):
labels = [labels]
else:
labels = list(labels)
while len(labels) < len(signals):
labels.append(None)
# check if there are any labels
if not any(labels):
labels = [f"signal {i + 1}" for i in range(len(signals))]
fig, axs = plt.subplots(2, len(signals), sharex=True, sharey=True)
fig.suptitle(
f"Eye diagram{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}"
)
xaxis = np.linspace(0, 2, 2 * sps, endpoint=False)
for j, (label, signal) in enumerate(zip(labels, signals)):
# signal = signal.cpu().numpy()
for i in range(len(signal) // sps - 1):
x, y = signal[i * sps : (i + 2) * sps].T
axs[0, j].plot(xaxis, np.abs(x) ** 2, color="C0", alpha=0.02)
axs[1, j].plot(xaxis, np.abs(y) ** 2, color="C0", alpha=0.02)
axs[0, j].set_title(label + " x")
axs[1, j].set_title(label + " y")
axs[0, j].set_xlabel("Symbol")
axs[1, j].set_xlabel("Symbol")
axs[0, j].set_ylabel("normalized power")
axs[1, j].set_ylabel("normalized power")
if show:
plt.show()
def _plot_model_response_head(
self, *signals, labels=None, sps=None, title_append="", subtitle="", show=True
):
if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))):
labels = [labels]
else:
labels = list(labels)
while len(labels) < len(signals):
labels.append(None)
# check if there are any labels
if not any(labels):
labels = [f"signal {i + 1}" for i in range(len(signals))]
fig, axs = plt.subplots(1, 2, sharex=True, sharey=True)
fig.set_size_inches(18,6)
fig.suptitle(
f"Fiber response{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}"
)
for i, ax in enumerate(axs):
for signal, label in zip(signals, labels):
if sps is not None:
xaxis = np.linspace(
0, len(signal) / sps, len(signal), endpoint=False
)
else:
xaxis = np.arange(len(signal))
ax.plot(xaxis, np.abs(signal[:, i]) ** 2, label=label)
ax.set_xlabel("Sample" if sps is None else "Symbol")
ax.set_ylabel("normalized power")
ax.legend(loc="upper right")
if show:
plt.show()
return fig
def plot_model_response(
self,
trial,
model=None,
title_append="",
subtitle="",
mode: Literal["eye", "head"] = "head",
show=True,
):
data_settings_backup = copy.deepcopy(self.data_settings)
pytorch_settings_backup = copy.deepcopy(self.pytorch_settings)
self.data_settings.drop_first = 100
self.data_settings.shuffle = False
self.data_settings.train_split = 1.0
self.pytorch_settings.batchsize = self.pytorch_settings.eye_symbols if mode == "eye" else self.pytorch_settings.head_symbols
plot_loader, _ = self.get_sliced_data(
trial, override={"num_symbols": self.pytorch_settings.batchsize}
)
self.data_settings = data_settings_backup
self.pytorch_settings = pytorch_settings_backup
fiber_in, fiber_out, regen = self.run_model(model, plot_loader)
fiber_in = fiber_in.view(-1, 2)
fiber_out = fiber_out.view(-1, 2)
regen = regen.view(-1, 2)
fiber_in = fiber_in.numpy()
fiber_out = fiber_out.numpy()
regen = regen.numpy()
# https://github.com/matplotlib/matplotlib/issues/27713#issue-2104110987
# https://github.com/matplotlib/matplotlib/issues/27713#issuecomment-1915497463
import gc
if mode == "head":
fig = self._plot_model_response_head(
fiber_in,
fiber_out,
regen,
labels=("fiber in", "fiber out", "regen"),
sps=plot_loader.dataset.samples_per_symbol,
title_append=title_append,
subtitle=subtitle,
show=show,
)
elif mode == "eye":
# raise NotImplementedError("Eye diagram not implemented")
fig = self._plot_model_response_eye(
fiber_in,
fiber_out,
regen,
labels=("fiber in", "fiber out", "regen"),
sps=plot_loader.dataset.samples_per_symbol,
title_append=title_append,
subtitle=subtitle,
show=show,
)
else:
raise ValueError(f"Unknown mode: {mode}")
gc.collect()
return fig
@staticmethod
def build_title(trial):
title_append = f"for trial {trial.number}"
subtitle = (
f"{trial.params['model_n_layers']} layers, "
f"{', '.join([str(trial.params[f'model_hidden_dim_{i}']) for i in range(trial.params['model_n_layers'])])} units, "
f"{trial.params['model_activation_func']}, "
f"{trial.params['model_dtype']}"
)
return title_append, subtitle