import copy from datetime import datetime from pathlib import Path from typing import Literal import matplotlib.pyplot as plt import numpy as np import optuna # import optunahub import warnings import torch import torch.nn as nn # import torch.nn.functional as F # mse_loss doesn't support complex numbers import torch.optim as optim import torch.utils.data from torch.utils.tensorboard import SummaryWriter # from rich.progress import ( # Progress, # TextColumn, # BarColumn, # TaskProgressColumn, # TimeRemainingColumn, # MofNCompleteColumn, # TimeElapsedColumn, # ) # from rich.console import Console # from rich import print as rprint import multiprocessing from util.datasets import FiberRegenerationDataset # from util.optuna_helpers import ( # suggest_categorical_optional, # noqa: F401 # suggest_float_optional, # noqa: F401 # suggest_int_optional, # noqa: F401 # ) from util.optuna_helpers import install_optional_suggests import util from .settings import ( GlobalSettings, DataSettings, ModelSettings, OptunaSettings, OptimizerSettings, PytorchSettings, ) install_optional_suggests() class HyperTraining: def __init__( self, *, global_settings, data_settings, pytorch_settings, model_settings, optimizer_settings, optuna_settings, # console=None, ): self.global_settings: GlobalSettings = global_settings self.data_settings: DataSettings = data_settings self.pytorch_settings: PytorchSettings = pytorch_settings self.model_settings: ModelSettings = model_settings self.optimizer_settings: OptimizerSettings = optimizer_settings self.optuna_settings: OptunaSettings = optuna_settings self.processes = None # self.console = console or Console() # set some extra settings to make the code more readable self._extra_optuna_settings() self.stop_study = True def setup_tb_writer(self, study_name=None, append=None): log_dir = self.pytorch_settings.summary_dir + "/" + (study_name or self.optuna_settings.study_name) if append is not None: log_dir += "_" + str(append) return SummaryWriter(log_dir) def resume_latest_study(self, verbose=True): study_name = self.get_latest_study() if study_name: if verbose: print(f"Resuming study: {study_name}") self.optuna_settings.study_name = study_name def get_latest_study(self, verbose=False) -> optuna.Study: studies = self.get_studies() study = None for study in studies: study.datetime_start = study.datetime_start or datetime.min if studies: study = sorted(studies, key=lambda x: x.datetime_start, reverse=True)[0] if verbose: print(f"Last study: {study.study_name}") else: if verbose: print("No previous studies found") return optuna.load_study(study_name=study.study_name, storage=self.optuna_settings.storage) # def study(self) -> optuna.Study: # return optuna.load_study(self.optuna_settings.study_name, storage=self.optuna_settings.storage) def get_studies(self): return optuna.get_all_study_summaries(storage=self.optuna_settings.storage) def setup_study(self): # module = optunahub.load_module(package="samplers/auto_sampler") if self.optuna_settings._parallel: self.processes = [] pruner = getattr(optuna.pruners, self.optuna_settings.pruner, None) if pruner and self.optuna_settings.pruner_kwargs is not None: pruner = pruner(**self.optuna_settings.pruner_kwargs) elif pruner: pruner = pruner() self.study = optuna.create_study( study_name=self.optuna_settings.study_name, storage=self.optuna_settings.storage, load_if_exists=True, direction=self.optuna_settings._direction, directions=self.optuna_settings._directions, pruner=pruner, # sampler=module.AutoSampler(), ) # print("using sampler:", self.study.sampler) with warnings.catch_warnings(action="ignore"): self.study.set_metric_names(self.optuna_settings.metrics_names) def run_study(self): try: if self.optuna_settings._parallel: self._run_parallel_study() else: self._run_study() except KeyboardInterrupt: print("Stopping. Please wait for the processes to finish.") self.stop_study = True def trials_left(self): return self.optuna_settings.n_trials - len(self.study.get_trials(states=self.optuna_settings.n_trials_filter)) def remove_completed_processes(self): if self.processes is None: return for p, process in enumerate(self.processes): if not process.is_alive(): process.join() self.processes.pop(p) def remove_outliers(self): if self.optuna_settings.remove_outliers is not None: trials = self.study.get_trials(states=(optuna.trial.TrialState.COMPLETE,)) if len(trials) == 0: return vals = [trial.value for trial in trials] vals = np.log(vals) mean = np.mean(vals) std = np.std(vals) outliers = [ trial for trial in trials if np.log(trial.value) > mean + self.optuna_settings.remove_outliers * std ] for trial in outliers: trial: optuna.trial.Trial = trial trial.state = optuna.trial.TrialState.FAIL trial.set_user_attr("outlier", True) def _run_study(self): while trials_left := self.trials_left(): self.remove_outliers() self._run_optimize(n_trials=trials_left, timeout=self.optuna_settings.timeout) def _run_parallel_study(self): while trials_left := self.trials_left(): self.remove_outliers() self.remove_completed_processes() n_trials = max(trials_left, self.optuna_settings._n_threads) // self.optuna_settings._n_threads def target_fun(): self._run_optimize(n_trials=n_trials, timeout=self.optuna_settings.timeout) for _ in range(self.optuna_settings._n_threads - len(self.processes)): self.processes.append(multiprocessing.Process(target=target_fun)) self.processes[-1].start() def _run_optimize(self, **kwargs): self.study.optimize( self.objective, **kwargs, show_progress_bar=not self.optuna_settings._parallel, ) def _extra_optuna_settings(self): self.optuna_settings._multi_objective = len(self.optuna_settings.directions) > 1 if self.optuna_settings._multi_objective: self.optuna_settings._direction = None self.optuna_settings._directions = self.optuna_settings.directions else: self.optuna_settings._direction = self.optuna_settings.directions[0] self.optuna_settings._directions = None self.optuna_settings._n_train_batches = ( self.optuna_settings.n_train_batches if self.optuna_settings.limit_examples else float("inf") ) self.optuna_settings._n_valid_batches = ( self.optuna_settings.n_valid_batches if self.optuna_settings.limit_examples else float("inf") ) self.optuna_settings._n_threads = self.optuna_settings.n_workers self.optuna_settings._parallel = self.optuna_settings._n_threads > 1 def define_model(self, trial: optuna.Trial, writer=None): n_layers = trial.suggest_int_optional("model_n_hidden_layers", self.model_settings.n_hidden_layers) input_dim = trial.suggest_int_optional( "model_input_dim", self.data_settings.output_size, step=2, multiply=2, set_new=False, ) # trial.set_user_attr("model_input_dim", input_dim) dtype = trial.suggest_categorical_optional("model_dtype", self.data_settings.dtype, set_new=False) dtype = getattr(torch, dtype) afunc = trial.suggest_categorical_optional("model_activation_func", self.model_settings.model_activation_func) # T0 = trial.suggest_float_optional("T0", self.model_settings.satabsT0 , log=True) layers = [] last_dim = input_dim n_nodes = last_dim for i in range(n_layers): if hidden_dim_override := self.model_settings.overrides.get(f"n_hidden_nodes_{i}", False): hidden_dim = trial.suggest_int_optional(f"model_hidden_dim_{i}", hidden_dim_override) else: hidden_dim = trial.suggest_int_optional( f"model_hidden_dim_{i}", self.model_settings.n_hidden_nodes, ) layers.append(util.complexNN.SemiUnitaryLayer(last_dim, hidden_dim, dtype=dtype)) last_dim = hidden_dim layers.append(getattr(util.complexNN, afunc)()) n_nodes += last_dim layers.append(util.complexNN.SemiUnitaryLayer(last_dim, self.model_settings.output_dim, dtype=dtype)) model = nn.Sequential(*layers) if writer is not None: writer.add_graph(model, torch.zeros(1, input_dim, dtype=dtype), use_strict_trace=False) n_params = sum(p.numel() for p in model.parameters()) trial.set_user_attr("model_n_params", n_params) trial.set_user_attr("model_n_nodes", n_nodes) return model.to(self.pytorch_settings.device) def get_sliced_data(self, trial: optuna.Trial, override=None): symbols = trial.suggest_float_optional("dataset_symbols", self.data_settings.symbols, set_new=False) in_out_delay = trial.suggest_float_optional( "dataset_in_out_delay", self.data_settings.in_out_delay, set_new=False ) xy_delay = trial.suggest_float_optional("dataset_xy_delay", self.data_settings.xy_delay, set_new=False) data_size = int( 0.5 * trial.suggest_int_optional( "model_input_dim", self.data_settings.output_size, step=2, multiply=2, set_new=False, ) ) dtype = trial.suggest_categorical_optional("model_dtype", self.data_settings.dtype, set_new=False) dtype = getattr(torch, dtype) num_symbols = None if override is not None: num_symbols = override.get("num_symbols", None) # get dataset dataset = FiberRegenerationDataset( file_path=self.data_settings.config_path, symbols=symbols, output_dim=data_size, target_delay=in_out_delay, xy_delay=xy_delay, drop_first=self.data_settings.drop_first, dtype=dtype, real=not dtype.is_complex, num_symbols=num_symbols, ) dataset_size = len(dataset) indices = list(range(dataset_size)) split = int(np.floor(self.data_settings.train_split * dataset_size)) if self.data_settings.shuffle: np.random.seed(self.global_settings.seed) np.random.shuffle(indices) train_indices, valid_indices = indices[:split], indices[split:] if self.data_settings.shuffle: train_sampler = torch.utils.data.SubsetRandomSampler(train_indices) valid_sampler = torch.utils.data.SubsetRandomSampler(valid_indices) else: train_sampler = train_indices valid_sampler = valid_indices train_loader = torch.utils.data.DataLoader( dataset, batch_size=self.pytorch_settings.batchsize, sampler=train_sampler, drop_last=True, pin_memory=True, num_workers=self.pytorch_settings.dataloader_workers, prefetch_factor=self.pytorch_settings.dataloader_prefetch, ) valid_loader = torch.utils.data.DataLoader( dataset, batch_size=self.pytorch_settings.batchsize, sampler=valid_sampler, drop_last=True, pin_memory=True, num_workers=self.pytorch_settings.dataloader_workers, prefetch_factor=self.pytorch_settings.dataloader_prefetch, ) return train_loader, valid_loader def train_model( self, trial, model, optimizer, train_loader, epoch, writer=None, # enable_progress=False, ): # if enable_progress: # progress = Progress( # TextColumn("[yellow] Training..."), # TextColumn("Error: {task.description}"), # BarColumn(), # TaskProgressColumn(), # TextColumn("[green]Batch"), # MofNCompleteColumn(), # TimeRemainingColumn(), # TimeElapsedColumn(), # # description="Training", # transient=False, # console=self.console, # refresh_per_second=10, # ) # task = progress.add_task("-.---e--", total=len(train_loader)) # progress.start() running_loss2 = 0.0 running_loss = 0.0 model.train() for batch_idx, (x, y) in enumerate(train_loader): if batch_idx >= self.optuna_settings._n_train_batches: break model.zero_grad(set_to_none=True) x, y = ( x.to(self.pytorch_settings.device), y.to(self.pytorch_settings.device), ) y_pred = model(x) loss = util.complexNN.complex_mse_loss(y_pred, y) loss_value = loss.item() loss.backward() optimizer.step() running_loss2 += loss_value running_loss += loss_value # if enable_progress: # progress.update(task, advance=1, description=f"{loss_value:.3e}") if writer is not None: if batch_idx % self.pytorch_settings.write_every == 0: writer.add_scalar( "training loss", running_loss2 / (self.pytorch_settings.write_every if batch_idx > 0 else 1), epoch * min(len(train_loader), self.optuna_settings._n_train_batches) + batch_idx, ) running_loss2 = 0.0 # if enable_progress: # progress.stop() return running_loss / min(len(train_loader), self.optuna_settings._n_train_batches) def eval_model( self, trial, model, valid_loader, epoch, writer=None, # enable_progress=True ): # if enable_progress: # progress = Progress( # TextColumn("[green]Evaluating..."), # TextColumn("Error: {task.description}"), # BarColumn(), # TaskProgressColumn(), # TextColumn("[green]Batch"), # MofNCompleteColumn(), # TimeRemainingColumn(), # TimeElapsedColumn(), # # description="Training", # transient=False, # console=self.console, # refresh_per_second=10, # ) # progress.start() # task = progress.add_task("-.---e--", total=len(valid_loader)) model.eval() running_error = 0 running_error_2 = 0 with torch.no_grad(): for batch_idx, (x, y) in enumerate(valid_loader): if batch_idx >= self.optuna_settings._n_valid_batches: break x, y = ( x.to(self.pytorch_settings.device), y.to(self.pytorch_settings.device), ) y_pred = model(x) error = util.complexNN.complex_mse_loss(y_pred, y) error_value = error.item() running_error += error_value running_error_2 += error_value # if enable_progress: # progress.update(task, advance=1, description=f"{error_value:.3e}") if writer is not None: if batch_idx % self.pytorch_settings.write_every == 0: writer.add_scalar( "eval loss", running_error_2 / (self.pytorch_settings.write_every if batch_idx > 0 else 1), epoch * min(len(valid_loader), self.optuna_settings._n_valid_batches) + batch_idx, ) running_error_2 = 0.0 running_error /= min(len(valid_loader), self.optuna_settings._n_valid_batches) if writer is not None: title_append, subtitle = self.build_title(trial) writer.add_figure( "fiber response", self.plot_model_response( trial, model=model, title_append=title_append, subtitle=subtitle, show=False, ), epoch + 1, ) # if enable_progress: # progress.stop() return running_error def run_model(self, model, loader): model.eval() xs = [] ys = [] y_preds = [] with torch.no_grad(): model = model.to(self.pytorch_settings.device) for x, y in loader: x, y = ( x.to(self.pytorch_settings.device), y.to(self.pytorch_settings.device), ) y_pred = model(x).cpu() # x = x.cpu() # y = y.cpu() y_pred = y_pred.view(y_pred.shape[0], -1, 2) y = y.view(y.shape[0], -1, 2) x = x.view(x.shape[0], -1, 2) xs.append(x[:, 0, :].squeeze()) ys.append(y.squeeze()) y_preds.append(y_pred.squeeze()) xs = torch.vstack(xs).cpu() ys = torch.vstack(ys).cpu() y_preds = torch.vstack(y_preds).cpu() return ys, xs, y_preds def objective(self, trial: optuna.Trial, plot_before=False): if self.stop_study: trial.study.stop() model = None writer = self.setup_tb_writer( self.optuna_settings.study_name, f"{trial.number:0{len(str(self.optuna_settings.n_trials)) + 2}}", ) model = self.define_model(trial, writer) # n_nodes = trial.params.get("model_n_hidden_layers", self.model_settings.model_n_layers) * trial.params.get("model_hidden_dim", self.model_settings.unit_count) title_append, subtitle = self.build_title(trial) writer.add_figure( "fiber response", self.plot_model_response( trial, model=model, title_append=title_append, subtitle=subtitle, show=plot_before, ), 0, ) train_loader, valid_loader = self.get_sliced_data(trial) optimizer_name = trial.suggest_categorical_optional("optimizer", self.optimizer_settings.optimizer) lr = trial.suggest_float_optional("lr", self.optimizer_settings.learning_rate, log=True) optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr) if self.optimizer_settings.scheduler is not None: scheduler = getattr(optim.lr_scheduler, self.optimizer_settings.scheduler)( optimizer, **self.optimizer_settings.scheduler_kwargs ) for epoch in range(self.pytorch_settings.epochs): trial.set_user_attr("epoch", epoch) # enable_progress = self.optuna_settings.n_threads == 1 # if enable_progress: # self.console.rule( # f"Epoch {epoch + 1}/{self.pytorch_settings.epochs}" # ) self.train_model( trial, model, optimizer, train_loader, epoch, writer, # enable_progress=enable_progress, ) error = self.eval_model( trial, model, valid_loader, epoch, writer, # enable_progress=enable_progress, ) if self.optimizer_settings.scheduler is not None: scheduler.step(error) trial.set_user_attr("mse", error) trial.set_user_attr("log_mse", np.log10(error + np.finfo(float).eps)) trial.set_user_attr("neg_mse", -error) trial.set_user_attr("neg_log_mse", -np.log10(error + np.finfo(float).eps)) if not self.optuna_settings._multi_objective: trial.report(error, epoch) if trial.should_prune(): raise optuna.exceptions.TrialPruned() writer.close() if self.optuna_settings._multi_objective: return -np.log10(error + np.finfo(float).eps), trial.user_attrs.get("model_n_nodes", -1) if self.pytorch_settings.save_models and model is not None: save_path = Path(self.pytorch_settings.model_dir) / f"{self.optuna_settings.study_name}_{trial.number}.pth" save_path.parent.mkdir(parents=True, exist_ok=True) torch.save(model, save_path) return error def _plot_model_response_eye(self, *signals, labels=None, sps=None, title_append="", subtitle="", show=True): if sps is None: raise ValueError("sps must be provided") if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))): labels = [labels] else: labels = list(labels) while len(labels) < len(signals): labels.append(None) # check if there are any labels if not any(labels): labels = [f"signal {i + 1}" for i in range(len(signals))] fig, axs = plt.subplots(2, len(signals), sharex=True, sharey=True) fig.suptitle(f"Eye diagram{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}") xaxis = np.linspace(0, 2, 2 * sps, endpoint=False) for j, (label, signal) in enumerate(zip(labels, signals)): # signal = signal.cpu().numpy() for i in range(len(signal) // sps - 1): x, y = signal[i * sps : (i + 2) * sps].T axs[0, j].plot(xaxis, np.abs(x) ** 2, color="C0", alpha=0.02) axs[1, j].plot(xaxis, np.abs(y) ** 2, color="C0", alpha=0.02) axs[0, j].set_title(label + " x") axs[1, j].set_title(label + " y") axs[0, j].set_xlabel("Symbol") axs[1, j].set_xlabel("Symbol") axs[0, j].set_ylabel("normalized power") axs[1, j].set_ylabel("normalized power") if show: plt.show() return fig def _plot_model_response_head(self, *signals, labels=None, sps=None, title_append="", subtitle="", show=True): if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))): labels = [labels] else: labels = list(labels) while len(labels) < len(signals): labels.append(None) # check if there are any labels if not any(labels): labels = [f"signal {i + 1}" for i in range(len(signals))] fig, axs = plt.subplots(1, 2, sharex=True, sharey=True) fig.set_size_inches(18, 6) fig.suptitle(f"Fiber response{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}") for i, ax in enumerate(axs): for signal, label in zip(signals, labels): if sps is not None: xaxis = np.linspace(0, len(signal) / sps, len(signal), endpoint=False) else: xaxis = np.arange(len(signal)) ax.plot(xaxis, np.abs(signal[:, i]) ** 2, label=label) ax.set_xlabel("Sample" if sps is None else "Symbol") ax.set_ylabel("normalized power") ax.legend(loc="upper right") if show: plt.show() return fig def plot_model_response( self, trial, model=None, title_append="", subtitle="", mode: Literal["eye", "head"] = "head", show=True, ): data_settings_backup = copy.deepcopy(self.data_settings) pytorch_settings_backup = copy.deepcopy(self.pytorch_settings) self.data_settings.drop_first = 100*128 self.data_settings.shuffle = False self.data_settings.train_split = 1.0 self.pytorch_settings.batchsize = ( self.pytorch_settings.eye_symbols if mode == "eye" else self.pytorch_settings.head_symbols ) plot_loader, _ = self.get_sliced_data(trial, override={"num_symbols": self.pytorch_settings.batchsize}) self.data_settings = data_settings_backup self.pytorch_settings = pytorch_settings_backup fiber_in, fiber_out, regen = self.run_model(model, plot_loader) fiber_in = fiber_in.view(-1, 2) fiber_out = fiber_out.view(-1, 2) regen = regen.view(-1, 2) fiber_in = fiber_in.numpy() fiber_out = fiber_out.numpy() regen = regen.numpy() # https://github.com/matplotlib/matplotlib/issues/27713#issue-2104110987 # https://github.com/matplotlib/matplotlib/issues/27713#issuecomment-1915497463 import gc if mode == "head": fig = self._plot_model_response_head( fiber_in, fiber_out, regen, labels=("fiber in", "fiber out", "regen"), sps=plot_loader.dataset.samples_per_symbol, title_append=title_append, subtitle=subtitle, show=show, ) elif mode == "eye": # raise NotImplementedError("Eye diagram not implemented") fig = self._plot_model_response_eye( fiber_in, fiber_out, regen, labels=("fiber in", "fiber out", "regen"), sps=plot_loader.dataset.samples_per_symbol, title_append=title_append, subtitle=subtitle, show=show, ) else: raise ValueError(f"Unknown mode: {mode}") gc.collect() return fig @staticmethod def build_title(trial: optuna.trial.Trial): title_append = f"for trial {trial.number}" model_n_hidden_layers = util.misc.multi_getattr((trial.params, trial.user_attrs), "model_n_hidden_layers", 0) input_dim = util.misc.multi_getattr((trial.params, trial.user_attrs), "model_input_dim", 0) model_dims = [ util.misc.multi_getattr((trial.params, trial.user_attrs), f"model_hidden_dim_{i}", 0) for i in range(model_n_hidden_layers) ] model_dims.insert(0, input_dim) model_dims.append(2) model_dims = [str(dim) for dim in model_dims] model_activation_func = util.misc.multi_getattr( (trial.params, trial.user_attrs), "model_activation_func", "unknown act. fun", ) model_dtype = util.misc.multi_getattr((trial.params, trial.user_attrs), "model_dtype", "unknown dtype") subtitle = ( f"{model_n_hidden_layers+2} layers à ({', '.join(model_dims)}) units, {model_activation_func}, {model_dtype}" ) return title_append, subtitle