refactor complex loss functions for improved readability; update settings and dataset classes for consistency

This commit is contained in:
Joseph Hopfmüller
2024-11-24 01:55:32 +01:00
parent 9a16a5637d
commit 7343ccb3a5
4 changed files with 392 additions and 361 deletions

View File

@@ -6,7 +6,8 @@ import matplotlib.pyplot as plt
import numpy as np import numpy as np
import optuna import optuna
import optunahub
# import optunahub
import warnings import warnings
import torch import torch
@@ -18,26 +19,28 @@ import torch.utils.data
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from rich.progress import ( # from rich.progress import (
Progress, # Progress,
TextColumn, # TextColumn,
BarColumn, # BarColumn,
TaskProgressColumn, # TaskProgressColumn,
TimeRemainingColumn, # TimeRemainingColumn,
MofNCompleteColumn, # MofNCompleteColumn,
TimeElapsedColumn, # TimeElapsedColumn,
) # )
from rich.console import Console # from rich.console import Console
# from rich import print as rprint # from rich import print as rprint
import multiprocessing import multiprocessing
from util.datasets import FiberRegenerationDataset from util.datasets import FiberRegenerationDataset
from util.optuna_helpers import (
force_suggest_categorical, # from util.optuna_helpers import (
force_suggest_float, # suggest_categorical_optional, # noqa: F401
force_suggest_int, # suggest_float_optional, # noqa: F401
) # suggest_int_optional, # noqa: F401
# )
from util.optuna_helpers import install_optional_suggests
import util import util
from .settings import ( from .settings import (
@@ -49,6 +52,8 @@ from .settings import (
PytorchSettings, PytorchSettings,
) )
install_optional_suggests()
class HyperTraining: class HyperTraining:
def __init__( def __init__(
@@ -60,7 +65,7 @@ class HyperTraining:
model_settings, model_settings,
optimizer_settings, optimizer_settings,
optuna_settings, optuna_settings,
console=None, # console=None,
): ):
self.global_settings: GlobalSettings = global_settings self.global_settings: GlobalSettings = global_settings
self.data_settings: DataSettings = data_settings self.data_settings: DataSettings = data_settings
@@ -68,18 +73,16 @@ class HyperTraining:
self.model_settings: ModelSettings = model_settings self.model_settings: ModelSettings = model_settings
self.optimizer_settings: OptimizerSettings = optimizer_settings self.optimizer_settings: OptimizerSettings = optimizer_settings
self.optuna_settings: OptunaSettings = optuna_settings self.optuna_settings: OptunaSettings = optuna_settings
self.processes = None
self.console = console or Console() # self.console = console or Console()
# set some extra settings to make the code more readable # set some extra settings to make the code more readable
self._extra_optuna_settings() self._extra_optuna_settings()
self.stop_study = True
def setup_tb_writer(self, study_name=None, append=None): def setup_tb_writer(self, study_name=None, append=None):
log_dir = ( log_dir = self.pytorch_settings.summary_dir + "/" + (study_name or self.optuna_settings.study_name)
self.pytorch_settings.summary_dir
+ "/"
+ (study_name or self.optuna_settings.study_name)
)
if append is not None: if append is not None:
log_dir += "_" + str(append) log_dir += "_" + str(append)
@@ -89,180 +92,211 @@ class HyperTraining:
study_name = self.get_latest_study() study_name = self.get_latest_study()
if study_name: if study_name:
if verbose:
print(f"Resuming study: {study_name}") print(f"Resuming study: {study_name}")
self.optuna_settings.study_name = study_name self.optuna_settings.study_name = study_name
def get_latest_study(self, verbose=True): def get_latest_study(self, verbose=False) -> optuna.Study:
studies = self.get_studies() studies = self.get_studies()
study = None
for study in studies: for study in studies:
study.datetime_start = study.datetime_start or datetime.min study.datetime_start = study.datetime_start or datetime.min
if studies: if studies:
study = sorted(studies, key=lambda x: x.datetime_start, reverse=True)[0] study = sorted(studies, key=lambda x: x.datetime_start, reverse=True)[0]
if verbose: if verbose:
print(f"Last study: {study.study_name}") print(f"Last study: {study.study_name}")
study_name = study.study_name
else: else:
if verbose: if verbose:
print("No previous studies found") print("No previous studies found")
study_name = None return optuna.load_study(study_name=study.study_name, storage=self.optuna_settings.storage)
return study_name
# def study(self) -> optuna.Study:
# return optuna.load_study(self.optuna_settings.study_name, storage=self.optuna_settings.storage)
def get_studies(self): def get_studies(self):
return optuna.get_all_study_summaries(storage=self.optuna_settings.storage) return optuna.get_all_study_summaries(storage=self.optuna_settings.storage)
def setup_study(self): def setup_study(self):
module = optunahub.load_module(package="samplers/auto_sampler") # module = optunahub.load_module(package="samplers/auto_sampler")
if self.optuna_settings._parallel:
self.processes = []
pruner = getattr(optuna.pruners, self.optuna_settings.pruner, None)
if pruner and self.optuna_settings.pruner_kwargs is not None:
pruner = pruner(**self.optuna_settings.pruner_kwargs)
elif pruner:
pruner = pruner()
self.study = optuna.create_study( self.study = optuna.create_study(
study_name=self.optuna_settings.study_name, study_name=self.optuna_settings.study_name,
storage=self.optuna_settings.storage, storage=self.optuna_settings.storage,
load_if_exists=True, load_if_exists=True,
direction=self.optuna_settings.direction, direction=self.optuna_settings._direction,
directions=self.optuna_settings.directions, directions=self.optuna_settings._directions,
sampler=module.AutoSampler(), pruner=pruner,
# sampler=module.AutoSampler(),
) )
print("using sampler:", self.study.sampler) # print("using sampler:", self.study.sampler)
with warnings.catch_warnings(action="ignore"): with warnings.catch_warnings(action="ignore"):
self.study.set_metric_names(self.optuna_settings.metrics_names) self.study.set_metric_names(self.optuna_settings.metrics_names)
self.n_threads = min(
self.optuna_settings.n_trials, self.optuna_settings.n_threads
)
self.processes = []
if self.n_threads > 1:
for _ in range(self.n_threads):
p = multiprocessing.Process(
# target=lambda n_trials: self._run_optimize(self, n_trials),
target=self._run_optimize,
args=(self.optuna_settings.n_trials // self.n_threads,),
)
self.processes.append(p)
# def plot_eye(self, width=2, symbols=None, alpha=None, complex=False, show=True):
# data, config = util.datasets.load_data(
# self.data_settings.config_path,
# skipfirst=10,
# symbols=symbols or 1000,
# real=not complex,
# normalize=True,
# )
# eye_data = {"data": data.numpy(), "sps": int(config["glova"]["sps"])}
# return util.plot.eye(
# **eye_data,
# width=width,
# show=show,
# alpha=alpha,
# complex=complex,
# symbols=symbols or 1000,
# skipfirst=0,
# )
def run_study(self): def run_study(self):
if self.processes: try:
for p in self.processes: if self.optuna_settings._parallel:
p.start() self._run_parallel_study()
for p in self.processes:
p.join()
remaining_trials = self.optuna_settings.n_trials % self.n_threads
else: else:
remaining_trials = self.optuna_settings.n_trials self._run_study()
except KeyboardInterrupt:
print("Stopping. Please wait for the processes to finish.")
self.stop_study = True
if remaining_trials: def trials_left(self):
self._run_optimize(remaining_trials) return self.optuna_settings.n_trials - len(self.study.get_trials(states=self.optuna_settings.n_trials_filter))
def _run_optimize(self, n_trials): def remove_completed_processes(self):
if self.processes is None:
return
for p, process in enumerate(self.processes):
if not process.is_alive():
process.join()
self.processes.pop(p)
def remove_outliers(self):
if self.optuna_settings.remove_outliers is not None:
trials = self.study.get_trials(states=(optuna.trial.TrialState.COMPLETE,))
if len(trials) == 0:
return
vals = [trial.value for trial in trials]
vals = np.log(vals)
mean = np.mean(vals)
std = np.std(vals)
outliers = [
trial for trial in trials if np.log(trial.value) > mean + self.optuna_settings.remove_outliers * std
]
for trial in outliers:
trial: optuna.trial.Trial = trial
trial.state = optuna.trial.TrialState.FAIL
trial.set_user_attr("outlier", True)
def _run_study(self):
while trials_left := self.trials_left():
self.remove_outliers()
self._run_optimize(n_trials=trials_left, timeout=self.optuna_settings.timeout)
def _run_parallel_study(self):
while trials_left := self.trials_left():
self.remove_outliers()
self.remove_completed_processes()
n_trials = max(trials_left, self.optuna_settings._n_threads) // self.optuna_settings._n_threads
def target_fun():
self._run_optimize(n_trials=n_trials, timeout=self.optuna_settings.timeout)
for _ in range(self.optuna_settings._n_threads - len(self.processes)):
self.processes.append(multiprocessing.Process(target=target_fun))
self.processes[-1].start()
def _run_optimize(self, **kwargs):
self.study.optimize( self.study.optimize(
self.objective, n_trials=n_trials, timeout=self.optuna_settings.timeout self.objective,
**kwargs,
show_progress_bar=not self.optuna_settings._parallel,
) )
def _extra_optuna_settings(self): def _extra_optuna_settings(self):
self.optuna_settings.multi_objective = len(self.optuna_settings.directions) > 1 self.optuna_settings._multi_objective = len(self.optuna_settings.directions) > 1
if self.optuna_settings.multi_objective:
self.optuna_settings.direction = None
else:
self.optuna_settings.direction = self.optuna_settings.directions[0]
self.optuna_settings.directions = None
self.optuna_settings.n_train_batches = ( if self.optuna_settings._multi_objective:
self.optuna_settings.n_train_batches self.optuna_settings._direction = None
if self.optuna_settings.limit_examples self.optuna_settings._directions = self.optuna_settings.directions
else float("inf") else:
self.optuna_settings._direction = self.optuna_settings.directions[0]
self.optuna_settings._directions = None
self.optuna_settings._n_train_batches = (
self.optuna_settings.n_train_batches if self.optuna_settings.limit_examples else float("inf")
) )
self.optuna_settings.n_valid_batches = ( self.optuna_settings._n_valid_batches = (
self.optuna_settings.n_valid_batches self.optuna_settings.n_valid_batches if self.optuna_settings.limit_examples else float("inf")
if self.optuna_settings.limit_examples
else float("inf")
) )
self.optuna_settings._n_threads = self.optuna_settings.n_workers
self.optuna_settings._parallel = self.optuna_settings._n_threads > 1
def define_model(self, trial: optuna.Trial, writer=None): def define_model(self, trial: optuna.Trial, writer=None):
n_layers = force_suggest_int( n_layers = trial.suggest_int_optional("model_n_hidden_layers", self.model_settings.n_hidden_layers)
trial, "model_n_layers", self.model_settings.model_n_layers
)
input_dim = 2 * trial.params.get( input_dim = trial.suggest_int_optional(
"model_input_dim", "model_input_dim",
force_suggest_int(trial, "model_input_dim", self.data_settings.model_input_dim), self.data_settings.output_size,
step=2,
multiply=2,
set_new=False,
) )
dtype = trial.params.get( # trial.set_user_attr("model_input_dim", input_dim)
"model_dtype",
force_suggest_categorical(trial, "model_dtype", self.data_settings.dtype), dtype = trial.suggest_categorical_optional("model_dtype", self.data_settings.dtype, set_new=False)
)
dtype = getattr(torch, dtype) dtype = getattr(torch, dtype)
afunc = force_suggest_categorical( afunc = trial.suggest_categorical_optional("model_activation_func", self.model_settings.model_activation_func)
trial, "model_activation_func", self.model_settings.model_activation_func
)
layers = [] layers = []
last_dim = input_dim last_dim = input_dim
n_nodes = last_dim
for i in range(n_layers): for i in range(n_layers):
hidden_dim = force_suggest_int( if hidden_dim_override := self.model_settings.overrides.get(f"n_hidden_nodes_{i}", False):
trial, f"model_hidden_dim_{i}", self.model_settings.unit_count hidden_dim = trial.suggest_int_optional(f"model_hidden_dim_{i}", hidden_dim_override, force=True)
) else:
layers.append( hidden_dim = trial.suggest_int_optional(
util.complexNN.SemiUnitaryLayer(last_dim, hidden_dim, dtype=dtype) f"model_hidden_dim_{i}",
self.model_settings.n_hidden_nodes,
# step=2,
) )
layers.append(util.complexNN.SemiUnitaryLayer(last_dim, hidden_dim, dtype=dtype))
last_dim = hidden_dim last_dim = hidden_dim
layers.append(getattr(util.complexNN, afunc)()) layers.append(getattr(util.complexNN, afunc)())
n_nodes += last_dim
layers.append( layers.append(util.complexNN.SemiUnitaryLayer(last_dim, self.model_settings.output_dim, dtype=dtype))
util.complexNN.UnitaryLayer(
hidden_dim, self.model_settings.output_dim, dtype=dtype
)
)
model = nn.Sequential(*layers) model = nn.Sequential(*layers)
if writer is not None: if writer is not None:
writer.add_graph( writer.add_graph(model, torch.zeros(1, input_dim, dtype=dtype), use_strict_trace=False)
model, torch.zeros(1, input_dim, dtype=dtype), use_strict_trace=False
) n_params = sum(p.numel() for p in model.parameters())
trial.set_user_attr("model_n_params", n_params)
trial.set_user_attr("model_n_nodes", n_nodes)
return model.to(self.pytorch_settings.device) return model.to(self.pytorch_settings.device)
def get_sliced_data(self, trial: optuna.Trial, override=None): def get_sliced_data(self, trial: optuna.Trial, override=None):
symbols = trial.params.get( symbols = trial.suggest_float_optional("dataset_symbols", self.data_settings.symbols, set_new=False)
"dataset_symbols",
force_suggest_float(trial, "dataset_symbols", self.data_settings.symbols), in_out_delay = trial.suggest_float_optional(
"dataset_in_out_delay", self.data_settings.in_out_delay, set_new=False
) )
xy_delay = trial.params.get( xy_delay = trial.suggest_float_optional("dataset_xy_delay", self.data_settings.xy_delay, set_new=False)
"dataset_xy_delay",
force_suggest_float(trial, "dataset_xy_delay", self.data_settings.xy_delay),
)
data_size = trial.params.get( data_size = int(
0.5
* trial.suggest_int_optional(
"model_input_dim", "model_input_dim",
force_suggest_int(trial, "model_input_dim", self.data_settings.model_input_dim), self.data_settings.output_size,
step=2,
multiply=2,
set_new=False,
)
) )
dtype = trial.params.get( dtype = trial.suggest_categorical_optional("model_dtype", self.data_settings.dtype, set_new=False)
"model_dtype",
force_suggest_categorical(trial, "model_dtype", self.data_settings.dtype),
)
dtype = getattr(torch, dtype) dtype = getattr(torch, dtype)
num_symbols = None num_symbols = None
@@ -273,7 +307,7 @@ class HyperTraining:
file_path=self.data_settings.config_path, file_path=self.data_settings.config_path,
symbols=symbols, symbols=symbols,
output_dim=data_size, output_dim=data_size,
target_delay=self.data_settings.in_out_delay, target_delay=in_out_delay,
xy_delay=xy_delay, xy_delay=xy_delay,
drop_first=self.data_settings.drop_first, drop_first=self.data_settings.drop_first,
dtype=dtype, dtype=dtype,
@@ -327,31 +361,31 @@ class HyperTraining:
train_loader, train_loader,
epoch, epoch,
writer=None, writer=None,
enable_progress=False, # enable_progress=False,
): ):
if enable_progress: # if enable_progress:
progress = Progress( # progress = Progress(
TextColumn("[yellow] Training..."), # TextColumn("[yellow] Training..."),
TextColumn("Error: {task.description}"), # TextColumn("Error: {task.description}"),
BarColumn(), # BarColumn(),
TaskProgressColumn(), # TaskProgressColumn(),
TextColumn("[green]Batch"), # TextColumn("[green]Batch"),
MofNCompleteColumn(), # MofNCompleteColumn(),
TimeRemainingColumn(), # TimeRemainingColumn(),
TimeElapsedColumn(), # TimeElapsedColumn(),
# description="Training", # # description="Training",
transient=False, # transient=False,
console=self.console, # console=self.console,
refresh_per_second=10, # refresh_per_second=10,
) # )
task = progress.add_task("-.---e--", total=len(train_loader)) # task = progress.add_task("-.---e--", total=len(train_loader))
progress.start() # progress.start()
running_loss2 = 0.0 running_loss2 = 0.0
running_loss = 0.0 running_loss = 0.0
model.train() model.train()
for batch_idx, (x, y) in enumerate(train_loader): for batch_idx, (x, y) in enumerate(train_loader):
if batch_idx >= self.optuna_settings.n_train_batches: if batch_idx >= self.optuna_settings._n_train_batches:
break break
model.zero_grad(set_to_none=True) model.zero_grad(set_to_none=True)
x, y = ( x, y = (
@@ -366,55 +400,56 @@ class HyperTraining:
running_loss2 += loss_value running_loss2 += loss_value
running_loss += loss_value running_loss += loss_value
if enable_progress: # if enable_progress:
progress.update(task, advance=1, description=f"{loss_value:.3e}") # progress.update(task, advance=1, description=f"{loss_value:.3e}")
if writer is not None: if writer is not None:
if batch_idx % self.pytorch_settings.write_every == 0: if batch_idx % self.pytorch_settings.write_every == 0:
writer.add_scalar( writer.add_scalar(
"training loss", "training loss",
running_loss2 running_loss2 / (self.pytorch_settings.write_every if batch_idx > 0 else 1),
/ (self.pytorch_settings.write_every if batch_idx > 0 else 1), epoch * min(len(train_loader), self.optuna_settings._n_train_batches) + batch_idx,
epoch
* min(len(train_loader), self.optuna_settings.n_train_batches)
+ batch_idx,
) )
running_loss2 = 0.0 running_loss2 = 0.0
if enable_progress: # if enable_progress:
progress.stop() # progress.stop()
return running_loss / min( return running_loss / min(len(train_loader), self.optuna_settings._n_train_batches)
len(train_loader), self.optuna_settings.n_train_batches
)
def eval_model( def eval_model(
self, trial, model, valid_loader, epoch, writer=None, enable_progress=True self,
trial,
model,
valid_loader,
epoch,
writer=None,
# enable_progress=True
): ):
if enable_progress: # if enable_progress:
progress = Progress( # progress = Progress(
TextColumn("[green]Evaluating..."), # TextColumn("[green]Evaluating..."),
TextColumn("Error: {task.description}"), # TextColumn("Error: {task.description}"),
BarColumn(), # BarColumn(),
TaskProgressColumn(), # TaskProgressColumn(),
TextColumn("[green]Batch"), # TextColumn("[green]Batch"),
MofNCompleteColumn(), # MofNCompleteColumn(),
TimeRemainingColumn(), # TimeRemainingColumn(),
TimeElapsedColumn(), # TimeElapsedColumn(),
# description="Training", # # description="Training",
transient=False, # transient=False,
console=self.console, # console=self.console,
refresh_per_second=10, # refresh_per_second=10,
) # )
progress.start() # progress.start()
task = progress.add_task("-.---e--", total=len(valid_loader)) # task = progress.add_task("-.---e--", total=len(valid_loader))
model.eval() model.eval()
running_error = 0 running_error = 0
running_error_2 = 0 running_error_2 = 0
with torch.no_grad(): with torch.no_grad():
for batch_idx, (x, y) in enumerate(valid_loader): for batch_idx, (x, y) in enumerate(valid_loader):
if batch_idx >= self.optuna_settings.n_valid_batches: if batch_idx >= self.optuna_settings._n_valid_batches:
break break
x, y = ( x, y = (
x.to(self.pytorch_settings.device), x.to(self.pytorch_settings.device),
@@ -426,28 +461,19 @@ class HyperTraining:
running_error += error_value running_error += error_value
running_error_2 += error_value running_error_2 += error_value
if enable_progress: # if enable_progress:
progress.update(task, advance=1, description=f"{error_value:.3e}") # progress.update(task, advance=1, description=f"{error_value:.3e}")
if writer is not None: if writer is not None:
if batch_idx % self.pytorch_settings.write_every == 0: if batch_idx % self.pytorch_settings.write_every == 0:
writer.add_scalar( writer.add_scalar(
"eval loss", "eval loss",
running_error_2 running_error_2 / (self.pytorch_settings.write_every if batch_idx > 0 else 1),
/ ( epoch * min(len(valid_loader), self.optuna_settings._n_valid_batches) + batch_idx,
self.pytorch_settings.write_every
if batch_idx > 0
else 1
),
epoch
* min(
len(valid_loader), self.optuna_settings.n_valid_batches
)
+ batch_idx,
) )
running_error_2 = 0.0 running_error_2 = 0.0
running_error /= min(len(valid_loader), self.optuna_settings.n_valid_batches) running_error /= min(len(valid_loader), self.optuna_settings._n_valid_batches)
if writer is not None: if writer is not None:
title_append, subtitle = self.build_title(trial) title_append, subtitle = self.build_title(trial)
@@ -463,8 +489,8 @@ class HyperTraining:
epoch + 1, epoch + 1,
) )
if enable_progress: # if enable_progress:
progress.stop() # progress.stop()
return running_error return running_error
@@ -496,19 +522,18 @@ class HyperTraining:
return ys, xs, y_preds return ys, xs, y_preds
def objective(self, trial: optuna.Trial, plot_before=False): def objective(self, trial: optuna.Trial, plot_before=False):
if self.stop_study:
trial.study.stop()
model = None model = None
exc = None
try:
# rprint(*list(self.study_name.split("_")))
writer = self.setup_tb_writer( writer = self.setup_tb_writer(
self.optuna_settings.study_name, self.optuna_settings.study_name,
f"{trial.number:0{len(str(self.optuna_settings.n_trials))}}", f"{trial.number:0{len(str(self.optuna_settings.n_trials)) + 2}}",
) )
model = self.define_model(trial, writer) model = self.define_model(trial, writer)
n_params = sum(p.numel() for p in model.parameters())
# n_nodes = trial.params.get("model_n_layers", self.model_settings.model_n_layers) * trial.params.get("model_hidden_dim", self.model_settings.unit_count) # n_nodes = trial.params.get("model_n_hidden_layers", self.model_settings.model_n_layers) * trial.params.get("model_hidden_dim", self.model_settings.unit_count)
title_append, subtitle = self.build_title(trial) title_append, subtitle = self.build_title(trial)
@@ -526,25 +551,23 @@ class HyperTraining:
train_loader, valid_loader = self.get_sliced_data(trial) train_loader, valid_loader = self.get_sliced_data(trial)
optimizer_name = force_suggest_categorical( optimizer_name = trial.suggest_categorical_optional("optimizer", self.optimizer_settings.optimizer)
trial, "optimizer", self.optimizer_settings.optimizer
)
lr = force_suggest_float( lr = trial.suggest_float_optional("lr", self.optimizer_settings.learning_rate, log=True)
trial, "lr", self.optimizer_settings.learning_rate, log=True
)
optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr) optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)
if self.optimizer_settings.scheduler is not None: if self.optimizer_settings.scheduler is not None:
scheduler = getattr(optim.lr_scheduler, self.optimizer_settings.scheduler)( scheduler = getattr(optim.lr_scheduler, self.optimizer_settings.scheduler)(
optimizer, **self.optimizer_settings.scheduler_kwargs) optimizer, **self.optimizer_settings.scheduler_kwargs
)
for epoch in range(self.pytorch_settings.epochs): for epoch in range(self.pytorch_settings.epochs):
enable_progress = self.optuna_settings.n_threads == 1 trial.set_user_attr("epoch", epoch)
if enable_progress: # enable_progress = self.optuna_settings.n_threads == 1
self.console.rule( # if enable_progress:
f"Epoch {epoch + 1}/{self.pytorch_settings.epochs}" # self.console.rule(
) # f"Epoch {epoch + 1}/{self.pytorch_settings.epochs}"
# )
self.train_model( self.train_model(
trial, trial,
model, model,
@@ -552,7 +575,7 @@ class HyperTraining:
train_loader, train_loader,
epoch, epoch,
writer, writer,
enable_progress=enable_progress, # enable_progress=enable_progress,
) )
error = self.eval_model( error = self.eval_model(
trial, trial,
@@ -560,39 +583,33 @@ class HyperTraining:
valid_loader, valid_loader,
epoch, epoch,
writer, writer,
enable_progress=enable_progress, # enable_progress=enable_progress,
) )
if self.optimizer_settings.scheduler is not None: if self.optimizer_settings.scheduler is not None:
scheduler.step(error) scheduler.step(error)
writer.close() trial.set_user_attr("mse", error)
trial.set_user_attr("log_mse", np.log10(error + np.finfo(float).eps))
if self.optuna_settings.multi_objective: trial.set_user_attr("neg_mse", -error)
return n_params, error trial.set_user_attr("neg_log_mse", -np.log10(error + np.finfo(float).eps))
if not self.optuna_settings._multi_objective:
trial.report(error, epoch) trial.report(error, epoch)
if trial.should_prune(): if trial.should_prune():
raise optuna.exceptions.TrialPruned() raise optuna.exceptions.TrialPruned()
return error
except KeyboardInterrupt: writer.close()
...
# except Exception as e: if self.optuna_settings._multi_objective:
# exc = e return -np.log10(error + np.finfo(float).eps), trial.user_attrs.get("model_n_nodes", -1)
finally:
if model is not None: if self.pytorch_settings.save_models and model is not None:
save_path = ( save_path = Path(self.pytorch_settings.model_dir) / f"{self.optuna_settings.study_name}_{trial.number}.pth"
Path(self.pytorch_settings.model_dir)
/ f"{self.optuna_settings.study_name}_{trial.number}.pth"
)
save_path.parent.mkdir(parents=True, exist_ok=True) save_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(model, save_path) torch.save(model, save_path)
if exc is not None:
raise exc
return error
def _plot_model_response_eye( def _plot_model_response_eye(self, *signals, labels=None, sps=None, title_append="", subtitle="", show=True):
self, *signals, labels=None, sps=None, title_append="", subtitle="", show=True
):
if sps is None: if sps is None:
raise ValueError("sps must be provided") raise ValueError("sps must be provided")
if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))): if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))):
@@ -608,9 +625,7 @@ class HyperTraining:
labels = [f"signal {i + 1}" for i in range(len(signals))] labels = [f"signal {i + 1}" for i in range(len(signals))]
fig, axs = plt.subplots(2, len(signals), sharex=True, sharey=True) fig, axs = plt.subplots(2, len(signals), sharex=True, sharey=True)
fig.suptitle( fig.suptitle(f"Eye diagram{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}")
f"Eye diagram{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}"
)
xaxis = np.linspace(0, 2, 2 * sps, endpoint=False) xaxis = np.linspace(0, 2, 2 * sps, endpoint=False)
for j, (label, signal) in enumerate(zip(labels, signals)): for j, (label, signal) in enumerate(zip(labels, signals)):
# signal = signal.cpu().numpy() # signal = signal.cpu().numpy()
@@ -628,9 +643,7 @@ class HyperTraining:
if show: if show:
plt.show() plt.show()
def _plot_model_response_head( def _plot_model_response_head(self, *signals, labels=None, sps=None, title_append="", subtitle="", show=True):
self, *signals, labels=None, sps=None, title_append="", subtitle="", show=True
):
if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))): if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))):
labels = [labels] labels = [labels]
else: else:
@@ -644,16 +657,12 @@ class HyperTraining:
labels = [f"signal {i + 1}" for i in range(len(signals))] labels = [f"signal {i + 1}" for i in range(len(signals))]
fig, axs = plt.subplots(1, 2, sharex=True, sharey=True) fig, axs = plt.subplots(1, 2, sharex=True, sharey=True)
fig.set_size_inches(18,6) fig.set_size_inches(18, 6)
fig.suptitle( fig.suptitle(f"Fiber response{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}")
f"Fiber response{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}"
)
for i, ax in enumerate(axs): for i, ax in enumerate(axs):
for signal, label in zip(signals, labels): for signal, label in zip(signals, labels):
if sps is not None: if sps is not None:
xaxis = np.linspace( xaxis = np.linspace(0, len(signal) / sps, len(signal), endpoint=False)
0, len(signal) / sps, len(signal), endpoint=False
)
else: else:
xaxis = np.arange(len(signal)) xaxis = np.arange(len(signal))
ax.plot(xaxis, np.abs(signal[:, i]) ** 2, label=label) ax.plot(xaxis, np.abs(signal[:, i]) ** 2, label=label)
@@ -678,10 +687,10 @@ class HyperTraining:
self.data_settings.drop_first = 100 self.data_settings.drop_first = 100
self.data_settings.shuffle = False self.data_settings.shuffle = False
self.data_settings.train_split = 1.0 self.data_settings.train_split = 1.0
self.pytorch_settings.batchsize = self.pytorch_settings.eye_symbols if mode == "eye" else self.pytorch_settings.head_symbols self.pytorch_settings.batchsize = (
plot_loader, _ = self.get_sliced_data( self.pytorch_settings.eye_symbols if mode == "eye" else self.pytorch_settings.head_symbols
trial, override={"num_symbols": self.pytorch_settings.batchsize}
) )
plot_loader, _ = self.get_sliced_data(trial, override={"num_symbols": self.pytorch_settings.batchsize})
self.data_settings = data_settings_backup self.data_settings = data_settings_backup
self.pytorch_settings = pytorch_settings_backup self.pytorch_settings = pytorch_settings_backup
@@ -728,13 +737,22 @@ class HyperTraining:
return fig return fig
@staticmethod @staticmethod
def build_title(trial): def build_title(trial: optuna.trial.Trial):
title_append = f"for trial {trial.number}" title_append = f"for trial {trial.number}"
model_n_layers = util.misc.multi_getattr((trial.params, trial.user_attrs), "model_n_layers", 0)
model_hidden_dims = [
util.misc.multi_getattr((trial.params, trial.user_attrs), f"model_hidden_dim_{i}", 0)
for i in range(model_n_layers)
]
model_activation_func = util.misc.multi_getattr(
(trial.params, trial.user_attrs),
"model_activation_func",
"unknown act. fun",
)
model_dtype = util.misc.multi_getattr((trial.params, trial.user_attrs), "model_dtype", "unknown dtype")
subtitle = ( subtitle = (
f"{trial.params['model_n_layers']} layers, " f"{model_n_layers} layers à ({', '.join(model_hidden_dims)}) units, {model_activation_func}, {model_dtype}"
f"{', '.join([str(trial.params[f'model_hidden_dim_{i}']) for i in range(trial.params['model_n_layers'])])} units, "
f"{trial.params['model_activation_func']}, "
f"{trial.params['model_dtype']}"
) )
return title_append, subtitle return title_append, subtitle

