diff --git a/data/optuna_single_core_regen.db b/data/optuna_single_core_regen.db index 5d9549f..6236cfe 100644 --- a/data/optuna_single_core_regen.db +++ b/data/optuna_single_core_regen.db @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:72460af57347d35df91cd76982231bcf538a82fd7f1b8522795202fa298a2dcb -size 696320 +oid sha256:e12f0c21fca93620a165fbb6ed58d0b313093e972ef4416694c29c9cea6dc867 +size 831488 diff --git a/src/single-core-regen/hypertraining/hypertraining.py b/src/single-core-regen/hypertraining/hypertraining.py new file mode 100644 index 0000000..3d5cca7 --- /dev/null +++ b/src/single-core-regen/hypertraining/hypertraining.py @@ -0,0 +1,735 @@ +import copy +from datetime import datetime +from pathlib import Path +from typing import Literal +import matplotlib.pyplot as plt + +import numpy as np +import optuna +import warnings + +import torch +import torch.nn as nn + +# import torch.nn.functional as F # mse_loss doesn't support complex numbers +import torch.optim as optim +import torch.utils.data + +from torch.utils.tensorboard import SummaryWriter + +from rich.progress import ( + Progress, + TextColumn, + BarColumn, + TaskProgressColumn, + TimeRemainingColumn, + MofNCompleteColumn, + TimeElapsedColumn, +) +from rich.console import Console +# from rich import print as rprint + +import multiprocessing + +from util.datasets import FiberRegenerationDataset +from util.optuna_helpers import ( + force_suggest_categorical, + force_suggest_float, + force_suggest_int, +) +import util + +from .settings import ( + GlobalSettings, + DataSettings, + ModelSettings, + OptunaSettings, + OptimizerSettings, + PytorchSettings, +) + + +class HyperTraining: + def __init__( + self, + *, + global_settings, + data_settings, + pytorch_settings, + model_settings, + optimizer_settings, + optuna_settings, + console=None, + ): + self.global_settings: GlobalSettings = global_settings + self.data_settings: DataSettings = data_settings + self.pytorch_settings: PytorchSettings = pytorch_settings + self.model_settings: ModelSettings = model_settings + self.optimizer_settings: OptimizerSettings = optimizer_settings + self.optuna_settings: OptunaSettings = optuna_settings + + self.console = console or Console() + + # set some extra settings to make the code more readable + self._extra_optuna_settings() + + def setup_tb_writer(self, study_name=None, append=None): + log_dir = ( + self.pytorch_settings.summary_dir + + "/" + + (study_name or self.optuna_settings.study_name) + ) + if append is not None: + log_dir += "_" + str(append) + + return SummaryWriter(log_dir) + + def resume_latest_study(self, verbose=True): + study_name = self.get_latest_study() + + if study_name: + print(f"Resuming study: {study_name}") + self.optuna_settings.study_name = study_name + + def get_latest_study(self, verbose=True): + studies = self.get_studies() + for study in studies: + study.datetime_start = study.datetime_start or datetime.min + if studies: + study = sorted(studies, key=lambda x: x.datetime_start, reverse=True)[0] + if verbose: + print(f"Last study: {study.study_name}") + study_name = study.study_name + else: + if verbose: + print("No previous studies found") + study_name = None + return study_name + + def get_studies(self): + return optuna.get_all_study_summaries(storage=self.optuna_settings.storage) + + def setup_study(self): + self.study = optuna.create_study( + study_name=self.optuna_settings.study_name, + storage=self.optuna_settings.storage, + load_if_exists=True, + direction=self.optuna_settings.direction, + directions=self.optuna_settings.directions, + ) + + with warnings.catch_warnings(action="ignore"): + self.study.set_metric_names(self.optuna_settings.metrics_names) + + self.n_threads = min( + self.optuna_settings.n_trials, self.optuna_settings.n_threads + ) + self.processes = [] + if self.n_threads > 1: + for _ in range(self.n_threads): + p = multiprocessing.Process( + # target=lambda n_trials: self._run_optimize(self, n_trials), + target=self._run_optimize, + args=(self.optuna_settings.n_trials // self.n_threads,), + ) + self.processes.append(p) + + # def plot_eye(self, width=2, symbols=None, alpha=None, complex=False, show=True): + # data, config = util.datasets.load_data( + # self.data_settings.config_path, + # skipfirst=10, + # symbols=symbols or 1000, + # real=not complex, + # normalize=True, + # ) + # eye_data = {"data": data.numpy(), "sps": int(config["glova"]["sps"])} + # return util.plot.eye( + # **eye_data, + # width=width, + # show=show, + # alpha=alpha, + # complex=complex, + # symbols=symbols or 1000, + # skipfirst=0, + # ) + + def run_study(self): + if self.processes: + for p in self.processes: + p.start() + for p in self.processes: + p.join() + + remaining_trials = self.optuna_settings.n_trials % self.n_threads + else: + remaining_trials = self.optuna_settings.n_trials + + if remaining_trials: + self._run_optimize(remaining_trials) + + def _run_optimize(self, n_trials): + self.study.optimize( + self.objective, n_trials=n_trials, timeout=self.optuna_settings.timeout + ) + + def _extra_optuna_settings(self): + self.optuna_settings.multi_objective = len(self.optuna_settings.directions) > 1 + if self.optuna_settings.multi_objective: + self.optuna_settings.direction = None + else: + self.optuna_settings.direction = self.optuna_settings.directions[0] + self.optuna_settings.directions = None + + self.optuna_settings.n_train_batches = ( + self.optuna_settings.n_train_batches + if self.optuna_settings.limit_examples + else float("inf") + ) + self.optuna_settings.n_valid_batches = ( + self.optuna_settings.n_valid_batches + if self.optuna_settings.limit_examples + else float("inf") + ) + + def define_model(self, trial: optuna.Trial, writer=None): + n_layers = force_suggest_int( + trial, "model_n_layers", self.model_settings.model_n_layers + ) + + input_dim = 2 * trial.params.get( + "model_input_dim", + force_suggest_int(trial, "model_input_dim", self.data_settings.model_input_dim), + ) + + dtype = trial.params.get( + "model_dtype", + force_suggest_categorical(trial, "model_dtype", self.data_settings.dtype), + ) + dtype = getattr(torch, dtype) + + afunc = force_suggest_categorical( + trial, "model_activation_func", self.model_settings.model_activation_func + ) + + layers = [] + last_dim = input_dim + for i in range(n_layers): + hidden_dim = force_suggest_int( + trial, f"model_hidden_dim_{i}", self.model_settings.unit_count + ) + layers.append( + util.complexNN.SemiUnitaryLayer(last_dim, hidden_dim, dtype=dtype) + ) + last_dim = hidden_dim + layers.append(getattr(util.complexNN, afunc)()) + + layers.append( + util.complexNN.UnitaryLayer( + hidden_dim, self.model_settings.output_dim, dtype=dtype + ) + ) + + model = nn.Sequential(*layers) + + if writer is not None: + writer.add_graph( + model, torch.zeros(1, input_dim, dtype=dtype), use_strict_trace=False + ) + + return model.to(self.pytorch_settings.device) + + def get_sliced_data(self, trial: optuna.Trial, override=None): + symbols = trial.params.get( + "dataset_symbols", + force_suggest_float(trial, "dataset_symbols", self.data_settings.symbols), + ) + + xy_delay = trial.params.get( + "dataset_xy_delay", + force_suggest_float(trial, "dataset_xy_delay", self.data_settings.xy_delay), + ) + + data_size = trial.params.get( + "model_input_dim", + force_suggest_int(trial, "model_input_dim", self.data_settings.model_input_dim), + ) + + dtype = trial.params.get( + "model_dtype", + force_suggest_categorical(trial, "model_dtype", self.data_settings.dtype), + ) + dtype = getattr(torch, dtype) + + num_symbols = None + if override is not None: + num_symbols = override.get("num_symbols", None) + # get dataset + dataset = FiberRegenerationDataset( + file_path=self.data_settings.config_path, + symbols=symbols, + output_dim=data_size, + target_delay=self.data_settings.in_out_delay, + xy_delay=xy_delay, + drop_first=self.data_settings.drop_first, + dtype=dtype, + real=not dtype.is_complex, + num_symbols=num_symbols, + ) + + dataset_size = len(dataset) + indices = list(range(dataset_size)) + split = int(np.floor(self.data_settings.train_split * dataset_size)) + if self.data_settings.shuffle: + np.random.seed(self.global_settings.seed) + np.random.shuffle(indices) + + train_indices, valid_indices = indices[:split], indices[split:] + + if self.data_settings.shuffle: + train_sampler = torch.utils.data.SubsetRandomSampler(train_indices) + valid_sampler = torch.utils.data.SubsetRandomSampler(valid_indices) + else: + train_sampler = train_indices + valid_sampler = valid_indices + + train_loader = torch.utils.data.DataLoader( + dataset, + batch_size=self.pytorch_settings.batchsize, + sampler=train_sampler, + drop_last=True, + pin_memory=True, + num_workers=self.pytorch_settings.dataloader_workers, + prefetch_factor=self.pytorch_settings.dataloader_prefetch, + ) + + valid_loader = torch.utils.data.DataLoader( + dataset, + batch_size=self.pytorch_settings.batchsize, + sampler=valid_sampler, + drop_last=True, + pin_memory=True, + num_workers=self.pytorch_settings.dataloader_workers, + prefetch_factor=self.pytorch_settings.dataloader_prefetch, + ) + + return train_loader, valid_loader + + def train_model( + self, + trial, + model, + optimizer, + train_loader, + epoch, + writer=None, + enable_progress=False, + ): + if enable_progress: + progress = Progress( + TextColumn("[yellow] Training..."), + TextColumn("Error: {task.description}"), + BarColumn(), + TaskProgressColumn(), + TextColumn("[green]Batch"), + MofNCompleteColumn(), + TimeRemainingColumn(), + TimeElapsedColumn(), + # description="Training", + transient=False, + console=self.console, + refresh_per_second=10, + ) + task = progress.add_task("-.---e--", total=len(train_loader)) + progress.start() + + running_loss2 = 0.0 + running_loss = 0.0 + model.train() + for batch_idx, (x, y) in enumerate(train_loader): + if batch_idx >= self.optuna_settings.n_train_batches: + break + model.zero_grad(set_to_none=True) + x, y = ( + x.to(self.pytorch_settings.device), + y.to(self.pytorch_settings.device), + ) + y_pred = model(x) + loss = util.complexNN.complex_mse_loss(y_pred, y) + loss_value = loss.item() + loss.backward() + optimizer.step() + running_loss2 += loss_value + running_loss += loss_value + + if enable_progress: + progress.update(task, advance=1, description=f"{loss_value:.3e}") + + if writer is not None: + if batch_idx % self.pytorch_settings.write_every == 0: + writer.add_scalar( + "training loss", + running_loss2 + / (self.pytorch_settings.write_every if batch_idx > 0 else 1), + epoch + * min(len(train_loader), self.optuna_settings.n_train_batches) + + batch_idx, + ) + running_loss2 = 0.0 + + if enable_progress: + progress.stop() + + return running_loss / min( + len(train_loader), self.optuna_settings.n_train_batches + ) + + def eval_model( + self, trial, model, valid_loader, epoch, writer=None, enable_progress=True + ): + if enable_progress: + progress = Progress( + TextColumn("[green]Evaluating..."), + TextColumn("Error: {task.description}"), + BarColumn(), + TaskProgressColumn(), + TextColumn("[green]Batch"), + MofNCompleteColumn(), + TimeRemainingColumn(), + TimeElapsedColumn(), + # description="Training", + transient=False, + console=self.console, + refresh_per_second=10, + ) + progress.start() + task = progress.add_task("-.---e--", total=len(valid_loader)) + + model.eval() + running_error = 0 + running_error_2 = 0 + with torch.no_grad(): + for batch_idx, (x, y) in enumerate(valid_loader): + if batch_idx >= self.optuna_settings.n_valid_batches: + break + x, y = ( + x.to(self.pytorch_settings.device), + y.to(self.pytorch_settings.device), + ) + y_pred = model(x) + error = util.complexNN.complex_mse_loss(y_pred, y) + error_value = error.item() + running_error += error_value + running_error_2 += error_value + + if enable_progress: + progress.update(task, advance=1, description=f"{error_value:.3e}") + + if writer is not None: + if batch_idx % self.pytorch_settings.write_every == 0: + writer.add_scalar( + "eval loss", + running_error_2 + / ( + self.pytorch_settings.write_every + if batch_idx > 0 + else 1 + ), + epoch + * min( + len(valid_loader), self.optuna_settings.n_valid_batches + ) + + batch_idx, + ) + running_error_2 = 0.0 + + running_error /= min(len(valid_loader), self.optuna_settings.n_valid_batches) + + if writer is not None: + title_append, subtitle = self.build_title(trial) + writer.add_figure( + "fiber response", + self.plot_model_response( + trial, + model=model, + title_append=title_append, + subtitle=subtitle, + show=False, + ), + epoch + 1, + ) + + if enable_progress: + progress.stop() + + return running_error + + def run_model(self, model, loader): + model.eval() + xs = [] + ys = [] + y_preds = [] + with torch.no_grad(): + model = model.to(self.pytorch_settings.device) + for x, y in loader: + x, y = ( + x.to(self.pytorch_settings.device), + y.to(self.pytorch_settings.device), + ) + y_pred = model(x).cpu() + # x = x.cpu() + # y = y.cpu() + y_pred = y_pred.view(y_pred.shape[0], -1, 2) + y = y.view(y.shape[0], -1, 2) + x = x.view(x.shape[0], -1, 2) + xs.append(x[:, 0, :].squeeze()) + ys.append(y.squeeze()) + y_preds.append(y_pred.squeeze()) + + xs = torch.vstack(xs).cpu() + ys = torch.vstack(ys).cpu() + y_preds = torch.vstack(y_preds).cpu() + return ys, xs, y_preds + + def objective(self, trial: optuna.Trial, plot_before=False): + model = None + exc = None + try: + # rprint(*list(self.study_name.split("_"))) + + writer = self.setup_tb_writer( + self.optuna_settings.study_name, + f"{trial.number:0{len(str(self.optuna_settings.n_trials))}}", + ) + + model = self.define_model(trial, writer) + n_params = sum(p.numel() for p in model.parameters()) + # n_nodes = trial.params.get("model_n_layers", self.model_settings.model_n_layers) * trial.params.get("model_hidden_dim", self.model_settings.unit_count) + + title_append, subtitle = self.build_title(trial) + + writer.add_figure( + "fiber response", + self.plot_model_response( + trial, + model=model, + title_append=title_append, + subtitle=subtitle, + show=plot_before, + ), + 0, + ) + + train_loader, valid_loader = self.get_sliced_data(trial) + + optimizer_name = force_suggest_categorical( + trial, "optimizer", self.optimizer_settings.optimizer + ) + + lr = force_suggest_float( + trial, "lr", self.optimizer_settings.learning_rate, log=True + ) + + optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr) + if self.optimizer_settings.scheduler is not None: + scheduler = getattr(optim.lr_scheduler, self.optimizer_settings.scheduler)( + optimizer, **self.optimizer_settings.scheduler_kwargs) + + for epoch in range(self.pytorch_settings.epochs): + enable_progress = self.optuna_settings.n_threads == 1 + if enable_progress: + self.console.rule( + f"Epoch {epoch + 1}/{self.pytorch_settings.epochs}" + ) + self.train_model( + trial, + model, + optimizer, + train_loader, + epoch, + writer, + enable_progress=enable_progress, + ) + error = self.eval_model( + trial, + model, + valid_loader, + epoch, + writer, + enable_progress=enable_progress, + ) + if self.optimizer_settings.scheduler is not None: + scheduler.step(error) + + writer.close() + + if self.optuna_settings.multi_objective: + return n_params, error + trial.report(error, epoch) + if trial.should_prune(): + raise optuna.exceptions.TrialPruned() + return error + + except KeyboardInterrupt: + ... + # except Exception as e: + # exc = e + finally: + if model is not None: + save_path = ( + Path(self.pytorch_settings.model_dir) + / f"{self.optuna_settings.study_name}_{trial.number}.pth" + ) + save_path.parent.mkdir(parents=True, exist_ok=True) + torch.save(model, save_path) + if exc is not None: + raise exc + + + def _plot_model_response_eye( + self, *signals, labels=None, sps=None, title_append="", subtitle="", show=True + ): + if sps is None: + raise ValueError("sps must be provided") + if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))): + labels = [labels] + else: + labels = list(labels) + + while len(labels) < len(signals): + labels.append(None) + + # check if there are any labels + if not any(labels): + labels = [f"signal {i + 1}" for i in range(len(signals))] + + fig, axs = plt.subplots(2, len(signals), sharex=True, sharey=True) + fig.suptitle( + f"Eye diagram{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}" + ) + xaxis = np.linspace(0, 2, 2 * sps, endpoint=False) + for j, (label, signal) in enumerate(zip(labels, signals)): + # signal = signal.cpu().numpy() + for i in range(len(signal) // sps - 1): + x, y = signal[i * sps : (i + 2) * sps].T + axs[0, j].plot(xaxis, np.abs(x) ** 2, color="C0", alpha=0.02) + axs[1, j].plot(xaxis, np.abs(y) ** 2, color="C0", alpha=0.02) + axs[0, j].set_title(label + " x") + axs[1, j].set_title(label + " y") + axs[0, j].set_xlabel("Symbol") + axs[1, j].set_xlabel("Symbol") + axs[0, j].set_ylabel("normalized power") + axs[1, j].set_ylabel("normalized power") + + if show: + plt.show() + + def _plot_model_response_head( + self, *signals, labels=None, sps=None, title_append="", subtitle="", show=True + ): + if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))): + labels = [labels] + else: + labels = list(labels) + + while len(labels) < len(signals): + labels.append(None) + + # check if there are any labels + if not any(labels): + labels = [f"signal {i + 1}" for i in range(len(signals))] + + fig, axs = plt.subplots(1, 2, sharex=True, sharey=True) + fig.set_size_inches(18,6) + fig.suptitle( + f"Fiber response{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}" + ) + for i, ax in enumerate(axs): + for signal, label in zip(signals, labels): + if sps is not None: + xaxis = np.linspace( + 0, len(signal) / sps, len(signal), endpoint=False + ) + else: + xaxis = np.arange(len(signal)) + ax.plot(xaxis, np.abs(signal[:, i]) ** 2, label=label) + ax.set_xlabel("Sample" if sps is None else "Symbol") + ax.set_ylabel("normalized power") + ax.legend(loc="upper right") + if show: + plt.show() + return fig + + def plot_model_response( + self, + trial, + model=None, + title_append="", + subtitle="", + mode: Literal["eye", "head"] = "head", + show=True, + ): + data_settings_backup = copy.deepcopy(self.data_settings) + pytorch_settings_backup = copy.deepcopy(self.pytorch_settings) + self.data_settings.drop_first = 100 + self.data_settings.shuffle = False + self.data_settings.train_split = 1.0 + self.pytorch_settings.batchsize = self.pytorch_settings.eye_symbols if mode == "eye" else self.pytorch_settings.head_symbols + plot_loader, _ = self.get_sliced_data( + trial, override={"num_symbols": self.pytorch_settings.batchsize} + ) + self.data_settings = data_settings_backup + self.pytorch_settings = pytorch_settings_backup + + fiber_in, fiber_out, regen = self.run_model(model, plot_loader) + fiber_in = fiber_in.view(-1, 2) + fiber_out = fiber_out.view(-1, 2) + regen = regen.view(-1, 2) + + fiber_in = fiber_in.numpy() + fiber_out = fiber_out.numpy() + regen = regen.numpy() + + # https://github.com/matplotlib/matplotlib/issues/27713#issue-2104110987 + # https://github.com/matplotlib/matplotlib/issues/27713#issuecomment-1915497463 + import gc + + if mode == "head": + fig = self._plot_model_response_head( + fiber_in, + fiber_out, + regen, + labels=("fiber in", "fiber out", "regen"), + sps=plot_loader.dataset.samples_per_symbol, + title_append=title_append, + subtitle=subtitle, + show=show, + ) + elif mode == "eye": + # raise NotImplementedError("Eye diagram not implemented") + fig = self._plot_model_response_eye( + fiber_in, + fiber_out, + regen, + labels=("fiber in", "fiber out", "regen"), + sps=plot_loader.dataset.samples_per_symbol, + title_append=title_append, + subtitle=subtitle, + show=show, + ) + else: + raise ValueError(f"Unknown mode: {mode}") + gc.collect() + + return fig + + @staticmethod + def build_title(trial): + title_append = f"for trial {trial.number}" + subtitle = ( + f"{trial.params['model_n_layers']} layers, " + f"{', '.join([str(trial.params[f'model_hidden_dim_{i}']) for i in range(trial.params['model_n_layers'])])} units, " + f"{trial.params['model_activation_func']}, " + f"{trial.params['model_dtype']}" + ) + + return title_append, subtitle diff --git a/src/single-core-regen/hypertraining/settings.py b/src/single-core-regen/hypertraining/settings.py new file mode 100644 index 0000000..1fca6f2 --- /dev/null +++ b/src/single-core-regen/hypertraining/settings.py @@ -0,0 +1,77 @@ +from dataclasses import dataclass +from datetime import datetime + + +# global settings +@dataclass(frozen=True) +class GlobalSettings: + seed: int = 42 + + +# data settings +@dataclass +class DataSettings: + config_path: str # = "data/*-128-16384-100000-0-0-17-0-PAM4-0.ini" + dtype: tuple = ("complex64", "float64") + symbols: tuple | float | int = 8 + model_input_dim: tuple | float | int = 64 + shuffle: bool = True + in_out_delay: float = 0 + xy_delay: tuple | float | int = 0 + drop_first: int = 1000 + train_split: float = 0.8 + + +# pytorch settings +@dataclass +class PytorchSettings: + epochs: int = 1 + batchsize: int = 2**10 + + device: str = "cuda" + + dataloader_workers: int = 2 + dataloader_prefetch: int = 2 + + model_dir: str = ".models" + + summary_dir: str = ".runs" + write_every: int = 10 + head_symbols: int = 40 + eye_symbols: int = 1000 + + +# model settings +@dataclass +class ModelSettings: + output_dim: int = 2 + model_n_layers: tuple | int = 3 + unit_count: tuple | int = 8 + # n_units_range: tuple | int = (2, 32) + # activation_func_range: tuple = ("ModReLU", "ZReLU", "CReLU", "Mag", "Identity") + model_activation_func: tuple = ("ModReLU",) + + +@dataclass +class OptimizerSettings: + optimizer: tuple | str = ("Adam", "RMSprop", "SGD") + learning_rate: tuple | float = (1e-5, 1e-1) + scheduler: str | None = None + scheduler_kwargs: dict | None = None + + +# optuna settings +@dataclass +class OptunaSettings: + n_trials: int = 128 + n_threads: int = 4 + timeout: int = 600 + directions: tuple = ("minimize",) + metrics_names: tuple = ("mse",) + limit_examples: bool = True + n_train_batches: int = 100 + n_valid_batches: int = 100 + storage: str = "sqlite:///example.db" + study_name: str = ( + f"optuna_study_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}" + ) \ No newline at end of file diff --git a/src/single-core-regen/regen.py b/src/single-core-regen/regen.py index c3791db..7dbdd3e 100644 --- a/src/single-core-regen/regen.py +++ b/src/single-core-regen/regen.py @@ -1,464 +1,107 @@ -import copy -from dataclasses import dataclass from datetime import datetime -import matplotlib.pyplot as plt - -import numpy as np -import optuna -import warnings - -import torch -import torch.nn as nn -# import torch.nn.functional as F # mse_loss doesn't support complex numbers -import torch.optim as optim -import torch.utils.data - -from torch.utils.tensorboard import SummaryWriter - -from rich.progress import Progress, TextColumn, BarColumn, TaskProgressColumn, TimeRemainingColumn, MofNCompleteColumn -from rich.console import Console - -import multiprocessing - -from util.datasets import FiberRegenerationDataset -from util.complexNN import complex_sse_loss -from util.optuna_helpers import optional_suggest_categorical, optional_suggest_float, optional_suggest_int -import util -# global settings -@dataclass -class GlobalSettings: - seed: int = 42 - - -# data settings -@dataclass -class DataSettings: - config_path: str = "data/*-128-16384-1000-0-0-17-0-PAM4-0.ini" - dtype: torch.dtype = torch.complex64 - symbols_range: tuple|float|int = 16 - data_size_range: tuple|float|int = 32 - shuffle: bool = True - target_delay: float = 0 - xy_delay_range: tuple|float|int = 0 - drop_first: int = 10 - train_split: float = 0.8 - - -# pytorch settings -@dataclass -class PytorchSettings: - device: str = "cuda" - batchsize: int = 1024 - epochs: int = 10 - summary_dir: str = ".runs" - - -# model settings -@dataclass -class ModelSettings: - output_size: int = 2 - n_layer_range: tuple|float|int = (2,8) - n_units_range: tuple|float|int = (2,32) - # activation_func_range: tuple = ("ReLU",) - - -@dataclass -class OptimizerSettings: - # optimizer_range: tuple|str = ("Adam", "RMSprop", "SGD") - optimizer_range: tuple|str = "RMSprop" - # lr_range: tuple|float = (1e-5, 1e-1) - lr_range: tuple|float = 2e-5 - - -# optuna settings -@dataclass -class OptunaSettings: - n_trials: int = 128 - n_threads: int = 8 - timeout: int = 600 - directions: tuple = ("minimize",) - metrics_names: tuple = ("sse",) - - limit_examples: bool = True - n_train_examples: int = PytorchSettings.batchsize * 50 - # n_valid_examples: int = PytorchSettings.batchsize * 100 - n_valid_examples: int = float("inf") - storage: str = "sqlite:///optuna_single_core_regen.db" - study_name: str = ( - f"single_core_regen_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}" - ) - - -class HyperTraining: - def __init__(self): - self.global_settings = GlobalSettings() - self.data_settings = DataSettings() - self.pytorch_settings = PytorchSettings() - self.model_settings = ModelSettings() - self.optimizer_settings = OptimizerSettings() - self.optuna_settings = OptunaSettings() - - self.console = Console() - - # set some extra settings to make the code more readable - self._extra_optuna_settings() - - def setup_tb_writer(self, study_name=None, append=None): - log_dir = self.pytorch_settings.summary_dir + "/" + (study_name or self.optuna_settings.study_name) - if append is not None: - log_dir += "_" + str(append) - - return SummaryWriter(log_dir) - - def resume_latest_study(self, verbose=True): - study_name = hyper_training.get_latest_study() - - if study_name: - print(f"Resuming study: {study_name}") - self.optuna_settings.study_name = study_name - - def get_latest_study(self, verbose=True): - studies = self.get_studies() - for study in studies: - study.datetime_start = study.datetime_start or datetime.min - if studies: - study = sorted(studies, key = lambda x: x.datetime_start, reverse=True)[0] - if verbose: - print(f"Last study: {study.study_name}") - study_name = study.study_name - else: - if verbose: - print("No previous studies found") - study_name = None - return study_name - - def get_studies(self): - return optuna.get_all_study_summaries(storage=self.optuna_settings.storage) - - def setup_study(self): - self.study = optuna.create_study( - study_name=self.optuna_settings.study_name, - storage=self.optuna_settings.storage, - load_if_exists=True, - direction=self.optuna_settings.direction, - directions=self.optuna_settings.directions, - ) - - with warnings.catch_warnings(action="ignore"): - self.study.set_metric_names(self.optuna_settings.metrics_names) - - self.n_threads = min( - self.optuna_settings.n_trials, self.optuna_settings.n_threads - ) - self.processes = [] - if self.n_threads > 1: - for _ in range(self.n_threads): - p = multiprocessing.Process( - # target=lambda n_trials: self._run_optimize(self, n_trials), - target = self._run_optimize, - args = (self.optuna_settings.n_trials // self.n_threads,), - ) - self.processes.append(p) - - def run_study(self): - if self.processes: - for p in self.processes: - p.start() - for p in self.processes: - p.join() - - remaining_trials = ( - self.optuna_settings.n_trials - - self.optuna_settings.n_trials % self.optuna_settings.n_threads - ) - else: - remaining_trials = self.optuna_settings.n_trials - - if remaining_trials: - self._run_optimize(remaining_trials) - - def _run_optimize(self, n_trials): - self.study.optimize( - self.objective, n_trials=n_trials, timeout=self.optuna_settings.timeout - ) - - def plot_eye(self, show=True): - if not hasattr(self, "eye_data"): - data, config = util.datasets.load_data( - self.data_settings.config_path, skipfirst=10, symbols=1000 - ) - self.eye_data = {"data": data, "sps": int(config["glova"]["sps"])} - return util.plot.eye(**self.eye_data, show=show) - - def _extra_optuna_settings(self): - self.optuna_settings.multi_objective = len(self.optuna_settings.directions) > 1 - if self.optuna_settings.multi_objective: - self.optuna_settings.direction = None - else: - self.optuna_settings.direction = self.optuna_settings.directions[0] - self.optuna_settings.directions = None - - self.optuna_settings.n_train_examples = ( - self.optuna_settings.n_train_examples - if self.optuna_settings.limit_examples - else float("inf") - ) - self.optuna_settings.n_valid_examples = ( - self.optuna_settings.n_valid_examples - if self.optuna_settings.limit_examples - else float("inf") - ) - - def define_model(self, trial: optuna.Trial, writer=None): - n_layers = optional_suggest_int(trial, "model_n_layers", self.model_settings.n_layer_range) - - in_features = 2 * trial.params.get( - "dataset_data_size", - optional_suggest_int(trial, "dataset_data_size", self.data_settings.data_size_range), - ) - trial.set_user_attr("input_dim", in_features) - - layers = [] - for i in range(n_layers): - out_features = optional_suggest_int(trial, f"model_n_units_l{i}", self.model_settings.n_units_range, log=True) - - layers.append(nn.Linear(in_features, out_features, dtype=self.data_settings.dtype)) - # layers.append(getattr(nn, activation_func)()) - in_features = out_features - - layers.append(nn.Linear(in_features, self.model_settings.output_size, dtype=self.data_settings.dtype)) - - if writer is not None: - writer.add_graph(nn.Sequential(*layers), torch.zeros(1, trial.user_attrs["input_dim"], dtype=self.data_settings.dtype)) - - return nn.Sequential(*layers) - - def get_sliced_data(self, trial: optuna.Trial): - symbols = optional_suggest_float(trial, "dataset_symbols", self.data_settings.symbols_range) - - xy_delay = optional_suggest_float(trial, "dataset_xy_delay", self.data_settings.xy_delay_range) - - data_size = trial.params.get( - "dataset_data_size", - optional_suggest_int(trial, "dataset_data_size", self.data_settings.data_size_range) - ) - - # get dataset - dataset = FiberRegenerationDataset( - file_path=self.data_settings.config_path, - symbols=symbols, - data_size=data_size, - target_delay=self.data_settings.target_delay, - xy_delay=xy_delay, - drop_first=self.data_settings.drop_first, - dtype=self.data_settings.dtype, - ) - - dataset_size = len(dataset) - indices = list(range(dataset_size)) - split = int(np.floor(self.data_settings.train_split * dataset_size)) - if self.data_settings.shuffle: - np.random.seed(self.global_settings.seed) - np.random.shuffle(indices) - - train_indices, valid_indices = indices[:split], indices[split:] - - train_sampler = torch.utils.data.SubsetRandomSampler(train_indices) - valid_sampler = torch.utils.data.SubsetRandomSampler(valid_indices) - - train_loader = torch.utils.data.DataLoader( - dataset, batch_size=self.pytorch_settings.batchsize, sampler=train_sampler, drop_last=True - ) - valid_loader = torch.utils.data.DataLoader( - dataset, batch_size=self.pytorch_settings.batchsize, sampler=valid_sampler, drop_last=True - ) - - return train_loader, valid_loader - - def train_model(self, model, optimizer, train_loader, epoch, writer=None, enable_progress=True): - if enable_progress: - progress = Progress( - TextColumn("[yellow] Training..."), - TextColumn(" Loss: {task.description}"), - BarColumn(), - TaskProgressColumn(), - TextColumn("[green]Batch"), - MofNCompleteColumn(), - TimeRemainingColumn(), - # description="Training", - transient=False, - console=self.console, - refresh_per_second=10, - ) - task = progress.add_task("-.---e--", total=len(train_loader)) - - running_loss = 0.0 - last_loss = 0.0 - model.train() - for batch_idx, (x, y) in enumerate(train_loader): - if ( - batch_idx * train_loader.batch_size - >= self.optuna_settings.n_train_examples - ): - break - optimizer.zero_grad() - x, y = ( - x.to(self.pytorch_settings.device), - y.to(self.pytorch_settings.device), - ) - y_pred = model(x) - loss = complex_sse_loss(y_pred, y) - loss.backward() - optimizer.step() - - # clamp weights to keep energy bounded - for p in model.parameters(): - p.data.clamp_(-1.0, 1.0) - - last_loss = loss.item() - - if enable_progress: - progress.update(task, advance=1, description=f"{last_loss:.3e}") - - running_loss += loss.item() - if writer is not None: - if batch_idx % 10 == 0: - writer.add_scalar("training loss", running_loss/10, epoch*min(len(train_loader), self.optuna_settings.n_train_examples/train_loader.batch_size) + batch_idx) - running_loss = 0.0 - - if enable_progress: - progress.update(task, description=f"{last_loss:.3e}") - progress.stop() - - - def eval_model(self, model, valid_loader, epoch, writer=None, enable_progress=True): - if enable_progress: - progress = Progress( - TextColumn("[green]Evaluating..."), - TextColumn("Error: {task.description}"), - BarColumn(), - TaskProgressColumn(), - TextColumn("[green]Batch"), - MofNCompleteColumn(), - TimeRemainingColumn(), - # description="Training", - transient=False, - console=self.console, - refresh_per_second=10, - ) - task = progress.add_task("-.---e--", total=len(valid_loader)) - - model.eval() - running_error = 0 - running_error_2 = 0 - with torch.no_grad(): - for batch_idx, (x, y) in enumerate(valid_loader): - if ( - batch_idx * valid_loader.batch_size - >= self.optuna_settings.n_valid_examples - ): - break - x, y = ( - x.to(self.pytorch_settings.device), - y.to(self.pytorch_settings.device), - ) - y_pred = model(x) - error = complex_sse_loss(y_pred, y) - running_error += error.item() - running_error_2 += error.item() - - if enable_progress: - progress.update(task, advance=1, description=f"{error.item():.3e}") - - if writer is not None: - if batch_idx % 10 == 0: - writer.add_scalar("sse", running_error_2/10, epoch*min(len(valid_loader), self.optuna_settings.n_valid_examples/valid_loader.batch_size) + batch_idx) - running_error_2 = 0.0 - - running_error /= batch_idx + 1 - - if enable_progress: - progress.update(task, description=f"{running_error:.3e}") - progress.stop() - - return running_error - - def run_model(self, model, loader): - model.eval() - y_preds = [] - with torch.no_grad(): - for x, y in loader: - x, y = ( - x.to(self.pytorch_settings.device), - y.to(self.pytorch_settings.device), - ) - y_preds.append(model(x)) - return torch.stack(y_preds) - - - def objective(self, trial: optuna.Trial): - writer = self.setup_tb_writer(self.optuna_settings.study_name, f"{trial.number:0>len(str(self.optuna_settings.n_trials))}") - train_loader, valid_loader = self.get_sliced_data(trial) - - model = self.define_model(trial, writer).to(self.pytorch_settings.device) - - optimizer_name = optional_suggest_categorical(trial, "optimizer", self.optimizer_settings.optimizer_range) - - lr = optional_suggest_float(trial, "lr", self.optimizer_settings.lr_range, log=True) - - optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr) - - for epoch in range(self.pytorch_settings.epochs): - enable_progress = self.optuna_settings.n_threads == 1 - if enable_progress: - print(f"Epoch {epoch+1}/{self.pytorch_settings.epochs}") - self.train_model(model, optimizer, train_loader, epoch, writer, enable_progress=enable_progress) - sse = self.eval_model(model, valid_loader, epoch, writer, enable_progress=enable_progress) - - if not self.optuna_settings.multi_objective: - trial.report(sse, epoch) - if trial.should_prune(): - raise optuna.exceptions.TrialPruned() - - writer.close() - - return sse - - +from hypertraining.hypertraining import HyperTraining +from hypertraining.settings import ( + GlobalSettings, + DataSettings, + PytorchSettings, + ModelSettings, + OptimizerSettings, + OptunaSettings, +) + +global_settings = GlobalSettings( + seed = 42, +) + +data_settings = DataSettings( + config_path = "data/*-128-16384-100000-0-0-17-0-PAM4-0.ini", + dtype = ("complex64", "float64", "complex32", "float32"), + symbols = (1, 16), + model_input_dim = (1, 32), + shuffle = True, + in_out_delay = 0, + xy_delay = 0, + drop_first = 1000, + train_split = 0.8, +) + +pytorch_settings = PytorchSettings( + epochs = 25, + batchsize = 2**10, + device = "cuda", + dataloader_workers = 2, + dataloader_prefetch = 2, + summary_dir = ".runs", + write_every = 2**5, + model_dir = ".models", +) + +model_settings = ModelSettings( + output_dim = 2, + model_n_layers = (2, 8), + unit_count = (2, 16), + model_activation_func = ("ModReLU")#, "ZReLU", "Mag")#, "CReLU", "Identity"), +) + +optimizer_settings = OptimizerSettings( + optimizer = ("Adam", "RMSprop"),#, "SGD"), + # learning_rate = (1e-5, 1e-1), + learning_rate=1e-3, + # scheduler = "ReduceLROnPlateau", + # scheduler_kwargs = {"mode": "min", "factor": 0.5, "patience": 10} +) + +optuna_settings = OptunaSettings( + n_trials = 4096, + n_threads = 16, + timeout = 600, + directions = ("minimize","minimize"), + metrics_names = ("n_params","mse"), + + limit_examples = True, + n_train_batches = 100, + n_valid_batches = 100, + storage = "sqlite:///data/single_core_regen.db", + study_name = f"single_core_regen_{datetime.now().strftime('%Y%m%d_%H%M%S')}", +) if __name__ == "__main__": - hyper_training = HyperTraining() + hyper_training = HyperTraining( + global_settings=global_settings, + data_settings=data_settings, + pytorch_settings=pytorch_settings, + model_settings=model_settings, + optimizer_settings=optimizer_settings, + optuna_settings=optuna_settings, + ) + + hyper_training.setup_study() # hyper_training.resume_latest_study() - - hyper_training.setup_study() + hyper_training.run_study() + # best_trial = hyper_training.study.best_trial - best_model = hyper_training.define_model(hyper_training.study.best_trial).to(hyper_training.pytorch_settings.device) - data_settings_backup = copy.copy(hyper_training.data_settings) - hyper_training.data_settings.shuffle = False - hyper_training.data_settings.train_split = 0.01 - plot_loader, _ = hyper_training.get_sliced_data(hyper_training.study.best_trial) - - regen = hyper_training.run_model(best_model, plot_loader) - regen = regen.view(-1, 2) - # [batch_no, batch_size, 2] -> [no, 2] + # best_model = hyper_training.define_model(best_trial).to( + # hyper_training.pytorch_settings.device + # ) - original, _ = util.datasets.load_data(hyper_training.data_settings.config_path, skipfirst=hyper_training.data_settings.drop_first) - original = original[:len(regen)] - - regen = regen.cpu().numpy() - _, axs = plt.subplots(2) - for i, ax in enumerate(axs): - ax.plot(np.abs(original[:, i])**2, label="original") - ax.plot(np.abs(regen[:, i])**2, label="regen") - ax.legend() - plt.show() - - - print(f"Best model: {best_model}") + # title_append, subtitle = hyper_training.build_title(best_trial) + # hyper_training.plot_model_response( + # best_trial, + # model=best_model, + # title_append=title_append, + # subtitle=subtitle, + # mode="eye", + # show=True, + # ) + # print(f"Best model found for trial {best_trial.number}") + # print(f"Best model error: {best_trial.value}") + # print(f"Best model params: {best_trial.params}") + # print() + # print(best_model) # eye_fig = hyper_training.plot_eye() ... diff --git a/src/single-core-regen/regen_no_hyper.py b/src/single-core-regen/regen_no_hyper.py index d136a78..45e18c0 100644 --- a/src/single-core-regen/regen_no_hyper.py +++ b/src/single-core-regen/regen_no_hyper.py @@ -95,11 +95,12 @@ class Training: self.writer = None self.console = Console() - def setup_tb_writer(self, study_name=None): + def setup_tb_writer(self, study_name=None, append=None): log_dir = ( - self.pytorch_settings.summary_dir + "/" + (study_name or self.study_name) + self.pytorch_settings.summary_dir + "/" + (study_name or self.study_name) + ("_" + str(append)) if append else "" ) self.writer = SummaryWriter(log_dir) + return self.writer def plot_eye(self, width=2, symbols=None, alpha=None, complex=False, show=True): if not hasattr(self, "eye_data"): @@ -160,7 +161,7 @@ class Training: dataset = util.datasets.FiberRegenerationDataset( file_path=self.data_settings.config_path, symbols=symbols, - data_size=data_size, + output_dim=data_size, target_delay=self.data_settings.target_delay, xy_delay=xy_delay, drop_first=self.data_settings.drop_first, @@ -212,7 +213,7 @@ class Training: def train_model(self, model, optimizer, train_loader, epoch): with Progress( TextColumn("[yellow] Training..."), - TextColumn("Loss: {task.description}"), + TextColumn("Error: {task.description}"), BarColumn(), TaskProgressColumn(), TextColumn("[green]Batch"), @@ -256,7 +257,7 @@ class Training: def eval_model(self, model, valid_loader, epoch): with Progress( TextColumn("[green]Evaluating..."), - TextColumn("Loss: {task.description}"), + TextColumn("Error: {task.description}"), BarColumn(), TaskProgressColumn(), TextColumn("[green]Batch"), @@ -325,18 +326,6 @@ class Training: ys = torch.vstack(ys).cpu() y_preds = torch.vstack(y_preds).cpu() return ys, xs, y_preds - - def dummy_model(self, loader): - xs = [] - ys = [] - for x, y in loader: - y = y.cpu().view(y.shape[0], -1, 2) - x = x.cpu().view(x.shape[0], -1, 2) - xs.append(x[:, 0, :].squeeze()) - ys.append(y.squeeze()) - xs = torch.vstack(xs) - ys = torch.vstack(ys) - return xs, ys def objective(self, save=False, plot_before=False): try: @@ -360,22 +349,18 @@ class Training: self.train_model(self.model, optimizer, train_loader, epoch) eval_loss = self.eval_model(self.model, valid_loader, epoch) - if save: + return eval_loss + + except KeyboardInterrupt: + ... + finally: + if hasattr(self, "model"): save_path = ( Path(self.pytorch_settings.model_dir) / f"{self.study_name}.pth" ) save_path.parent.mkdir(parents=True, exist_ok=True) torch.save(self.model, save_path) - return eval_loss - except KeyboardInterrupt: - pass - finally: - if hasattr(self, "model"): - except_save_path = Path(".models/exception") / f"{self.study_name}.pth" - except_save_path.parent.mkdir(parents=True, exist_ok=True) - torch.save(self.model, except_save_path) - def _plot_model_response_plotter(self, fiber_in, fiber_out, regen, plot=True): fig, axs = plt.subplots(2) for i, ax in enumerate(axs): diff --git a/src/single-core-regen/util/__init__.py b/src/single-core-regen/util/__init__.py index 0cc2191..842276e 100644 --- a/src/single-core-regen/util/__init__.py +++ b/src/single-core-regen/util/__init__.py @@ -15,3 +15,5 @@ from . import complexNN # noqa: F401 # from .complexNN import UnitaryLayer # noqa: F401 # from .complexNN import complex_mse_loss # noqa: F401 # from .complexNN import complex_sse_loss # noqa: F401 + +from . import misc # noqa: F401 \ No newline at end of file diff --git a/src/single-core-regen/util/complexNN.py b/src/single-core-regen/util/complexNN.py index e42a836..972bafb 100644 --- a/src/single-core-regen/util/complexNN.py +++ b/src/single-core-regen/util/complexNN.py @@ -1,116 +1,160 @@ import torch import torch.nn as nn +import torch.nn.functional as F + def complex_mse_loss(input, target): """ Compute the mean squared error between two complex tensors. """ - return torch.mean(torch.square(input.real - target.real) + torch.square(input.imag - target.imag)) + if input.is_complex(): + return torch.mean( + torch.square(input.real - target.real) + + torch.square(input.imag - target.imag) + ) + else: + return F.mse_loss(input, target) + def complex_sse_loss(input, target): """ Compute the sum squared error between two complex tensors. """ if input.is_complex(): - return torch.sum(torch.square(input.real - target.real) + torch.square(input.imag - target.imag)) + return torch.sum( + torch.square(input.real - target.real) + + torch.square(input.imag - target.imag) + ) else: return torch.sum(torch.square(input - target)) - - class UnitaryLayer(nn.Module): - def __init__(self, in_features, out_features): - super(UnitaryLayer, self).__init__() + def __init__(self, in_features, out_features, dtype=None): assert in_features >= out_features + super(UnitaryLayer, self).__init__() self.in_features = in_features self.out_features = out_features - self.weight = nn.Parameter(torch.randn(in_features, out_features, dtype=torch.cfloat)) + self.weight = nn.Parameter(torch.randn(in_features, out_features, dtype=dtype)) self.reset_parameters() - + def reset_parameters(self): q, _ = torch.linalg.qr(self.weight) self.weight.data = q + + def forward(self, x): + return torch.matmul(x, self.weight) + + def __repr__(self): + return f"UnitaryLayer({self.in_features}, {self.out_features})" + +class SemiUnitaryLayer(nn.Module): + def __init__(self, input_dim, output_dim, dtype=None): + super(SemiUnitaryLayer, self).__init__() + self.input_dim = input_dim + self.output_dim = output_dim + + # Create a larger square matrix for QR decomposition + self.weight = nn.Parameter(torch.randn(max(input_dim, output_dim), max(input_dim, output_dim), dtype=dtype)) + self.reset_parameters() - @staticmethod - @torch.jit.script - def _unitary_forward(x, weight): - out = torch.matmul(x, weight) - return out + def reset_parameters(self): + # Ensure the weights are semi-unitary by QR decomposition + q, _ = torch.linalg.qr(self.weight) + if self.input_dim > self.output_dim: + self.weight.data = q[:self.input_dim, :self.output_dim] + else: + self.weight.data = q[:self.output_dim, :self.input_dim].t() def forward(self, x): - return self._unitary_forward(x, self.weight) + out = torch.matmul(x, self.weight) + return out + def __repr__(self): + return f"SemiUnitaryLayer({self.input_dim}, {self.output_dim})" + + +# class SpreadLayer(nn.Module): +# def __init__(self, in_features, out_features, dtype=None): +# super(SpreadLayer, self).__init__() +# self.in_features = in_features +# self.out_features = out_features +# self.mat = torch.ones(in_features, out_features, dtype=dtype)*torch.sqrt(torch.tensor(in_features/out_features)) + +# def forward(self, x): +# # N in_features -> M out_features, Enery is preserved (P = abs(x)^2) +# out = torch.matmul(x, self.mat) +# return out + #### as defined by zhang et al + class Identity(nn.Module): """ - implements the "activation" function + implements the "activation" function M(z) = z """ + def __init__(self): super(Identity, self).__init__() def forward(self, x): return x + class Mag(nn.Module): """ - implements the activation function + implements the activation function M(z) = ||z|| """ + def __init__(self): super(Mag, self).__init__() - @torch.jit.script def forward(self, x): - return torch.abs(x.real**2 + x.imag**2) - -# class Tanh(nn.Module): -# """ -# implements the activation function -# M(z) = tanh(z) = sinh(z)/cosh(z) = (exp(z)-exp(-z))/(exp(z)+exp(-z)) = (exp(2*z)-1)/(exp(2*z)+1) -# """ -# def __init__(self): -# super(Tanh, self).__init__() + return torch.abs(x).to(dtype=x.dtype) + -# def forward(self, x): -# return torch.tanh(x) - class ModReLU(nn.Module): """ - implements the activation function + implements the activation function M(z) = ReLU(||z|| + b)*exp(j*theta_z) = ReLU(||z|| + b)*z/||z|| """ + def __init__(self, b=0): super(ModReLU, self).__init__() - self.b = b - self.relu = nn.ReLU() - - @staticmethod - # @torch.jit.script - def _mod_relu(x, b): - mod = torch.abs(x.real**2 + x.imag**2) - return torch.relu(mod + b) * x / mod + self.b = torch.tensor(b) def forward(self, x): - return self._mod_relu(x, self.b) - + if x.is_complex(): + mod = torch.abs(x.real**2 + x.imag**2) + return torch.relu(mod + self.b) * x / mod + + else: + return torch.relu(x + self.b) + + def __repr__(self): + return f"ModReLU(b={self.b})" + + class CReLU(nn.Module): """ implements the activation function M(z) = ReLU(Re(z)) + j*ReLU(Im(z)) """ + def __init__(self): super(CReLU, self).__init__() - self.relu = nn.ReLU() - @torch.jit.script def forward(self, x): - return torch.relu(x.real) + 1j*torch.relu(x.imag) - + if x.is_complex(): + return torch.relu(x.real) + 1j * torch.relu(x.imag) + else: + return torch.relu(x) + + class ZReLU(nn.Module): """ implements the activation function @@ -122,20 +166,8 @@ class ZReLU(nn.Module): def __init__(self): super(ZReLU, self).__init__() - @torch.jit.script def forward(self, x): - return x * (torch.angle(x) >= 0) * (torch.angle(x) <= torch.pi/2) - -# class ComplexFeedForwardNN(nn.Module): -# def __init__(self, in_features, hidden_features, out_features): -# super(ComplexFeedForwardNN, self).__init__() -# self.in_features = in_features -# self.hidden_features = hidden_features -# self.out_features = out_features -# self.fc1 = UnitaryLayer(in_features, hidden_features) -# self.fc2 = UnitaryLayer(hidden_features, out_features) - -# def forward(self, x): -# x = self.fc1(x) -# x = self.fc2(x) -# return x \ No newline at end of file + if x.is_complex(): + return x * (torch.angle(x) >= 0) * (torch.angle(x) <= torch.pi / 2) + else: + return torch.relu(x) diff --git a/src/single-core-regen/util/datasets.py b/src/single-core-regen/util/datasets.py index 1d1d543..9b2953e 100644 --- a/src/single-core-regen/util/datasets.py +++ b/src/single-core-regen/util/datasets.py @@ -41,9 +41,10 @@ def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=Fals data = np.load(datapath)[skipfirst * sps : symbols * sps + skipfirst * sps] if normalize: - a, b, c, d = data.T + # square gets normalized to 1, as the power is (proportional to) the square of the amplitude + a, b, c, d = np.square(data.T) a, b, c, d = a/np.max(np.abs(a)), b/np.max(np.abs(b)), c/np.max(np.abs(c)), d/np.max(np.abs(d)) - data = np.array([a, b, c, d]).T + data = np.sqrt(np.array([a, b, c, d]).T) if real: data = np.abs(data) @@ -98,7 +99,7 @@ class FiberRegenerationDataset(Dataset): file_path: str | Path, symbols: int | float, *, - data_size: int = None, + output_dim: int = None, target_delay: float | int = 0, xy_delay: float | int = 0, drop_first: float | int = 0, @@ -129,7 +130,7 @@ class FiberRegenerationDataset(Dataset): assert isinstance(symbols, (float, int)), ( "symbols must be a float or an integer" ) - assert data_size is None or isinstance(data_size, int), ( + assert output_dim is None or isinstance(output_dim, int), ( "output_len must be an integer" ) assert isinstance(target_delay, (float, int)), ( @@ -142,7 +143,7 @@ class FiberRegenerationDataset(Dataset): # check values assert symbols > 0, "symbols must be positive" - assert data_size is None or data_size > 0, "output_len must be positive or None" + assert output_dim is None or output_dim > 0, "output_len must be positive or None" assert drop_first >= 0, "drop_first must be non-negative" faux = kwargs.pop("faux", False) @@ -158,7 +159,7 @@ class FiberRegenerationDataset(Dataset): "glova": {"sps": 128}, } else: - data_raw, self.config = load_data(file_path, skipfirst=drop_first, real=real, normalize=True, device=device, dtype=dtype) + data_raw, self.config = load_data(file_path, skipfirst=drop_first, symbols=kwargs.pop("num_symbols", None), real=real, normalize=True, device=device, dtype=dtype) self.device = data_raw.device @@ -166,7 +167,7 @@ class FiberRegenerationDataset(Dataset): self.samples_per_slice = int(symbols * self.samples_per_symbol) self.symbols_per_slice = self.samples_per_slice / self.samples_per_symbol - self.data_size = data_size or self.samples_per_slice + self.output_dim = output_dim or self.samples_per_slice self.target_delay = target_delay or 0 self.xy_delay = xy_delay or 0 @@ -261,13 +262,13 @@ class FiberRegenerationDataset(Dataset): data, target = self.data[idx, 1].squeeze(), self.data[idx, 0].squeeze() # reduce by by taking self.output_dim equally spaced samples - data = data[:, : data.shape[1] // self.data_size * self.data_size] - data = data.view(data.shape[0], self.data_size, -1) + data = data[:, : data.shape[1] // self.output_dim * self.output_dim] + data = data.view(data.shape[0], self.output_dim, -1) data = data[:, :, 0] # target is corresponding to the middle of the data as the output sample is influenced by the data before and after it - target = target[:, : target.shape[1] // self.data_size * self.data_size] - target = target.view(target.shape[0], self.data_size, -1) + target = target[:, : target.shape[1] // self.output_dim * self.output_dim] + target = target.view(target.shape[0], self.output_dim, -1) target = target[:, 0, target.shape[2] // 2] data = data.transpose(0, 1).flatten().squeeze() diff --git a/src/single-core-regen/util/misc.py b/src/single-core-regen/util/misc.py new file mode 100644 index 0000000..b227305 --- /dev/null +++ b/src/single-core-regen/util/misc.py @@ -0,0 +1,21 @@ +def multi_getattr(objs, attr, fallback=None): + """ + tries to get the attribute from a list of objects, returning the first hit + if no object has the attribute, it returns the fallback value if provided, otherwise raises AttributeError + """ + try: + return _multi_getattr(objs, attr) + except AttributeError as e: + if fallback is not None: + return fallback + raise e + +def _multi_getattr(objs, attr): + if not isinstance(objs, (list, tuple)): + objs = [objs] + for obj in objs: + try: + return getattr(obj, attr) + except AttributeError: + pass + raise AttributeError(f"None of the objects has attribute {attr}") \ No newline at end of file diff --git a/src/single-core-regen/util/optuna_helpers.py b/src/single-core-regen/util/optuna_helpers.py index 1d924a9..123f61d 100644 --- a/src/single-core-regen/util/optuna_helpers.py +++ b/src/single-core-regen/util/optuna_helpers.py @@ -27,4 +27,19 @@ def optional_suggest_int(trial, name, range_or_value, step=None, log=False): return _optional_suggest(trial, name, range_or_value, step=step, log=log, type='int') def optional_suggest_float(trial, name, range_or_value, step=None, log=False): - return _optional_suggest(trial, name, range_or_value, step=step, log=log, type='float') \ No newline at end of file + return _optional_suggest(trial, name, range_or_value, step=step, log=log, type='float') + +def force_suggest_int(trial, name, range_or_value, step=1, log=False): + if not hasattr(range_or_value, '__iter__') or isinstance(range_or_value, str): + return trial.suggest_int(name, range_or_value, range_or_value, step=step, log=log) + return trial.suggest_int(name, *range_or_value, step=step, log=log) + +def force_suggest_float(trial, name, range_or_value, step=None, log=False): + if not hasattr(range_or_value, '__iter__') or isinstance(range_or_value, str): + return trial.suggest_float(name, range_or_value, range_or_value, step=step, log=log) + return trial.suggest_float(name, *range_or_value, step=step, log=log) + +def force_suggest_categorical(trial, name, range_or_value): + if not hasattr(range_or_value, '__iter__') or isinstance(range_or_value, str): + return trial.suggest_categorical(name, [range_or_value]) + return trial.suggest_categorical(name, range_or_value) \ No newline at end of file diff --git a/src/single-core-regen/util/plot.py b/src/single-core-regen/util/plot.py index 2bbc8ce..fe1f407 100644 --- a/src/single-core-regen/util/plot.py +++ b/src/single-core-regen/util/plot.py @@ -38,7 +38,7 @@ def eye(*, path=None, data=None, sps=None, title=None, symbols=1000, skipfirst=0 axs[0, 1].plot(xaxis, np.abs(outx), color="C0", alpha=alpha or 0.1) axs[1, 0].plot(xaxis, np.abs(iny), color="C0", alpha=alpha or 0.1) axs[1, 1].plot(xaxis, np.abs(outy), color="C0", alpha=alpha or 0.1) - axs[0,0].set_ylim(0, 1.1*np.max(np.abs(data))) + axs[0, 0].set_ylim(0, 1.1*np.max(np.abs(data))) axs2[0, 0].plot(xaxis, np.angle(inx), color="C1", alpha=alpha or 0.1) axs2[0, 1].plot(xaxis, np.angle(outx), color="C1", alpha=alpha or 0.1)