diff --git a/src/single-core-regen/regen_no_hyper.py b/src/single-core-regen/regen_no_hyper.py index 45e18c0..6d84c3e 100644 --- a/src/single-core-regen/regen_no_hyper.py +++ b/src/single-core-regen/regen_no_hyper.py @@ -1,414 +1,130 @@ -import copy -from dataclasses import dataclass -from datetime import datetime -from pathlib import Path -import matplotlib.pyplot as plt - -import numpy as np - -import torch -import torch.nn as nn - -# import torch.nn.functional as F # mse_loss doesn't support complex numbers -import torch.optim as optim -import torch.utils.data - -from torch.utils.tensorboard import SummaryWriter - -from rich.progress import ( - Progress, - TextColumn, - BarColumn, - TaskProgressColumn, - TimeRemainingColumn, - MofNCompleteColumn, - TimeElapsedColumn, +from hypertraining.settings import ( + GlobalSettings, + DataSettings, + PytorchSettings, + ModelSettings, + OptimizerSettings, ) -from rich.console import Console -from rich import print as rprint -# from util.optuna_helpers import optional_suggest_categorical, optional_suggest_float, optional_suggest_int +from hypertraining.training import Trainer +import torch +import json import util +global_settings = GlobalSettings( + seed=42, +) -# global settings -@dataclass -class GlobalSettings: - seed: int = 42 +data_settings = DataSettings( + config_path="data/*-128-16384-100000-0-0-17-0-PAM4-0.ini", + dtype="complex64", + # symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber + symbols=13, # study: single_core_regen_20241123_011232 + # output_size = (11, 32), # 26 taps -> 2 taps per input symbol -> 1 tap every 1cm (model has 52 inputs (x/y)) + output_size=26, # study: single_core_regen_20241123_011232 (model_input_dim/2) + shuffle=True, + in_out_delay=0, + xy_delay=0, + drop_first=128*64, + train_split=0.8, +) +pytorch_settings = PytorchSettings( + epochs=10000, + batchsize=2**12, + device="cuda", + dataloader_workers=12, + dataloader_prefetch=8, + summary_dir=".runs", + write_every=2**5, + save_models=True, + model_dir=".models", +) -# data settings -@dataclass -class DataSettings: - config_path: str = "data/*-128-16384-100000-0-0-17-0-PAM4-0.ini" - dtype: torch.dtype = torch.complex64 - symbols_range: float | int = 8 - data_size_range: float | int = 64 - shuffle: bool = True - target_delay: float = 0 - xy_delay_range: float | int = 0 - drop_first: int = 10 - train_split: float = 0.8 +model_settings = ModelSettings( + output_dim=2, + n_hidden_layers=4, + overrides={ + "n_hidden_nodes_0": 8, + "n_hidden_nodes_1": 8, + "n_hidden_nodes_2": 4, + "n_hidden_nodes_3": 6, + }, + model_activation_func="PowScale", + # dropout_prob=0.01, + model_layer_function="ONN", + model_layer_parametrizations=[ + { + "tensor_name": "weight", + "parametrization": torch.nn.utils.parametrizations.orthogonal, + }, + { + "tensor_name": "scales", + "parametrization": util.complexNN.clamp, + }, + { + "tensor_name": "scale", + "parametrization": util.complexNN.clamp, + }, + { + "tensor_name": "bias", + "parametrization": util.complexNN.clamp, + }, + # { + # "tensor_name": "V", + # "parametrization": torch.nn.utils.parametrizations.orthogonal, + # }, + # { + # "tensor_name": "S", + # "parametrization": util.complexNN.clamp, + # }, + ], +) +optimizer_settings = OptimizerSettings( + optimizer="Adam", + learning_rate=0.05, + scheduler="ReduceLROnPlateau", + scheduler_kwargs={ + "patience": 2**6, + "factor": 0.9, + # "threshold": 1e-3, + "min_lr": 1e-6, + "cooldown": 10, + }, +) -# pytorch settings -@dataclass -class PytorchSettings: - epochs: int = 1000 - batchsize: int = 2**12 - device: str = "cuda" - summary_dir: str = ".runs" - model_dir: str = ".models" +def save_dict_to_file(dictionary, filename): + """ + Save the best dictionary to a JSON file. + :param best: Dictionary containing the best training results. + :type best: dict + :param filename: Path to the JSON file where the dictionary will be saved. + :type filename: str + """ + with open(filename, 'w') as f: + json.dump(dictionary, f, indent=4) -# model settings -@dataclass -class ModelSettings: - output_size: int = 2 - # n_layer_range: float|int = 2 - # n_units_range: float|int = 32 - n_layers: int = 3 - n_units: int = 32 - activation_func: tuple | str = "ModReLU" - - -@dataclass -class OptimizerSettings: - optimizer_range: str = "Adam" - lr_range: float = 2e-3 - - -class Training: - def __init__(self): - self.global_settings = GlobalSettings() - self.data_settings = DataSettings() - self.pytorch_settings = PytorchSettings() - self.model_settings = ModelSettings() - self.optimizer_settings = OptimizerSettings() - self.study_name = ( - f"single_core_regen_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}" - ) - - if not hasattr(self.pytorch_settings, "model_dir"): - self.pytorch_settings.model_dir = ".models" - - self.writer = None - self.console = Console() - - def setup_tb_writer(self, study_name=None, append=None): - log_dir = ( - self.pytorch_settings.summary_dir + "/" + (study_name or self.study_name) + ("_" + str(append)) if append else "" - ) - self.writer = SummaryWriter(log_dir) - return self.writer - - def plot_eye(self, width=2, symbols=None, alpha=None, complex=False, show=True): - if not hasattr(self, "eye_data"): - data, config = util.datasets.load_data( - self.data_settings.config_path, - skipfirst=10, - symbols=symbols or 1000, - real=not self.data_settings.dtype.is_complex, - normalize=True, - ) - self.eye_data = {"data": data, "sps": int(config["glova"]["sps"])} - return util.plot.eye( - **self.eye_data, - width=width, - show=show, - alpha=alpha, - complex=complex, - symbols=symbols or 1000, - skipfirst=0, - ) - - def define_model(self): - n_layers = self.model_settings.n_layers - - in_features = 2 * self.data_settings.data_size_range - - layers = [] - for i in range(n_layers): - out_features = self.model_settings.n_units - - layers.append(util.complexNN.UnitaryLayer(in_features, out_features)) - # layers.append(getattr(nn, self.model_settings.activation_func)()) - layers.append( - getattr(util.complexNN, self.model_settings.activation_func)() - ) - in_features = out_features - - layers.append( - util.complexNN.UnitaryLayer(in_features, self.model_settings.output_size) - ) - - if self.writer is not None: - self.writer.add_graph( - nn.Sequential(*layers), - torch.zeros(1, layers[0].in_features, dtype=self.data_settings.dtype), - ) - - return nn.Sequential(*layers) - - def get_sliced_data(self): - symbols = self.data_settings.symbols_range - - xy_delay = self.data_settings.xy_delay_range - - data_size = self.data_settings.data_size_range - - # get dataset - dataset = util.datasets.FiberRegenerationDataset( - file_path=self.data_settings.config_path, - symbols=symbols, - output_dim=data_size, - target_delay=self.data_settings.target_delay, - xy_delay=xy_delay, - drop_first=self.data_settings.drop_first, - dtype=self.data_settings.dtype, - real=not self.data_settings.dtype.is_complex, - # device=self.pytorch_settings.device, - ) - - dataset_size = len(dataset) - indices = list(range(dataset_size)) - split = int(np.floor(self.data_settings.train_split * dataset_size)) - if self.data_settings.shuffle: - np.random.seed(self.global_settings.seed) - np.random.shuffle(indices) - - train_indices, valid_indices = indices[:split], indices[split:] - - if self.data_settings.shuffle: - train_sampler = torch.utils.data.SubsetRandomSampler(train_indices) - valid_sampler = torch.utils.data.SubsetRandomSampler(valid_indices) - else: - train_sampler = train_indices - valid_sampler = valid_indices - - - train_loader = torch.utils.data.DataLoader( - dataset, - batch_size=self.pytorch_settings.batchsize, - sampler=train_sampler, - drop_last=True, - pin_memory=True, - num_workers=24, - prefetch_factor=4, - # persistent_workers=True - ) - valid_loader = torch.utils.data.DataLoader( - dataset, - batch_size=self.pytorch_settings.batchsize, - sampler=valid_sampler, - drop_last=True, - pin_memory=True, - num_workers=24, - prefetch_factor=4, - # persistent_workers=True - ) - - return train_loader, valid_loader - - def train_model(self, model, optimizer, train_loader, epoch): - with Progress( - TextColumn("[yellow] Training..."), - TextColumn("Error: {task.description}"), - BarColumn(), - TaskProgressColumn(), - TextColumn("[green]Batch"), - MofNCompleteColumn(), - TimeRemainingColumn(), - TimeElapsedColumn(), - # description="Training", - transient=False, - console=self.console, - refresh_per_second=10, - ) as progress: - task = progress.add_task("-.---e--", total=len(train_loader)) - - running_loss = 0.0 - model.train() - for batch_idx, (x, y) in enumerate(train_loader): - model.zero_grad(set_to_none=True) - x, y = ( - x.to(self.pytorch_settings.device), - y.to(self.pytorch_settings.device), - ) - y_pred = model(x) - loss = util.complexNN.complex_mse_loss(y_pred, y) - loss.backward() - optimizer.step() - - progress.update(task, advance=1, description=f"{loss.item():.3e}") - - running_loss += loss.item() - if self.writer is not None: - if (batch_idx + 1) % 10 == 0: - self.writer.add_scalar( - "training loss", - running_loss / 10, - epoch * len(train_loader) + batch_idx, - ) - running_loss = 0.0 - - return running_loss - - def eval_model(self, model, valid_loader, epoch): - with Progress( - TextColumn("[green]Evaluating..."), - TextColumn("Error: {task.description}"), - BarColumn(), - TaskProgressColumn(), - TextColumn("[green]Batch"), - MofNCompleteColumn(), - TimeRemainingColumn(), - TimeElapsedColumn(), - # description="Training", - transient=False, - console=self.console, - refresh_per_second=10, - ) as progress: - task = progress.add_task("-.---e--", total=len(valid_loader)) - - model.eval() - running_loss = 0 - running_loss2 = 0 - with torch.no_grad(): - for batch_idx, (x, y) in enumerate(valid_loader): - x, y = ( - x.to(self.pytorch_settings.device), - y.to(self.pytorch_settings.device), - ) - y_pred = model(x) - loss = util.complexNN.complex_mse_loss(y_pred, y) - running_loss += loss.item() - running_loss2 += loss.item() - - progress.update(task, advance=1, description=f"{loss.item():.3e}") - if self.writer is not None: - if (batch_idx + 1) % 10 == 0: - self.writer.add_scalar( - "loss", - running_loss / 10, - epoch * len(valid_loader) + batch_idx, - ) - running_loss = 0.0 - - if self.writer is not None: - self.writer.add_figure("fiber response", self.plot_model_response(model, plot=False), epoch+1) - - return running_loss2 / len(valid_loader) - - def run_model(self, model, loader): - model.eval() - xs = [] - ys = [] - y_preds = [] - with torch.no_grad(): - model = model.to(self.pytorch_settings.device) - for x, y in loader: - x, y = ( - x.to(self.pytorch_settings.device), - y.to(self.pytorch_settings.device), - ) - y_pred = model(x).cpu() - # x = x.cpu() - # y = y.cpu() - y_pred = y_pred.view(y_pred.shape[0], -1, 2) - y = y.view(y.shape[0], -1, 2) - x = x.view(x.shape[0], -1, 2) - xs.append(x[:, 0, :].squeeze()) - ys.append(y.squeeze()) - y_preds.append(y_pred.squeeze()) - - xs = torch.vstack(xs).cpu() - ys = torch.vstack(ys).cpu() - y_preds = torch.vstack(y_preds).cpu() - return ys, xs, y_preds - - def objective(self, save=False, plot_before=False): - try: - rprint(*list(self.study_name.split("_"))) - - self.model = self.define_model().to(self.pytorch_settings.device) - - if self.writer is not None: - self.writer.add_figure("fiber response", self.plot_model_response(plot=plot_before), 0) - - train_loader, valid_loader = self.get_sliced_data() - - optimizer_name = self.optimizer_settings.optimizer_range - - lr = self.optimizer_settings.lr_range - - optimizer = getattr(optim, optimizer_name)(self.model.parameters(), lr=lr) - - for epoch in range(self.pytorch_settings.epochs): - self.console.rule(f"Epoch {epoch + 1}/{self.pytorch_settings.epochs}") - self.train_model(self.model, optimizer, train_loader, epoch) - eval_loss = self.eval_model(self.model, valid_loader, epoch) - - return eval_loss - - except KeyboardInterrupt: - ... - finally: - if hasattr(self, "model"): - save_path = ( - Path(self.pytorch_settings.model_dir) / f"{self.study_name}.pth" - ) - save_path.parent.mkdir(parents=True, exist_ok=True) - torch.save(self.model, save_path) - - def _plot_model_response_plotter(self, fiber_in, fiber_out, regen, plot=True): - fig, axs = plt.subplots(2) - for i, ax in enumerate(axs): - ax.plot(np.abs(fiber_in[:, i]) ** 2, label="fiber in") - ax.plot(np.abs(fiber_out[:, i]) ** 2, label="fiber out") - ax.plot(np.abs(regen[:, i]) ** 2, label="regenerated") - ax.legend() - if plot: - plt.show() - return fig - - def plot_model_response(self, model=None, plot=True): - data_settings_backup = copy.copy(self.data_settings) - self.data_settings.shuffle = False - self.data_settings.train_split = 0.01 - self.data_settings.drop_first = 100 - plot_loader, _ = self.get_sliced_data() - self.data_settings = data_settings_backup - - fiber_in, fiber_out, regen = self.run_model(model or self.model, plot_loader) - fiber_in = fiber_in.view(-1, 2) - fiber_out = fiber_out.view(-1, 2) - regen = regen.view(-1, 2) - - fiber_in = fiber_in.numpy() - fiber_out = fiber_out.numpy() - regen = regen.numpy() - - # https://github.com/matplotlib/matplotlib/issues/27713#issue-2104110987 - # https://github.com/matplotlib/matplotlib/issues/27713#issuecomment-1915497463 - import gc - fig = self._plot_model_response_plotter(fiber_in, fiber_out, regen, plot=plot) - gc.collect() - - return fig if __name__ == "__main__": - trainer = Training() + trainer = Trainer( + global_settings=global_settings, + data_settings=data_settings, + pytorch_settings=pytorch_settings, + model_settings=model_settings, + optimizer_settings=optimizer_settings, + checkpoint_path='.models/20241128_084935_8885.tar', + settings_override={ + "model_settings": { + # "model_activation_func": "PowScale", + "dropout_prob": 0, + } + }, + reset_epoch=True, + ) - # trainer.plot_eye() - trainer.setup_tb_writer() - trainer.objective(save=True) - - best_model = trainer.model - - # best_model = trainer.define_model(trainer.study.best_trial).to(trainer.pytorch_settings.device) - trainer.plot_model_response(best_model) - - # print(f"Best model: {best_model}") + best = trainer.train() + save_dict_to_file(best, ".models/best_results.json") ...