View File

@@ -1,4 +1,4 @@
from dataclasses import dataclass from dataclasses import dataclass, field
from datetime import datetime from datetime import datetime
@@ -14,7 +14,7 @@ class DataSettings:
config_path: str # = "data/*-128-16384-100000-0-0-17-0-PAM4-0.ini" config_path: str # = "data/*-128-16384-100000-0-0-17-0-PAM4-0.ini"
dtype: tuple = ("complex64", "float64") dtype: tuple = ("complex64", "float64")
symbols: tuple | float | int = 8 symbols: tuple | float | int = 8
model_input_dim: tuple | float | int = 64 output_size: tuple | float | int = 64
shuffle: bool = True shuffle: bool = True
in_out_delay: float = 0 in_out_delay: float = 0
xy_delay: tuple | float | int = 0 xy_delay: tuple | float | int = 0
@@ -33,6 +33,7 @@ class PytorchSettings:
dataloader_workers: int = 2 dataloader_workers: int = 2
dataloader_prefetch: int = 2 dataloader_prefetch: int = 2
save_models: bool = True
model_dir: str = ".models" model_dir: str = ".models"
summary_dir: str = ".runs" summary_dir: str = ".runs"
@@ -45,11 +46,10 @@ class PytorchSettings:
@dataclass @dataclass
class ModelSettings: class ModelSettings:
output_dim: int = 2 output_dim: int = 2
model_n_layers: tuple | int = 3 n_hidden_layers: tuple | int = 3
unit_count: tuple | int = 8 n_hidden_nodes: tuple | int = 8
# n_units_range: tuple | int = (2, 32) model_activation_func: tuple = "ModReLU"
# activation_func_range: tuple = ("ModReLU", "ZReLU", "CReLU", "Mag", "Identity") overrides: dict = field(default_factory=dict)
model_activation_func: tuple = ("ModReLU",)
@dataclass @dataclass
@@ -60,18 +60,36 @@ class OptimizerSettings:
scheduler_kwargs: dict | None = None scheduler_kwargs: dict | None = None
def _pruner_default_kwargs():
# MedianPruner
return {
"n_startup_trials": 0,
"n_warmup_steps": 5,
}
# optuna settings # optuna settings
@dataclass @dataclass
class OptunaSettings: class OptunaSettings:
n_trials: int = 128 n_trials: int = 128
n_threads: int = 4 n_workers: int = 1
timeout: int = 600 timeout: int = None
pruner: str = "MedianPruner"
pruner_kwargs: dict = field(default_factory=_pruner_default_kwargs)
directions: tuple = ("minimize",) directions: tuple = ("minimize",)
metrics_names: tuple = ("mse",) metrics_names: tuple = ("mse",)
limit_examples: bool = True limit_examples: bool = True
n_train_batches: int = 100 n_train_batches: int = float("inf")
n_valid_batches: int = 100 n_valid_batches: int = float("inf")
storage: str = "sqlite:///example.db" storage: str = "sqlite:///example.db"
study_name: str = ( study_name: str = (
f"optuna_study_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}" f"optuna_study_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}"
) )
n_trials_filter: tuple|list = None, #(optuna.trial.TrialState.COMPLETE,)
remove_outliers: float|int = None
## reserved, set by HyperTraining
_multi_objective = None
_parallel = None
_n_threads = None
_directions = None
_direction = None
_n_train_batches = None
_n_valid_batches = None

