import copy from datetime import datetime from pathlib import Path from typing import Literal import matplotlib.pyplot as plt import numpy as np import optuna import warnings import torch import torch.nn as nn # import torch.nn.functional as F # mse_loss doesn't support complex numbers import torch.optim as optim import torch.utils.data from torch.utils.tensorboard import SummaryWriter from rich.progress import ( Progress, TextColumn, BarColumn, TaskProgressColumn, TimeRemainingColumn, MofNCompleteColumn, TimeElapsedColumn, ) from rich.console import Console # from rich import print as rprint import multiprocessing from util.datasets import FiberRegenerationDataset from util.optuna_helpers import ( force_suggest_categorical, force_suggest_float, force_suggest_int, ) import util from .settings import ( GlobalSettings, DataSettings, ModelSettings, OptunaSettings, OptimizerSettings, PytorchSettings, ) class HyperTraining: def __init__( self, *, global_settings, data_settings, pytorch_settings, model_settings, optimizer_settings, optuna_settings, console=None, ): self.global_settings: GlobalSettings = global_settings self.data_settings: DataSettings = data_settings self.pytorch_settings: PytorchSettings = pytorch_settings self.model_settings: ModelSettings = model_settings self.optimizer_settings: OptimizerSettings = optimizer_settings self.optuna_settings: OptunaSettings = optuna_settings self.console = console or Console() # set some extra settings to make the code more readable self._extra_optuna_settings() def setup_tb_writer(self, study_name=None, append=None): log_dir = ( self.pytorch_settings.summary_dir + "/" + (study_name or self.optuna_settings.study_name) ) if append is not None: log_dir += "_" + str(append) return SummaryWriter(log_dir) def resume_latest_study(self, verbose=True): study_name = self.get_latest_study() if study_name: print(f"Resuming study: {study_name}") self.optuna_settings.study_name = study_name def get_latest_study(self, verbose=True): studies = self.get_studies() for study in studies: study.datetime_start = study.datetime_start or datetime.min if studies: study = sorted(studies, key=lambda x: x.datetime_start, reverse=True)[0] if verbose: print(f"Last study: {study.study_name}") study_name = study.study_name else: if verbose: print("No previous studies found") study_name = None return study_name def get_studies(self): return optuna.get_all_study_summaries(storage=self.optuna_settings.storage) def setup_study(self): self.study = optuna.create_study( study_name=self.optuna_settings.study_name, storage=self.optuna_settings.storage, load_if_exists=True, direction=self.optuna_settings.direction, directions=self.optuna_settings.directions, ) with warnings.catch_warnings(action="ignore"): self.study.set_metric_names(self.optuna_settings.metrics_names) self.n_threads = min( self.optuna_settings.n_trials, self.optuna_settings.n_threads ) self.processes = [] if self.n_threads > 1: for _ in range(self.n_threads): p = multiprocessing.Process( # target=lambda n_trials: self._run_optimize(self, n_trials), target=self._run_optimize, args=(self.optuna_settings.n_trials // self.n_threads,), ) self.processes.append(p) # def plot_eye(self, width=2, symbols=None, alpha=None, complex=False, show=True): # data, config = util.datasets.load_data( # self.data_settings.config_path, # skipfirst=10, # symbols=symbols or 1000, # real=not complex, # normalize=True, # ) # eye_data = {"data": data.numpy(), "sps": int(config["glova"]["sps"])} # return util.plot.eye( # **eye_data, # width=width, # show=show, # alpha=alpha, # complex=complex, # symbols=symbols or 1000, # skipfirst=0, # ) def run_study(self): if self.processes: for p in self.processes: p.start() for p in self.processes: p.join() remaining_trials = self.optuna_settings.n_trials % self.n_threads else: remaining_trials = self.optuna_settings.n_trials if remaining_trials: self._run_optimize(remaining_trials) def _run_optimize(self, n_trials): self.study.optimize( self.objective, n_trials=n_trials, timeout=self.optuna_settings.timeout ) def _extra_optuna_settings(self): self.optuna_settings.multi_objective = len(self.optuna_settings.directions) > 1 if self.optuna_settings.multi_objective: self.optuna_settings.direction = None else: self.optuna_settings.direction = self.optuna_settings.directions[0] self.optuna_settings.directions = None self.optuna_settings.n_train_batches = ( self.optuna_settings.n_train_batches if self.optuna_settings.limit_examples else float("inf") ) self.optuna_settings.n_valid_batches = ( self.optuna_settings.n_valid_batches if self.optuna_settings.limit_examples else float("inf") ) def define_model(self, trial: optuna.Trial, writer=None): n_layers = force_suggest_int( trial, "model_n_layers", self.model_settings.model_n_layers ) input_dim = 2 * trial.params.get( "model_input_dim", force_suggest_int(trial, "model_input_dim", self.data_settings.model_input_dim), ) dtype = trial.params.get( "model_dtype", force_suggest_categorical(trial, "model_dtype", self.data_settings.dtype), ) dtype = getattr(torch, dtype) afunc = force_suggest_categorical( trial, "model_activation_func", self.model_settings.model_activation_func ) layers = [] last_dim = input_dim for i in range(n_layers): hidden_dim = force_suggest_int( trial, f"model_hidden_dim_{i}", self.model_settings.unit_count ) layers.append( util.complexNN.SemiUnitaryLayer(last_dim, hidden_dim, dtype=dtype) ) last_dim = hidden_dim layers.append(getattr(util.complexNN, afunc)()) layers.append( util.complexNN.UnitaryLayer( hidden_dim, self.model_settings.output_dim, dtype=dtype ) ) model = nn.Sequential(*layers) if writer is not None: writer.add_graph( model, torch.zeros(1, input_dim, dtype=dtype), use_strict_trace=False ) return model.to(self.pytorch_settings.device) def get_sliced_data(self, trial: optuna.Trial, override=None): symbols = trial.params.get( "dataset_symbols", force_suggest_float(trial, "dataset_symbols", self.data_settings.symbols), ) xy_delay = trial.params.get( "dataset_xy_delay", force_suggest_float(trial, "dataset_xy_delay", self.data_settings.xy_delay), ) data_size = trial.params.get( "model_input_dim", force_suggest_int(trial, "model_input_dim", self.data_settings.model_input_dim), ) dtype = trial.params.get( "model_dtype", force_suggest_categorical(trial, "model_dtype", self.data_settings.dtype), ) dtype = getattr(torch, dtype) num_symbols = None if override is not None: num_symbols = override.get("num_symbols", None) # get dataset dataset = FiberRegenerationDataset( file_path=self.data_settings.config_path, symbols=symbols, output_dim=data_size, target_delay=self.data_settings.in_out_delay, xy_delay=xy_delay, drop_first=self.data_settings.drop_first, dtype=dtype, real=not dtype.is_complex, num_symbols=num_symbols, ) dataset_size = len(dataset) indices = list(range(dataset_size)) split = int(np.floor(self.data_settings.train_split * dataset_size)) if self.data_settings.shuffle: np.random.seed(self.global_settings.seed) np.random.shuffle(indices) train_indices, valid_indices = indices[:split], indices[split:] if self.data_settings.shuffle: train_sampler = torch.utils.data.SubsetRandomSampler(train_indices) valid_sampler = torch.utils.data.SubsetRandomSampler(valid_indices) else: train_sampler = train_indices valid_sampler = valid_indices train_loader = torch.utils.data.DataLoader( dataset, batch_size=self.pytorch_settings.batchsize, sampler=train_sampler, drop_last=True, pin_memory=True, num_workers=self.pytorch_settings.dataloader_workers, prefetch_factor=self.pytorch_settings.dataloader_prefetch, ) valid_loader = torch.utils.data.DataLoader( dataset, batch_size=self.pytorch_settings.batchsize, sampler=valid_sampler, drop_last=True, pin_memory=True, num_workers=self.pytorch_settings.dataloader_workers, prefetch_factor=self.pytorch_settings.dataloader_prefetch, ) return train_loader, valid_loader def train_model( self, trial, model, optimizer, train_loader, epoch, writer=None, enable_progress=False, ): if enable_progress: progress = Progress( TextColumn("[yellow] Training..."), TextColumn("Error: {task.description}"), BarColumn(), TaskProgressColumn(), TextColumn("[green]Batch"), MofNCompleteColumn(), TimeRemainingColumn(), TimeElapsedColumn(), # description="Training", transient=False, console=self.console, refresh_per_second=10, ) task = progress.add_task("-.---e--", total=len(train_loader)) progress.start() running_loss2 = 0.0 running_loss = 0.0 model.train() for batch_idx, (x, y) in enumerate(train_loader): if batch_idx >= self.optuna_settings.n_train_batches: break model.zero_grad(set_to_none=True) x, y = ( x.to(self.pytorch_settings.device), y.to(self.pytorch_settings.device), ) y_pred = model(x) loss = util.complexNN.complex_mse_loss(y_pred, y) loss_value = loss.item() loss.backward() optimizer.step() running_loss2 += loss_value running_loss += loss_value if enable_progress: progress.update(task, advance=1, description=f"{loss_value:.3e}") if writer is not None: if batch_idx % self.pytorch_settings.write_every == 0: writer.add_scalar( "training loss", running_loss2 / (self.pytorch_settings.write_every if batch_idx > 0 else 1), epoch * min(len(train_loader), self.optuna_settings.n_train_batches) + batch_idx, ) running_loss2 = 0.0 if enable_progress: progress.stop() return running_loss / min( len(train_loader), self.optuna_settings.n_train_batches ) def eval_model( self, trial, model, valid_loader, epoch, writer=None, enable_progress=True ): if enable_progress: progress = Progress( TextColumn("[green]Evaluating..."), TextColumn("Error: {task.description}"), BarColumn(), TaskProgressColumn(), TextColumn("[green]Batch"), MofNCompleteColumn(), TimeRemainingColumn(), TimeElapsedColumn(), # description="Training", transient=False, console=self.console, refresh_per_second=10, ) progress.start() task = progress.add_task("-.---e--", total=len(valid_loader)) model.eval() running_error = 0 running_error_2 = 0 with torch.no_grad(): for batch_idx, (x, y) in enumerate(valid_loader): if batch_idx >= self.optuna_settings.n_valid_batches: break x, y = ( x.to(self.pytorch_settings.device), y.to(self.pytorch_settings.device), ) y_pred = model(x) error = util.complexNN.complex_mse_loss(y_pred, y) error_value = error.item() running_error += error_value running_error_2 += error_value if enable_progress: progress.update(task, advance=1, description=f"{error_value:.3e}") if writer is not None: if batch_idx % self.pytorch_settings.write_every == 0: writer.add_scalar( "eval loss", running_error_2 / ( self.pytorch_settings.write_every if batch_idx > 0 else 1 ), epoch * min( len(valid_loader), self.optuna_settings.n_valid_batches ) + batch_idx, ) running_error_2 = 0.0 running_error /= min(len(valid_loader), self.optuna_settings.n_valid_batches) if writer is not None: title_append, subtitle = self.build_title(trial) writer.add_figure( "fiber response", self.plot_model_response( trial, model=model, title_append=title_append, subtitle=subtitle, show=False, ), epoch + 1, ) if enable_progress: progress.stop() return running_error def run_model(self, model, loader): model.eval() xs = [] ys = [] y_preds = [] with torch.no_grad(): model = model.to(self.pytorch_settings.device) for x, y in loader: x, y = ( x.to(self.pytorch_settings.device), y.to(self.pytorch_settings.device), ) y_pred = model(x).cpu() # x = x.cpu() # y = y.cpu() y_pred = y_pred.view(y_pred.shape[0], -1, 2) y = y.view(y.shape[0], -1, 2) x = x.view(x.shape[0], -1, 2) xs.append(x[:, 0, :].squeeze()) ys.append(y.squeeze()) y_preds.append(y_pred.squeeze()) xs = torch.vstack(xs).cpu() ys = torch.vstack(ys).cpu() y_preds = torch.vstack(y_preds).cpu() return ys, xs, y_preds def objective(self, trial: optuna.Trial, plot_before=False): model = None exc = None try: # rprint(*list(self.study_name.split("_"))) writer = self.setup_tb_writer( self.optuna_settings.study_name, f"{trial.number:0{len(str(self.optuna_settings.n_trials))}}", ) model = self.define_model(trial, writer) n_params = sum(p.numel() for p in model.parameters()) # n_nodes = trial.params.get("model_n_layers", self.model_settings.model_n_layers) * trial.params.get("model_hidden_dim", self.model_settings.unit_count) title_append, subtitle = self.build_title(trial) writer.add_figure( "fiber response", self.plot_model_response( trial, model=model, title_append=title_append, subtitle=subtitle, show=plot_before, ), 0, ) train_loader, valid_loader = self.get_sliced_data(trial) optimizer_name = force_suggest_categorical( trial, "optimizer", self.optimizer_settings.optimizer ) lr = force_suggest_float( trial, "lr", self.optimizer_settings.learning_rate, log=True ) optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr) if self.optimizer_settings.scheduler is not None: scheduler = getattr(optim.lr_scheduler, self.optimizer_settings.scheduler)( optimizer, **self.optimizer_settings.scheduler_kwargs) for epoch in range(self.pytorch_settings.epochs): enable_progress = self.optuna_settings.n_threads == 1 if enable_progress: self.console.rule( f"Epoch {epoch + 1}/{self.pytorch_settings.epochs}" ) self.train_model( trial, model, optimizer, train_loader, epoch, writer, enable_progress=enable_progress, ) error = self.eval_model( trial, model, valid_loader, epoch, writer, enable_progress=enable_progress, ) if self.optimizer_settings.scheduler is not None: scheduler.step(error) writer.close() if self.optuna_settings.multi_objective: return n_params, error trial.report(error, epoch) if trial.should_prune(): raise optuna.exceptions.TrialPruned() return error except KeyboardInterrupt: ... # except Exception as e: # exc = e finally: if model is not None: save_path = ( Path(self.pytorch_settings.model_dir) / f"{self.optuna_settings.study_name}_{trial.number}.pth" ) save_path.parent.mkdir(parents=True, exist_ok=True) torch.save(model, save_path) if exc is not None: raise exc def _plot_model_response_eye( self, *signals, labels=None, sps=None, title_append="", subtitle="", show=True ): if sps is None: raise ValueError("sps must be provided") if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))): labels = [labels] else: labels = list(labels) while len(labels) < len(signals): labels.append(None) # check if there are any labels if not any(labels): labels = [f"signal {i + 1}" for i in range(len(signals))] fig, axs = plt.subplots(2, len(signals), sharex=True, sharey=True) fig.suptitle( f"Eye diagram{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}" ) xaxis = np.linspace(0, 2, 2 * sps, endpoint=False) for j, (label, signal) in enumerate(zip(labels, signals)): # signal = signal.cpu().numpy() for i in range(len(signal) // sps - 1): x, y = signal[i * sps : (i + 2) * sps].T axs[0, j].plot(xaxis, np.abs(x) ** 2, color="C0", alpha=0.02) axs[1, j].plot(xaxis, np.abs(y) ** 2, color="C0", alpha=0.02) axs[0, j].set_title(label + " x") axs[1, j].set_title(label + " y") axs[0, j].set_xlabel("Symbol") axs[1, j].set_xlabel("Symbol") axs[0, j].set_ylabel("normalized power") axs[1, j].set_ylabel("normalized power") if show: plt.show() def _plot_model_response_head( self, *signals, labels=None, sps=None, title_append="", subtitle="", show=True ): if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))): labels = [labels] else: labels = list(labels) while len(labels) < len(signals): labels.append(None) # check if there are any labels if not any(labels): labels = [f"signal {i + 1}" for i in range(len(signals))] fig, axs = plt.subplots(1, 2, sharex=True, sharey=True) fig.set_size_inches(18,6) fig.suptitle( f"Fiber response{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}" ) for i, ax in enumerate(axs): for signal, label in zip(signals, labels): if sps is not None: xaxis = np.linspace( 0, len(signal) / sps, len(signal), endpoint=False ) else: xaxis = np.arange(len(signal)) ax.plot(xaxis, np.abs(signal[:, i]) ** 2, label=label) ax.set_xlabel("Sample" if sps is None else "Symbol") ax.set_ylabel("normalized power") ax.legend(loc="upper right") if show: plt.show() return fig def plot_model_response( self, trial, model=None, title_append="", subtitle="", mode: Literal["eye", "head"] = "head", show=True, ): data_settings_backup = copy.deepcopy(self.data_settings) pytorch_settings_backup = copy.deepcopy(self.pytorch_settings) self.data_settings.drop_first = 100 self.data_settings.shuffle = False self.data_settings.train_split = 1.0 self.pytorch_settings.batchsize = self.pytorch_settings.eye_symbols if mode == "eye" else self.pytorch_settings.head_symbols plot_loader, _ = self.get_sliced_data( trial, override={"num_symbols": self.pytorch_settings.batchsize} ) self.data_settings = data_settings_backup self.pytorch_settings = pytorch_settings_backup fiber_in, fiber_out, regen = self.run_model(model, plot_loader) fiber_in = fiber_in.view(-1, 2) fiber_out = fiber_out.view(-1, 2) regen = regen.view(-1, 2) fiber_in = fiber_in.numpy() fiber_out = fiber_out.numpy() regen = regen.numpy() # https://github.com/matplotlib/matplotlib/issues/27713#issue-2104110987 # https://github.com/matplotlib/matplotlib/issues/27713#issuecomment-1915497463 import gc if mode == "head": fig = self._plot_model_response_head( fiber_in, fiber_out, regen, labels=("fiber in", "fiber out", "regen"), sps=plot_loader.dataset.samples_per_symbol, title_append=title_append, subtitle=subtitle, show=show, ) elif mode == "eye": # raise NotImplementedError("Eye diagram not implemented") fig = self._plot_model_response_eye( fiber_in, fiber_out, regen, labels=("fiber in", "fiber out", "regen"), sps=plot_loader.dataset.samples_per_symbol, title_append=title_append, subtitle=subtitle, show=show, ) else: raise ValueError(f"Unknown mode: {mode}") gc.collect() return fig @staticmethod def build_title(trial): title_append = f"for trial {trial.number}" subtitle = ( f"{trial.params['model_n_layers']} layers, " f"{', '.join([str(trial.params[f'model_hidden_dim_{i}']) for i in range(trial.params['model_n_layers'])])} units, " f"{trial.params['model_activation_func']}, " f"{trial.params['model_dtype']}" ) return title_append, subtitle