From cfa08aae4e1e026175fbec826118bffdfef8a750 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joseph=20Hopfm=C3=BCller?= Date: Fri, 29 Nov 2024 15:48:18 +0100 Subject: [PATCH] add training.py for defining and running models without hyperparametertuning --- .../hypertraining/training.py | 739 ++++++++++++++++++ 1 file changed, 739 insertions(+) create mode 100644 src/single-core-regen/hypertraining/training.py diff --git a/src/single-core-regen/hypertraining/training.py b/src/single-core-regen/hypertraining/training.py new file mode 100644 index 0000000..0a7f668 --- /dev/null +++ b/src/single-core-regen/hypertraining/training.py @@ -0,0 +1,739 @@ +import copy +from datetime import datetime +from pathlib import Path +from typing import Literal +import matplotlib +import torch.nn.utils.parametrize + +try: + matplotlib.use("cairo") +except ImportError: + matplotlib.use("Agg") +import matplotlib.pyplot as plt + +import numpy as np + +import torch +import torch.nn as nn + +# import torch.nn.functional as F # mse_loss doesn't support complex numbers +import torch.optim as optim +import torch.utils.data + +from torch.utils.tensorboard import SummaryWriter + +from rich.progress import ( + Progress, + TextColumn, + BarColumn, + TaskProgressColumn, + TimeRemainingColumn, + MofNCompleteColumn, + TimeElapsedColumn, +) +from rich.console import Console + +from util.datasets import FiberRegenerationDataset +import util + +from .settings import ( + GlobalSettings, + DataSettings, + ModelSettings, + OptimizerSettings, + PytorchSettings, +) + + +class regenerator(nn.Module): + def __init__( + self, + *dims, + layer_function=util.complexNN.ONN, + layer_parametrizations: list[dict] = None, + # [ + # { + # "tensor_name": "weight", + # "parametrization": util.complexNN.Unitary, + # }, + # { + # "tensor_name": "scale", + # "parametrization": util.complexNN.Clamp, + # }, + # ], + activation_function=util.complexNN.Pow, + dtype=torch.float64, + dropout_prob=0.01, + **kwargs, + ): + super(regenerator, self).__init__() + if len(dims) == 0: + try: + dims = kwargs["dims"] + except KeyError: + raise ValueError("dims must be provided") + self._n_hidden_layers = len(dims) - 2 + self._layers = nn.Sequential() + + for i in range(self._n_hidden_layers + 1): + self._layers.append(layer_function(dims[i], dims[i + 1], dtype=dtype)) + if i < self._n_hidden_layers: + if dropout_prob is not None: + self._layers.append(util.complexNN.DropoutComplex(p=dropout_prob)) + self._layers.append(activation_function()) + + # add parametrizations + if layer_parametrizations is not None: + for layer_parametrization in layer_parametrizations: + tensor_name = layer_parametrization.get("tensor_name", None) + parametrization = layer_parametrization.get("parametrization", None) + param_kwargs = layer_parametrization.get("kwargs", {}) + if ( + tensor_name is not None + and tensor_name in self._layers[-1]._parameters + and parametrization is not None + ): + parametrization(self._layers[-1], tensor_name, **param_kwargs) + + def forward(self, input_x): + x = input_x + # check if tracing + if torch.jit.is_tracing(): + for layer in self._layers: + x = layer(x) + else: + # with torch.nn.utils.parametrize.cached(): + for layer in self._layers: + x = layer(x) + return x + +def traverse_dict_update(target, source): + for k, v in source.items(): + if isinstance(v, dict): + if k not in target: + target[k] = {} + traverse_dict_update(target[k], v) + else: + try: + target[k] = v + except TypeError: + target.__dict__[k] = v + +class Trainer: + def __init__( + self, + *, + global_settings=None, + data_settings=None, + pytorch_settings=None, + model_settings=None, + optimizer_settings=None, + console=None, + checkpoint_path=None, + settings_override=None, + reset_epoch=False, + ): + self.resume = checkpoint_path is not None + torch.serialization.add_safe_globals([ + *util.complexNN.__all__, + GlobalSettings, + DataSettings, + ModelSettings, + OptimizerSettings, + PytorchSettings, + regenerator, + torch.nn.utils.parametrizations.orthogonal + ]) + if self.resume: + self.checkpoint_dict = torch.load(checkpoint_path, weights_only=True) + if settings_override is not None: + traverse_dict_update(self.checkpoint_dict["settings"], settings_override) + if reset_epoch: + self.checkpoint_dict["epoch"] = -1 + + self.global_settings: GlobalSettings = self.checkpoint_dict["settings"]["global_settings"] + self.data_settings: DataSettings = self.checkpoint_dict["settings"]["data_settings"] + self.pytorch_settings: PytorchSettings = self.checkpoint_dict["settings"]["pytorch_settings"] + self.model_settings: ModelSettings = self.checkpoint_dict["settings"]["model_settings"] + self.optimizer_settings: OptimizerSettings = self.checkpoint_dict["settings"]["optimizer_settings"] + else: + if global_settings is None: + raise ValueError("global_settings must be provided") + if data_settings is None: + raise ValueError("data_settings must be provided") + if pytorch_settings is None: + raise ValueError("pytorch_settings must be provided") + if model_settings is None: + raise ValueError("model_settings must be provided") + if optimizer_settings is None: + raise ValueError("optimizer_settings must be provided") + + self.global_settings: GlobalSettings = global_settings + self.data_settings: DataSettings = data_settings + self.pytorch_settings: PytorchSettings = pytorch_settings + self.model_settings: ModelSettings = model_settings + self.optimizer_settings: OptimizerSettings = optimizer_settings + + self.console = console or Console() + self.writer = None + + def setup_tb_writer(self, append=None): + log_dir = self.pytorch_settings.summary_dir + "/" + (datetime.now().strftime("%Y%m%d_%H%M%S")) + if append is not None: + log_dir += "_" + str(append) + + print(f"Logging to {log_dir}") + self.writer = SummaryWriter(log_dir=log_dir) + + def save_checkpoint(self, save_dict, filename): + torch.save(save_dict, filename) + + def build_checkpoint_dict(self, loss=None, epoch=None): + return { + "epoch": -1 if epoch is None else epoch, + "loss": float("inf") if loss is None else loss, + "model_state_dict": copy.deepcopy(self.model.state_dict()), + "optimizer_state_dict": copy.deepcopy(self.optimizer.state_dict()), + "scheduler_state_dict": copy.deepcopy(self.scheduler.state_dict()) if hasattr(self, "scheduler") else None, + "model_kwargs": copy.deepcopy(self.model_kwargs), + "settings": { + "global_settings": copy.deepcopy(self.global_settings), + "data_settings": copy.deepcopy(self.data_settings), + "pytorch_settings": copy.deepcopy(self.pytorch_settings), + "model_settings": copy.deepcopy(self.model_settings), + "optimizer_settings": copy.deepcopy(self.optimizer_settings), + }, + } + + def define_model(self, model_kwargs=None): + if model_kwargs is None: + n_hidden_layers = self.model_settings.n_hidden_layers + + input_dim = 2 * self.data_settings.output_size + + dtype = getattr(torch, self.data_settings.dtype) + + afunc = getattr(util.complexNN, self.model_settings.model_activation_func) + + layer_func = getattr(util.complexNN, self.model_settings.model_layer_function) + + layer_parametrizations = self.model_settings.model_layer_parametrizations + + hidden_dims = [self.model_settings.overrides.get(f"n_hidden_nodes_{i}") for i in range(n_hidden_layers)] + + self.model_kwargs = { + "dims": (input_dim, *hidden_dims, self.model_settings.output_dim), + "layer_function": layer_func, + "layer_parametrizations": layer_parametrizations, + "activation_function": afunc, + "dtype": dtype, + "dropout_prob": self.model_settings.dropout_prob, + } + else: + self.model_kwargs = model_kwargs + input_dim = self.model_kwargs["dims"][0] + dtype = self.model_kwargs["dtype"] + + # dims = self.model_kwargs.pop("dims") + self.model = regenerator(**self.model_kwargs) + + self.writer.add_graph(self.model, torch.zeros(1, input_dim, dtype=dtype)) + + self.model = self.model.to(self.pytorch_settings.device) + + def get_sliced_data(self, override=None): + symbols = self.data_settings.symbols + + in_out_delay = self.data_settings.in_out_delay + + xy_delay = self.data_settings.xy_delay + + data_size = self.data_settings.output_size + + dtype = getattr(torch, self.data_settings.dtype) + + num_symbols = None + if override is not None: + num_symbols = override.get("num_symbols", None) + # get dataset + dataset = FiberRegenerationDataset( + file_path=self.data_settings.config_path, + symbols=symbols, + output_dim=data_size, + target_delay=in_out_delay, + xy_delay=xy_delay, + drop_first=self.data_settings.drop_first, + dtype=dtype, + real=not dtype.is_complex, + num_symbols=num_symbols, + ) + + dataset_size = len(dataset) + indices = list(range(dataset_size)) + split = int(np.floor(self.data_settings.train_split * dataset_size)) + if self.data_settings.shuffle: + np.random.seed(self.global_settings.seed) + np.random.shuffle(indices) + + train_indices, valid_indices = indices[:split], indices[split:] + + if self.data_settings.shuffle: + train_sampler = torch.utils.data.SubsetRandomSampler(train_indices) + valid_sampler = torch.utils.data.SubsetRandomSampler(valid_indices) + else: + train_sampler = train_indices + valid_sampler = valid_indices + + train_loader = torch.utils.data.DataLoader( + dataset, + batch_size=self.pytorch_settings.batchsize, + sampler=train_sampler, + drop_last=True, + pin_memory=True, + num_workers=self.pytorch_settings.dataloader_workers, + prefetch_factor=self.pytorch_settings.dataloader_prefetch, + ) + + valid_loader = torch.utils.data.DataLoader( + dataset, + batch_size=self.pytorch_settings.batchsize, + sampler=valid_sampler, + drop_last=True, + pin_memory=True, + num_workers=self.pytorch_settings.dataloader_workers, + prefetch_factor=self.pytorch_settings.dataloader_prefetch, + ) + + return train_loader, valid_loader + + def train_model( + self, + optimizer, + train_loader, + epoch, + enable_progress=False, + ): + if enable_progress: + progress = Progress( + TextColumn("[yellow] Training..."), + TextColumn("Error: {task.description}"), + BarColumn(), + TaskProgressColumn(), + TextColumn("[green]Batch"), + MofNCompleteColumn(), + TimeRemainingColumn(), + TimeElapsedColumn(), + transient=False, + console=self.console, + refresh_per_second=10, + ) + task = progress.add_task("-.---e--", total=len(train_loader)) + progress.start() + + running_loss2 = 0.0 + running_loss = 0.0 + self.model.train() + for batch_idx, (x, y) in enumerate(train_loader): + self.model.zero_grad(set_to_none=True) + x, y = ( + x.to(self.pytorch_settings.device), + y.to(self.pytorch_settings.device), + ) + y_pred = self.model(x) + loss = util.complexNN.complex_mse_loss(y_pred, y, power=True) + loss_value = loss.item() + loss.backward() + optimizer.step() + running_loss2 += loss_value + running_loss += loss_value + + if enable_progress: + progress.update(task, advance=1, description=f"{loss_value:.3e}") + + if batch_idx % self.pytorch_settings.write_every == 0: + self.writer.add_scalar( + "training loss", + running_loss2 / (self.pytorch_settings.write_every if batch_idx > 0 else 1), + epoch * len(train_loader) + batch_idx, + ) + running_loss2 = 0.0 + + if enable_progress: + progress.stop() + + return running_loss / len(train_loader) + + def eval_model(self, valid_loader, epoch, enable_progress=True): + if enable_progress: + progress = Progress( + TextColumn("[green]Evaluating..."), + TextColumn("Error: {task.description}"), + BarColumn(), + TaskProgressColumn(), + TextColumn("[green]Batch"), + MofNCompleteColumn(), + TimeRemainingColumn(), + TimeElapsedColumn(), + transient=False, + console=self.console, + refresh_per_second=10, + ) + progress.start() + task = progress.add_task("-.---e--", total=len(valid_loader)) + + self.model.eval() + running_error = 0 + with torch.no_grad(): + for batch_idx, (x, y) in enumerate(valid_loader): + x, y = ( + x.to(self.pytorch_settings.device), + y.to(self.pytorch_settings.device), + ) + y_pred = self.model(x) + error = util.complexNN.complex_mse_loss(y_pred, y, power=True) + error_value = error.item() + running_error += error_value + + if enable_progress: + progress.update(task, advance=1, description=f"{error_value:.3e}") + + + running_error /= len(valid_loader) + self.writer.add_scalar( + "eval loss", + running_error, + epoch, + ) + title_append, subtitle = self.build_title(epoch + 1) + self.writer.add_figure( + "fiber response", + self.plot_model_response( + model=self.model, + title_append=title_append, + subtitle=subtitle, + show=False, + ), + epoch + 1, + ) + self.writer.add_figure( + "eye diagram", + self.plot_model_response( + model=self.model, + title_append=title_append, + subtitle=subtitle, + show=False, + mode="eye", + ), + epoch + 1, + ) + self.writer_histograms(epoch + 1) + + if enable_progress: + progress.stop() + + return running_error + + def run_model(self, model, loader): + model.eval() + xs = [] + ys = [] + y_preds = [] + with torch.no_grad(): + model = model.to(self.pytorch_settings.device) + for x, y in loader: + x, y = ( + x.to(self.pytorch_settings.device), + y.to(self.pytorch_settings.device), + ) + y_pred = model(x).cpu() + # x = x.cpu() + # y = y.cpu() + y_pred = y_pred.view(y_pred.shape[0], -1, 2) + y = y.view(y.shape[0], -1, 2) + x = x.view(x.shape[0], -1, 2) + xs.append(x[:, 0, :].squeeze()) + ys.append(y.squeeze()) + y_preds.append(y_pred.squeeze()) + + xs = torch.vstack(xs).cpu() + ys = torch.vstack(ys).cpu() + y_preds = torch.vstack(y_preds).cpu() + return ys, xs, y_preds + + def writer_histograms(self, epoch, attributes=["weight", "weight_U", "weight_V", "bias", "sigma", "scale"]): + for i, layer in enumerate(self.model._layers): + tag = f"layer {i}" + for attribute in attributes: + if hasattr(layer, attribute): + vals: np.ndarray = getattr(layer, attribute).detach().cpu().numpy().flatten() + if vals.ndim <= 1 and len(vals) == 1: + if np.iscomplexobj(vals): + self.writer.add_scalar(f"{tag} {attribute} (Mag)", np.abs(vals), epoch) + self.writer.add_scalar(f"{tag} {attribute} (Phase)", np.angle(vals), epoch) + else: + self.writer.add_scalar(f"{tag} {attribute}", vals, epoch) + else: + if np.iscomplexobj(vals): + self.writer.add_histogram(f"{tag} {attribute} (Mag)", np.abs(vals), epoch, bins="fd") + self.writer.add_histogram(f"{tag} {attribute} (Phase)", np.angle(vals), epoch, bins="fd") + else: + self.writer.add_histogram(f"{tag} {attribute}", vals, epoch, bins="fd") + + def train(self): + if self.writer is None: + self.setup_tb_writer() + + if self.resume: + model_kwargs = self.checkpoint_dict["model_kwargs"] + else: + model_kwargs = None + + self.define_model(model_kwargs=model_kwargs) + + print(f"number of parameters (trainable): {sum(p.numel() for p in self.model.parameters())} ({sum(p.numel() for p in self.model.parameters() if p.requires_grad)})") + + title_append, subtitle = self.build_title(0) + + self.writer.add_figure( + "fiber response", + self.plot_model_response( + model=self.model, + title_append=title_append, + subtitle=subtitle, + show=False, + ), + 0, + ) + self.writer.add_figure( + "eye diagram", + self.plot_model_response( + model=self.model, + title_append=title_append, + subtitle=subtitle, + mode="eye", + show=False, + ), + 0, + ) + self.writer_histograms(0) + + train_loader, valid_loader = self.get_sliced_data() + + optimizer_name = self.optimizer_settings.optimizer + + lr = self.optimizer_settings.learning_rate + + self.optimizer: optim.Optimizer = getattr(optim, optimizer_name)(self.model.parameters(), lr=lr) + if self.optimizer_settings.scheduler is not None: + self.scheduler = getattr(optim.lr_scheduler, self.optimizer_settings.scheduler)( + self.optimizer, **self.optimizer_settings.scheduler_kwargs + ) + if self.resume: + try: + self.scheduler.load_state_dict(self.checkpoint_dict["scheduler_state_dict"]) + except ValueError: + pass + self.writer.add_scalar("learning rate", self.scheduler.get_last_lr()[0], -1) + + + if not self.resume: + self.best = self.build_checkpoint_dict() + else: + self.best = self.checkpoint_dict + self.model.load_state_dict(self.best["model_state_dict"], strict=False) + try: + self.optimizer.load_state_dict(self.best["optimizer_state_dict"]) + except ValueError: + pass + + for epoch in range(self.best["epoch"] + 1, self.pytorch_settings.epochs): + enable_progress = True + if enable_progress: + self.console.rule(f"Epoch {epoch + 1}/{self.pytorch_settings.epochs}") + self.train_model( + self.optimizer, + train_loader, + epoch, + enable_progress=enable_progress, + ) + loss = self.eval_model( + valid_loader, + epoch, + enable_progress=enable_progress, + ) + if self.optimizer_settings.scheduler is not None: + lr_old = self.scheduler.get_last_lr() + self.scheduler.step(loss) + lr_new = self.scheduler.get_last_lr() + if lr_old[0] != lr_new[0]: + self.writer.add_scalar("learning rate", lr_new[0], epoch) + + if self.pytorch_settings.save_models and self.model is not None: + save_path = ( + Path(self.pytorch_settings.model_dir) / f"{self.writer.get_logdir().split('/')[-1]}_{epoch}.tar" + ) + save_path.parent.mkdir(parents=True, exist_ok=True) + checkpoint = self.build_checkpoint_dict(loss, epoch) + self.save_checkpoint(checkpoint, save_path) + + if loss < self.best["loss"]: + self.best = checkpoint + save_path = ( + Path(self.pytorch_settings.model_dir) / f"best_{self.writer.get_logdir().split('/')[-1]}.tar" + ) + save_path.parent.mkdir(parents=True, exist_ok=True) + self.save_checkpoint(self.best, save_path) + self.writer.flush() + + self.writer.close() + return self.best + + def _plot_model_response_eye(self, *signals, labels=None, sps=None, title_append="", subtitle="", show=True): + if sps is None: + raise ValueError("sps must be provided") + if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))): + labels = [labels] + else: + labels = list(labels) + + while len(labels) < len(signals): + labels.append(None) + + # check if there are any labels + if not any(labels): + labels = [f"signal {i + 1}" for i in range(len(signals))] + + fig, axs = plt.subplots(1, 2 * len(signals), sharex=True, sharey=True) + fig.set_figwidth(18) + fig.suptitle(f"Eye diagram{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}") + xaxis = np.linspace(0, 2, 2 * sps, endpoint=False) + for j, (label, signal) in enumerate(zip(labels, signals)): + # signal = signal.cpu().numpy() + for i in range(len(signal) // sps - 1): + x, y = signal[i * sps : (i + 2) * sps].T + axs[0 + 2 * j].plot(xaxis, np.abs(x) ** 2, color=f"C{j}", alpha=1 / (len(signal) // sps) * 10) + axs[1 + 2 * j].plot(xaxis, np.abs(y) ** 2, color=f"C{j}", alpha=1 / (len(signal) // sps) * 10) + axs[0 + 2 * j].set_title(label + " x") + axs[1 + 2 * j].set_title(label + " y") + axs[0 + 2 * j].set_xlabel("Symbol") + axs[1 + 2 * j].set_xlabel("Symbol") + axs[0 + 2 * j].set_box_aspect(1) + axs[1 + 2 * j].set_box_aspect(1) + axs[0].set_ylabel("normalized power") + fig.tight_layout() + # axs[1+2*len(labels)-1].set_ylabel("normalized power") + + if show: + plt.show() + return fig + + def _plot_model_response_head(self, *signals, labels=None, sps=None, title_append="", subtitle="", show=True): + if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))): + labels = [labels] + else: + labels = list(labels) + + while len(labels) < len(signals): + labels.append(None) + + # check if there are any labels + if not any(labels): + labels = [f"signal {i + 1}" for i in range(len(signals))] + + fig, axs = plt.subplots(1, 2, sharex=True, sharey=True) + fig.set_figwidth(18) + fig.set_figheight(4) + fig.suptitle(f"Fiber response{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}") + for i, ax in enumerate(axs): + for signal, label in zip(signals, labels): + if sps is not None: + xaxis = np.linspace(0, len(signal) / sps, len(signal), endpoint=False) + else: + xaxis = np.arange(len(signal)) + ax.plot(xaxis, np.abs(signal[:, i]) ** 2, label=label) + ax.set_xlabel("Sample" if sps is None else "Symbol") + ax.set_ylabel("normalized power") + ax.legend(loc="upper right") + fig.tight_layout() + if show: + plt.show() + return fig + + def plot_model_response( + self, + model=None, + title_append="", + subtitle="", + mode: Literal["eye", "head"] = "head", + show=False, + ): + data_settings_backup = copy.deepcopy(self.data_settings) + pytorch_settings_backup = copy.deepcopy(self.pytorch_settings) + self.data_settings.drop_first = 100 * 128 + self.data_settings.shuffle = False + self.data_settings.train_split = 1.0 + self.pytorch_settings.batchsize = ( + self.pytorch_settings.eye_symbols if mode == "eye" else self.pytorch_settings.head_symbols + ) + plot_loader, _ = self.get_sliced_data(override={"num_symbols": self.pytorch_settings.batchsize}) + self.data_settings = data_settings_backup + self.pytorch_settings = pytorch_settings_backup + + fiber_in, fiber_out, regen = self.run_model(model, plot_loader) + fiber_in = fiber_in.view(-1, 2) + fiber_out = fiber_out.view(-1, 2) + regen = regen.view(-1, 2) + + fiber_in = fiber_in.numpy() + fiber_out = fiber_out.numpy() + regen = regen.numpy() + + # https://github.com/matplotlib/matplotlib/issues/27713#issue-2104110987 + # https://github.com/matplotlib/matplotlib/issues/27713#issuecomment-1915497463 + import gc + + if mode == "head": + fig = self._plot_model_response_head( + fiber_in, + fiber_out, + regen, + labels=("fiber in", "fiber out", "regen"), + sps=plot_loader.dataset.samples_per_symbol, + title_append=title_append, + subtitle=subtitle, + show=show, + ) + elif mode == "eye": + # raise NotImplementedError("Eye diagram not implemented") + fig = self._plot_model_response_eye( + fiber_in, + fiber_out, + regen, + labels=("fiber in", "fiber out", "regen"), + sps=plot_loader.dataset.samples_per_symbol, + title_append=title_append, + subtitle=subtitle, + show=show, + ) + else: + raise ValueError(f"Unknown mode: {mode}") + gc.collect() + + return fig + + def build_title(self, number: int): + title_append = f"epoch {number}" + model_n_hidden_layers = self.model_settings.n_hidden_layers + input_dim = 2 * self.data_settings.output_size + model_dims = [ + self.model_settings.overrides.get(f"n_hidden_nodes_{i}", -1) for i in range(model_n_hidden_layers) + ] + model_dims.insert(0, input_dim) + model_dims.append(2) + model_dims = [str(dim) for dim in model_dims] + model_activation_func = self.model_settings.model_activation_func + model_dtype = self.data_settings.dtype + + subtitle = f"{model_n_hidden_layers + 2} layers à ({', '.join(model_dims)}) units, {model_activation_func}, {model_dtype}" + + return title_append, subtitle