View File

@@ -8,10 +8,7 @@ def complex_mse_loss(input, target):
Compute the mean squared error between two complex tensors. Compute the mean squared error between two complex tensors.
""" """
if input.is_complex(): if input.is_complex():
return torch.mean( return torch.mean(torch.square(input.real - target.real) + torch.square(input.imag - target.imag))
torch.square(input.real - target.real)
+ torch.square(input.imag - target.imag)
)
else: else:
return F.mse_loss(input, target) return F.mse_loss(input, target)
@@ -21,10 +18,7 @@ def complex_sse_loss(input, target):
Compute the sum squared error between two complex tensors. Compute the sum squared error between two complex tensors.
""" """
if input.is_complex(): if input.is_complex():
return torch.sum( return torch.sum(torch.square(input.real - target.real) + torch.square(input.imag - target.imag))
torch.square(input.real - target.real)
+ torch.square(input.imag - target.imag)
)
else: else:
return torch.sum(torch.square(input - target)) return torch.sum(torch.square(input - target))
@@ -48,6 +42,7 @@ class UnitaryLayer(nn.Module):
def __repr__(self): def __repr__(self):
return f"UnitaryLayer({self.in_features}, {self.out_features})" return f"UnitaryLayer({self.in_features}, {self.out_features})"
class SemiUnitaryLayer(nn.Module): class SemiUnitaryLayer(nn.Module):
def __init__(self, input_dim, output_dim, dtype=None): def __init__(self, input_dim, output_dim, dtype=None):
super(SemiUnitaryLayer, self).__init__() super(SemiUnitaryLayer, self).__init__()
@@ -62,9 +57,9 @@ class SemiUnitaryLayer(nn.Module):
# Ensure the weights are semi-unitary by QR decomposition # Ensure the weights are semi-unitary by QR decomposition
q, _ = torch.linalg.qr(self.weight) q, _ = torch.linalg.qr(self.weight)
if self.input_dim > self.output_dim: if self.input_dim > self.output_dim:
self.weight.data = q[:self.input_dim, :self.output_dim] self.weight.data = q[: self.input_dim, : self.output_dim]
else: else:
self.weight.data = q[:self.output_dim, :self.input_dim].t() self.weight.data = q[: self.output_dim, : self.input_dim].t()
def forward(self, x): def forward(self, x):
out = torch.matmul(x, self.weight) out = torch.matmul(x, self.weight)

View File

@@ -1,6 +1,7 @@
from pathlib import Path from pathlib import Path
import torch import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
# from torch.utils.data import Sampler # from torch.utils.data import Sampler
import numpy as np import numpy as np
import configparser import configparser
@@ -22,6 +23,7 @@ import configparser
# def __len__(self): # def __len__(self):
# return len(self.indices) # return len(self.indices)
def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=False, device=None, dtype=None): def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=False, device=None, dtype=None):
filepath = Path(config_path) filepath = Path(config_path)
filepath = filepath.parent.glob(filepath.name) filepath = filepath.parent.glob(filepath.name)
@@ -43,7 +45,7 @@ def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=Fals
if normalize: if normalize:
# square gets normalized to 1, as the power is (proportional to) the square of the amplitude # square gets normalized to 1, as the power is (proportional to) the square of the amplitude
a, b, c, d = np.square(data.T) a, b, c, d = np.square(data.T)
a, b, c, d = a/np.max(np.abs(a)), b/np.max(np.abs(b)), c/np.max(np.abs(c)), d/np.max(np.abs(d)) a, b, c, d = a / np.max(np.abs(a)), b / np.max(np.abs(b)), c / np.max(np.abs(c)), d / np.max(np.abs(d))
data = np.sqrt(np.array([a, b, c, d]).T) data = np.sqrt(np.array([a, b, c, d]).T)
if real: if real:
@@ -55,6 +57,7 @@ def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=Fals
return data, config return data, config
def roll_along(arr, shifts, dim): def roll_along(arr, shifts, dim):
# https://stackoverflow.com/a/76920720 # https://stackoverflow.com/a/76920720
# (c) Mateen Ulhaq, 2023 # (c) Mateen Ulhaq, 2023
@@ -67,6 +70,7 @@ def roll_along(arr, shifts, dim):
indices = (dim_indices - shifts.unsqueeze(dim)) % arr.shape[dim] indices = (dim_indices - shifts.unsqueeze(dim)) % arr.shape[dim]
return torch.gather(arr, dim, indices) return torch.gather(arr, dim, indices)
class FiberRegenerationDataset(Dataset): class FiberRegenerationDataset(Dataset):
""" """
Dataset for fiber regeneration training. Dataset for fiber regeneration training.
@@ -105,7 +109,7 @@ class FiberRegenerationDataset(Dataset):
drop_first: float | int = 0, drop_first: float | int = 0,
dtype: torch.dtype = None, dtype: torch.dtype = None,
real: bool = False, real: bool = False,
device = None, device=None,
**kwargs, **kwargs,
): ):
""" """
@@ -127,18 +131,10 @@ class FiberRegenerationDataset(Dataset):
# check types # check types
assert isinstance(file_path, str), "file_path must be a string" assert isinstance(file_path, str), "file_path must be a string"
assert isinstance(symbols, (float, int)), ( assert isinstance(symbols, (float, int)), "symbols must be a float or an integer"
"symbols must be a float or an integer" assert output_dim is None or isinstance(output_dim, int), "output_len must be an integer"
) assert isinstance(target_delay, (float, int)), "target_delay must be a float or an integer"
assert output_dim is None or isinstance(output_dim, int), ( assert isinstance(xy_delay, (float, int)), "xy_delay must be a float or an integer"
"output_len must be an integer"
)
assert isinstance(target_delay, (float, int)), (
"target_delay must be a float or an integer"
)
assert isinstance(xy_delay, (float, int)), (
"xy_delay must be a float or an integer"
)
assert isinstance(drop_first, int), "drop_first must be an integer" assert isinstance(drop_first, int), "drop_first must be an integer"
# check values # check values
@@ -159,7 +155,15 @@ class FiberRegenerationDataset(Dataset):
"glova": {"sps": 128}, "glova": {"sps": 128},
} }
else: else:
data_raw, self.config = load_data(file_path, skipfirst=drop_first, symbols=kwargs.pop("num_symbols", None), real=real, normalize=True, device=device, dtype=dtype) data_raw, self.config = load_data(
file_path,
skipfirst=drop_first,
symbols=kwargs.pop("num_symbols", None),
real=real,
normalize=True,
device=device,
dtype=dtype,
)
self.device = data_raw.device self.device = data_raw.device
@@ -180,9 +184,7 @@ class FiberRegenerationDataset(Dataset):
else int(self.target_delay * self.samples_per_symbol) else int(self.target_delay * self.samples_per_symbol)
) )
self.xy_delay_samples = ( self.xy_delay_samples = (
ovrd_xy_delay_samples ovrd_xy_delay_samples if ovrd_xy_delay_samples is not None else int(self.xy_delay * self.samples_per_symbol)
if ovrd_xy_delay_samples is not None
else int(self.xy_delay * self.samples_per_symbol)
) )
# data_raw = torch.tensor(data_raw, dtype=dtype) # data_raw = torch.tensor(data_raw, dtype=dtype)
@@ -208,9 +210,7 @@ class FiberRegenerationDataset(Dataset):
# E_out_y[0:N] ] E_out_y[-3:N-3] ] E_out_y[0:N-3] ] # E_out_y[0:N] ] E_out_y[-3:N-3] ] E_out_y[0:N-3] ]
if self.xy_delay_samples != 0: if self.xy_delay_samples != 0:
data_raw = roll_along( data_raw = roll_along(data_raw, [0, self.xy_delay_samples, 0, self.xy_delay_samples], dim=1)
data_raw, [0, self.xy_delay_samples, 0, self.xy_delay_samples], dim=1
)
if self.xy_delay_samples > 0: if self.xy_delay_samples > 0:
data_raw = data_raw[:, self.xy_delay_samples :] data_raw = data_raw[:, self.xy_delay_samples :]
elif self.xy_delay_samples < 0: elif self.xy_delay_samples < 0: