diff --git a/.gitattributes b/.gitattributes index e555e85..68c01d9 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,4 +1,5 @@ data/**/* filter=lfs diff=lfs merge=lfs -text +data/*.db filter=lfs diff=lfs merge=lfs -text data/*.ini filter=lfs diff=lfs merge=lfs text ## lfs setup diff --git a/.gitignore b/.gitignore index a7efef9..75bcb68 100644 --- a/.gitignore +++ b/.gitignore @@ -1,8 +1,5 @@ src/**/*.ini -.data - -# VSCode -.vscode +.* # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/data/optuna_single_core_regen.db b/data/optuna_single_core_regen.db new file mode 100644 index 0000000..5d9549f --- /dev/null +++ b/data/optuna_single_core_regen.db @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:72460af57347d35df91cd76982231bcf538a82fd7f1b8522795202fa298a2dcb +size 696320 diff --git a/src/single-core-regen/regen.py b/src/single-core-regen/regen.py index c334e51..c3791db 100644 --- a/src/single-core-regen/regen.py +++ b/src/single-core-regen/regen.py @@ -1,6 +1,6 @@ +import copy from dataclasses import dataclass from datetime import datetime -import time import matplotlib.pyplot as plt import numpy as np @@ -9,16 +9,21 @@ import warnings import torch import torch.nn as nn -import torch.functional as F +# import torch.nn.functional as F # mse_loss doesn't support complex numbers import torch.optim as optim import torch.utils.data +from torch.utils.tensorboard import SummaryWriter + +from rich.progress import Progress, TextColumn, BarColumn, TaskProgressColumn, TimeRemainingColumn, MofNCompleteColumn +from rich.console import Console + import multiprocessing from util.datasets import FiberRegenerationDataset +from util.complexNN import complex_sse_loss +from util.optuna_helpers import optional_suggest_categorical, optional_suggest_float, optional_suggest_int import util - - # global settings @dataclass class GlobalSettings: @@ -28,12 +33,14 @@ class GlobalSettings: # data settings @dataclass class DataSettings: - config_path: str = "data/*-128-16384-10000-0-0-17-0-PAM4-0.ini" - symbols_range: tuple = (1, 100) - data_size_range: tuple = (1, 20) + config_path: str = "data/*-128-16384-1000-0-0-17-0-PAM4-0.ini" + dtype: torch.dtype = torch.complex64 + symbols_range: tuple|float|int = 16 + data_size_range: tuple|float|int = 32 + shuffle: bool = True target_delay: float = 0 - xy_delay_range: tuple = (0, 1) - drop_first: int = 1000 + xy_delay_range: tuple|float|int = 0 + drop_first: int = 10 train_split: float = 0.8 @@ -41,41 +48,46 @@ class DataSettings: @dataclass class PytorchSettings: device: str = "cuda" - batchsize: int = 128 - epochs: int = 100 + batchsize: int = 1024 + epochs: int = 10 + summary_dir: str = ".runs" # model settings @dataclass class ModelSettings: output_size: int = 2 - n_layer_range: tuple = (1, 3) - n_units_range: tuple = (4, 128) - activation_func_range: tuple = ("ReLU",) + n_layer_range: tuple|float|int = (2,8) + n_units_range: tuple|float|int = (2,32) + # activation_func_range: tuple = ("ReLU",) @dataclass class OptimizerSettings: - optimizer_range: tuple = ("Adam", "RMSprop", "SGD") - lr_range: tuple = (1e-5, 1e-1) + # optimizer_range: tuple|str = ("Adam", "RMSprop", "SGD") + optimizer_range: tuple|str = "RMSprop" + # lr_range: tuple|float = (1e-5, 1e-1) + lr_range: tuple|float = 2e-5 # optuna settings @dataclass class OptunaSettings: n_trials: int = 128 - n_threads: int = 16 + n_threads: int = 8 timeout: int = 600 - directions: tuple = ("maximize",) + directions: tuple = ("minimize",) + metrics_names: tuple = ("sse",) limit_examples: bool = True - n_train_examples: int = PytorchSettings.batchsize * 30 - n_valid_examples: int = PytorchSettings.batchsize * 10 + n_train_examples: int = PytorchSettings.batchsize * 50 + # n_valid_examples: int = PytorchSettings.batchsize * 100 + n_valid_examples: int = float("inf") storage: str = "sqlite:///optuna_single_core_regen.db" study_name: str = ( - f"single_core_regen_{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" + f"single_core_regen_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}" ) - metrics_names: tuple = ("accuracy",) + class HyperTraining: def __init__(self): @@ -86,9 +98,43 @@ class HyperTraining: self.optimizer_settings = OptimizerSettings() self.optuna_settings = OptunaSettings() + self.console = Console() + # set some extra settings to make the code more readable self._extra_optuna_settings() + + def setup_tb_writer(self, study_name=None, append=None): + log_dir = self.pytorch_settings.summary_dir + "/" + (study_name or self.optuna_settings.study_name) + if append is not None: + log_dir += "_" + str(append) + + return SummaryWriter(log_dir) + + def resume_latest_study(self, verbose=True): + study_name = hyper_training.get_latest_study() + if study_name: + print(f"Resuming study: {study_name}") + self.optuna_settings.study_name = study_name + + def get_latest_study(self, verbose=True): + studies = self.get_studies() + for study in studies: + study.datetime_start = study.datetime_start or datetime.min + if studies: + study = sorted(studies, key = lambda x: x.datetime_start, reverse=True)[0] + if verbose: + print(f"Last study: {study.study_name}") + study_name = study.study_name + else: + if verbose: + print("No previous studies found") + study_name = None + return study_name + + def get_studies(self): + return optuna.get_all_study_summaries(storage=self.optuna_settings.storage) + def setup_study(self): self.study = optuna.create_study( study_name=self.optuna_settings.study_name, @@ -100,29 +146,49 @@ class HyperTraining: with warnings.catch_warnings(action="ignore"): self.study.set_metric_names(self.optuna_settings.metrics_names) - - self.n_threads = min(self.optuna_settings.n_trials, self.optuna_settings.n_threads) + + self.n_threads = min( + self.optuna_settings.n_trials, self.optuna_settings.n_threads + ) self.processes = [] - for _ in range(self.n_threads): - p = multiprocessing.Process(target=self._run_optimize) - self.processes.append(p) - + if self.n_threads > 1: + for _ in range(self.n_threads): + p = multiprocessing.Process( + # target=lambda n_trials: self._run_optimize(self, n_trials), + target = self._run_optimize, + args = (self.optuna_settings.n_trials // self.n_threads,), + ) + self.processes.append(p) + def run_study(self): - for p in self.processes: - p.start() - - for p in self.processes: - p.join() - - remaining_trials = self.optuna_settings.n_trials - self.optuna_settings.n_trials % self.optuna_settings.n_threads + if self.processes: + for p in self.processes: + p.start() + for p in self.processes: + p.join() + + remaining_trials = ( + self.optuna_settings.n_trials + - self.optuna_settings.n_trials % self.optuna_settings.n_threads + ) + else: + remaining_trials = self.optuna_settings.n_trials + if remaining_trials: self._run_optimize(remaining_trials) - - def _run_optimize(self, n_trials): - self.study.optimize(self.objective, n_trials=n_trials, timeout=self.optuna_settings.timeout) - def eye(self, show=True): - util.plot.eye(self.data_settings.config_path, show=show) + def _run_optimize(self, n_trials): + self.study.optimize( + self.objective, n_trials=n_trials, timeout=self.optuna_settings.timeout + ) + + def plot_eye(self, show=True): + if not hasattr(self, "eye_data"): + data, config = util.datasets.load_data( + self.data_settings.config_path, skipfirst=10, symbols=1000 + ) + self.eye_data = {"data": data, "sps": int(config["glova"]["sps"])} + return util.plot.eye(**self.eye_data, show=show) def _extra_optuna_settings(self): self.optuna_settings.multi_objective = len(self.optuna_settings.directions) > 1 @@ -143,145 +209,256 @@ class HyperTraining: else float("inf") ) - def define_model(self, trial: optuna.Trial): - n_layers = trial.suggest_int( - "model_n_layers", *self.model_settings.n_layer_range + def define_model(self, trial: optuna.Trial, writer=None): + n_layers = optional_suggest_int(trial, "model_n_layers", self.model_settings.n_layer_range) + + in_features = 2 * trial.params.get( + "dataset_data_size", + optional_suggest_int(trial, "dataset_data_size", self.data_settings.data_size_range), ) + trial.set_user_attr("input_dim", in_features) + layers = [] - - # REVIEW does that work? - in_features = trial.params["dataset_data_size"] * 2 for i in range(n_layers): - out_features = trial.suggest_int( - f"model_n_units_l{i}", *self.model_settings.n_units_range - ) - activation_func = trial.suggest_categorical( - f"model_activation_func_l{i}", self.model_settings.activation_func_range - ) - - layers.append(nn.Linear(in_features, out_features)) - layers.append(getattr(nn, activation_func)) + out_features = optional_suggest_int(trial, f"model_n_units_l{i}", self.model_settings.n_units_range, log=True) + + layers.append(nn.Linear(in_features, out_features, dtype=self.data_settings.dtype)) + # layers.append(getattr(nn, activation_func)()) in_features = out_features - layers.append(nn.Linear(in_features, self.model_settings.output_size)) + layers.append(nn.Linear(in_features, self.model_settings.output_size, dtype=self.data_settings.dtype)) + + if writer is not None: + writer.add_graph(nn.Sequential(*layers), torch.zeros(1, trial.user_attrs["input_dim"], dtype=self.data_settings.dtype)) return nn.Sequential(*layers) def get_sliced_data(self, trial: optuna.Trial): - assert ModelSettings.input_size % 2 == 0, "input_dim must be even" - symbols = trial.suggest_float( - "dataset_symbols", *self.data_settings.symbols_range, log=True - ) - xy_delay = trial.suggest_float( - "dataset_xy_delay", *self.data_settings.xy_delay_range - ) - data_size = trial.suggest_int( - "dataset_data_size", *self.data_settings.data_size_range + symbols = optional_suggest_float(trial, "dataset_symbols", self.data_settings.symbols_range) + + xy_delay = optional_suggest_float(trial, "dataset_xy_delay", self.data_settings.xy_delay_range) + + data_size = trial.params.get( + "dataset_data_size", + optional_suggest_int(trial, "dataset_data_size", self.data_settings.data_size_range) ) + # get dataset dataset = FiberRegenerationDataset( file_path=self.data_settings.config_path, symbols=symbols, - data_size=data_size, # two channels (x,y) + data_size=data_size, target_delay=self.data_settings.target_delay, xy_delay=xy_delay, drop_first=self.data_settings.drop_first, + dtype=self.data_settings.dtype, ) dataset_size = len(dataset) indices = list(range(dataset_size)) split = int(np.floor(self.data_settings.train_split * dataset_size)) - np.random.seed(self.global_settings.seed) - np.random.shuffle(indices) + if self.data_settings.shuffle: + np.random.seed(self.global_settings.seed) + np.random.shuffle(indices) + train_indices, valid_indices = indices[:split], indices[split:] train_sampler = torch.utils.data.SubsetRandomSampler(train_indices) valid_sampler = torch.utils.data.SubsetRandomSampler(valid_indices) train_loader = torch.utils.data.DataLoader( - dataset, batch_size=self.pytorch_settings.batchsize, sampler=train_sampler + dataset, batch_size=self.pytorch_settings.batchsize, sampler=train_sampler, drop_last=True ) valid_loader = torch.utils.data.DataLoader( - dataset, batch_size=self.pytorch_settings.batchsize, sampler=valid_sampler + dataset, batch_size=self.pytorch_settings.batchsize, sampler=valid_sampler, drop_last=True ) return train_loader, valid_loader + def train_model(self, model, optimizer, train_loader, epoch, writer=None, enable_progress=True): + if enable_progress: + progress = Progress( + TextColumn("[yellow] Training..."), + TextColumn(" Loss: {task.description}"), + BarColumn(), + TaskProgressColumn(), + TextColumn("[green]Batch"), + MofNCompleteColumn(), + TimeRemainingColumn(), + # description="Training", + transient=False, + console=self.console, + refresh_per_second=10, + ) + task = progress.add_task("-.---e--", total=len(train_loader)) - def train_model(self, model, optimizer, train_loader): + running_loss = 0.0 + last_loss = 0.0 model.train() - for batch_idx, (data, target) in enumerate(train_loader): - if (batch_idx * train_loader.batchsize - >= self.optuna_settings.n_train_examples): + for batch_idx, (x, y) in enumerate(train_loader): + if ( + batch_idx * train_loader.batch_size + >= self.optuna_settings.n_train_examples + ): break optimizer.zero_grad() - data, target = ( - data.to(self.pytorch_settings.device), - target.to(self.pytorch_settings.device), + x, y = ( + x.to(self.pytorch_settings.device), + y.to(self.pytorch_settings.device), ) - target_pred = model(data) - loss = F.mean_squared_error(target_pred, target) + y_pred = model(x) + loss = complex_sse_loss(y_pred, y) loss.backward() optimizer.step() + # clamp weights to keep energy bounded + for p in model.parameters(): + p.data.clamp_(-1.0, 1.0) - def eval_model(self, model, valid_loader): + last_loss = loss.item() + + if enable_progress: + progress.update(task, advance=1, description=f"{last_loss:.3e}") + + running_loss += loss.item() + if writer is not None: + if batch_idx % 10 == 0: + writer.add_scalar("training loss", running_loss/10, epoch*min(len(train_loader), self.optuna_settings.n_train_examples/train_loader.batch_size) + batch_idx) + running_loss = 0.0 + + if enable_progress: + progress.update(task, description=f"{last_loss:.3e}") + progress.stop() + + + def eval_model(self, model, valid_loader, epoch, writer=None, enable_progress=True): + if enable_progress: + progress = Progress( + TextColumn("[green]Evaluating..."), + TextColumn("Error: {task.description}"), + BarColumn(), + TaskProgressColumn(), + TextColumn("[green]Batch"), + MofNCompleteColumn(), + TimeRemainingColumn(), + # description="Training", + transient=False, + console=self.console, + refresh_per_second=10, + ) + task = progress.add_task("-.---e--", total=len(valid_loader)) + model.eval() - correct = 0 + running_error = 0 + running_error_2 = 0 with torch.no_grad(): - for batch_idx, (data, target) in enumerate(valid_loader): + for batch_idx, (x, y) in enumerate(valid_loader): if ( - batch_idx * valid_loader.batchsize + batch_idx * valid_loader.batch_size >= self.optuna_settings.n_valid_examples ): break - data, target = ( - data.to(self.pytorch_settings.device), - target.to(self.pytorch_settings.device), + x, y = ( + x.to(self.pytorch_settings.device), + y.to(self.pytorch_settings.device), ) - target_pred = model(data) - pred = target_pred.argmax(dim=1, keepdim=True) - correct += pred.eq(target.view_as(pred)).sum().item() + y_pred = model(x) + error = complex_sse_loss(y_pred, y) + running_error += error.item() + running_error_2 += error.item() - accuracy = correct / len(valid_loader.dataset) - # num_params = sum(p.numel() for p in model.parameters()) - return accuracy + if enable_progress: + progress.update(task, advance=1, description=f"{error.item():.3e}") + + if writer is not None: + if batch_idx % 10 == 0: + writer.add_scalar("sse", running_error_2/10, epoch*min(len(valid_loader), self.optuna_settings.n_valid_examples/valid_loader.batch_size) + batch_idx) + running_error_2 = 0.0 + + running_error /= batch_idx + 1 + + if enable_progress: + progress.update(task, description=f"{running_error:.3e}") + progress.stop() + + return running_error + + def run_model(self, model, loader): + model.eval() + y_preds = [] + with torch.no_grad(): + for x, y in loader: + x, y = ( + x.to(self.pytorch_settings.device), + y.to(self.pytorch_settings.device), + ) + y_preds.append(model(x)) + return torch.stack(y_preds) + def objective(self, trial: optuna.Trial): - model = self.define_model(trial).to(self.pytorch_settings.device) - - optimizer_name = trial.suggest_categorical( - "optimizer", self.optimizer_settings.optimizer_range - ) - lr = trial.suggest_float("lr", *self.optimizer_settings.lr_range, log=True) - optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr) - + writer = self.setup_tb_writer(self.optuna_settings.study_name, f"{trial.number:0>len(str(self.optuna_settings.n_trials))}") train_loader, valid_loader = self.get_sliced_data(trial) - for epoch in range(self.pytorch_settings.epochs): - self.train_model(model, optimizer, train_loader) - accuracy = self.eval_model(model, valid_loader) + model = self.define_model(trial, writer).to(self.pytorch_settings.device) - if len(self.optuna_settings.directions) == 1: - trial.report(accuracy, epoch) + optimizer_name = optional_suggest_categorical(trial, "optimizer", self.optimizer_settings.optimizer_range) + + lr = optional_suggest_float(trial, "lr", self.optimizer_settings.lr_range, log=True) + + optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr) + + for epoch in range(self.pytorch_settings.epochs): + enable_progress = self.optuna_settings.n_threads == 1 + if enable_progress: + print(f"Epoch {epoch+1}/{self.pytorch_settings.epochs}") + self.train_model(model, optimizer, train_loader, epoch, writer, enable_progress=enable_progress) + sse = self.eval_model(model, valid_loader, epoch, writer, enable_progress=enable_progress) + + if not self.optuna_settings.multi_objective: + trial.report(sse, epoch) if trial.should_prune(): raise optuna.exceptions.TrialPruned() - return accuracy + writer.close() + + return sse + + if __name__ == "__main__": - # plt.ion() hyper_training = HyperTraining() - hyper_training.eye() - # hyper_training.setup_study() - # hyper_training.run_study() - for i in range(10): - #simulate some work - print(i) - time.sleep(0.2) + # hyper_training.resume_latest_study() + + hyper_training.setup_study() + hyper_training.run_study() + best_model = hyper_training.define_model(hyper_training.study.best_trial).to(hyper_training.pytorch_settings.device) + data_settings_backup = copy.copy(hyper_training.data_settings) + hyper_training.data_settings.shuffle = False + hyper_training.data_settings.train_split = 0.01 + plot_loader, _ = hyper_training.get_sliced_data(hyper_training.study.best_trial) + + regen = hyper_training.run_model(best_model, plot_loader) + regen = regen.view(-1, 2) + # [batch_no, batch_size, 2] -> [no, 2] + + original, _ = util.datasets.load_data(hyper_training.data_settings.config_path, skipfirst=hyper_training.data_settings.drop_first) + original = original[:len(regen)] + + regen = regen.cpu().numpy() + _, axs = plt.subplots(2) + for i, ax in enumerate(axs): + ax.plot(np.abs(original[:, i])**2, label="original") + ax.plot(np.abs(regen[:, i])**2, label="regen") + ax.legend() plt.show() - ... + + print(f"Best model: {best_model}") + + + # eye_fig = hyper_training.plot_eye() + ... diff --git a/src/single-core-regen/regen_no_hyper.py b/src/single-core-regen/regen_no_hyper.py new file mode 100644 index 0000000..d136a78 --- /dev/null +++ b/src/single-core-regen/regen_no_hyper.py @@ -0,0 +1,429 @@ +import copy +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path +import matplotlib.pyplot as plt + +import numpy as np + +import torch +import torch.nn as nn + +# import torch.nn.functional as F # mse_loss doesn't support complex numbers +import torch.optim as optim +import torch.utils.data + +from torch.utils.tensorboard import SummaryWriter + +from rich.progress import ( + Progress, + TextColumn, + BarColumn, + TaskProgressColumn, + TimeRemainingColumn, + MofNCompleteColumn, + TimeElapsedColumn, +) +from rich.console import Console +from rich import print as rprint + +# from util.optuna_helpers import optional_suggest_categorical, optional_suggest_float, optional_suggest_int +import util + + +# global settings +@dataclass +class GlobalSettings: + seed: int = 42 + + +# data settings +@dataclass +class DataSettings: + config_path: str = "data/*-128-16384-100000-0-0-17-0-PAM4-0.ini" + dtype: torch.dtype = torch.complex64 + symbols_range: float | int = 8 + data_size_range: float | int = 64 + shuffle: bool = True + target_delay: float = 0 + xy_delay_range: float | int = 0 + drop_first: int = 10 + train_split: float = 0.8 + + +# pytorch settings +@dataclass +class PytorchSettings: + epochs: int = 1000 + batchsize: int = 2**12 + device: str = "cuda" + summary_dir: str = ".runs" + model_dir: str = ".models" + + +# model settings +@dataclass +class ModelSettings: + output_size: int = 2 + # n_layer_range: float|int = 2 + # n_units_range: float|int = 32 + n_layers: int = 3 + n_units: int = 32 + activation_func: tuple | str = "ModReLU" + + +@dataclass +class OptimizerSettings: + optimizer_range: str = "Adam" + lr_range: float = 2e-3 + + +class Training: + def __init__(self): + self.global_settings = GlobalSettings() + self.data_settings = DataSettings() + self.pytorch_settings = PytorchSettings() + self.model_settings = ModelSettings() + self.optimizer_settings = OptimizerSettings() + self.study_name = ( + f"single_core_regen_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}" + ) + + if not hasattr(self.pytorch_settings, "model_dir"): + self.pytorch_settings.model_dir = ".models" + + self.writer = None + self.console = Console() + + def setup_tb_writer(self, study_name=None): + log_dir = ( + self.pytorch_settings.summary_dir + "/" + (study_name or self.study_name) + ) + self.writer = SummaryWriter(log_dir) + + def plot_eye(self, width=2, symbols=None, alpha=None, complex=False, show=True): + if not hasattr(self, "eye_data"): + data, config = util.datasets.load_data( + self.data_settings.config_path, + skipfirst=10, + symbols=symbols or 1000, + real=not self.data_settings.dtype.is_complex, + normalize=True, + ) + self.eye_data = {"data": data, "sps": int(config["glova"]["sps"])} + return util.plot.eye( + **self.eye_data, + width=width, + show=show, + alpha=alpha, + complex=complex, + symbols=symbols or 1000, + skipfirst=0, + ) + + def define_model(self): + n_layers = self.model_settings.n_layers + + in_features = 2 * self.data_settings.data_size_range + + layers = [] + for i in range(n_layers): + out_features = self.model_settings.n_units + + layers.append(util.complexNN.UnitaryLayer(in_features, out_features)) + # layers.append(getattr(nn, self.model_settings.activation_func)()) + layers.append( + getattr(util.complexNN, self.model_settings.activation_func)() + ) + in_features = out_features + + layers.append( + util.complexNN.UnitaryLayer(in_features, self.model_settings.output_size) + ) + + if self.writer is not None: + self.writer.add_graph( + nn.Sequential(*layers), + torch.zeros(1, layers[0].in_features, dtype=self.data_settings.dtype), + ) + + return nn.Sequential(*layers) + + def get_sliced_data(self): + symbols = self.data_settings.symbols_range + + xy_delay = self.data_settings.xy_delay_range + + data_size = self.data_settings.data_size_range + + # get dataset + dataset = util.datasets.FiberRegenerationDataset( + file_path=self.data_settings.config_path, + symbols=symbols, + data_size=data_size, + target_delay=self.data_settings.target_delay, + xy_delay=xy_delay, + drop_first=self.data_settings.drop_first, + dtype=self.data_settings.dtype, + real=not self.data_settings.dtype.is_complex, + # device=self.pytorch_settings.device, + ) + + dataset_size = len(dataset) + indices = list(range(dataset_size)) + split = int(np.floor(self.data_settings.train_split * dataset_size)) + if self.data_settings.shuffle: + np.random.seed(self.global_settings.seed) + np.random.shuffle(indices) + + train_indices, valid_indices = indices[:split], indices[split:] + + if self.data_settings.shuffle: + train_sampler = torch.utils.data.SubsetRandomSampler(train_indices) + valid_sampler = torch.utils.data.SubsetRandomSampler(valid_indices) + else: + train_sampler = train_indices + valid_sampler = valid_indices + + + train_loader = torch.utils.data.DataLoader( + dataset, + batch_size=self.pytorch_settings.batchsize, + sampler=train_sampler, + drop_last=True, + pin_memory=True, + num_workers=24, + prefetch_factor=4, + # persistent_workers=True + ) + valid_loader = torch.utils.data.DataLoader( + dataset, + batch_size=self.pytorch_settings.batchsize, + sampler=valid_sampler, + drop_last=True, + pin_memory=True, + num_workers=24, + prefetch_factor=4, + # persistent_workers=True + ) + + return train_loader, valid_loader + + def train_model(self, model, optimizer, train_loader, epoch): + with Progress( + TextColumn("[yellow] Training..."), + TextColumn("Loss: {task.description}"), + BarColumn(), + TaskProgressColumn(), + TextColumn("[green]Batch"), + MofNCompleteColumn(), + TimeRemainingColumn(), + TimeElapsedColumn(), + # description="Training", + transient=False, + console=self.console, + refresh_per_second=10, + ) as progress: + task = progress.add_task("-.---e--", total=len(train_loader)) + + running_loss = 0.0 + model.train() + for batch_idx, (x, y) in enumerate(train_loader): + model.zero_grad(set_to_none=True) + x, y = ( + x.to(self.pytorch_settings.device), + y.to(self.pytorch_settings.device), + ) + y_pred = model(x) + loss = util.complexNN.complex_mse_loss(y_pred, y) + loss.backward() + optimizer.step() + + progress.update(task, advance=1, description=f"{loss.item():.3e}") + + running_loss += loss.item() + if self.writer is not None: + if (batch_idx + 1) % 10 == 0: + self.writer.add_scalar( + "training loss", + running_loss / 10, + epoch * len(train_loader) + batch_idx, + ) + running_loss = 0.0 + + return running_loss + + def eval_model(self, model, valid_loader, epoch): + with Progress( + TextColumn("[green]Evaluating..."), + TextColumn("Loss: {task.description}"), + BarColumn(), + TaskProgressColumn(), + TextColumn("[green]Batch"), + MofNCompleteColumn(), + TimeRemainingColumn(), + TimeElapsedColumn(), + # description="Training", + transient=False, + console=self.console, + refresh_per_second=10, + ) as progress: + task = progress.add_task("-.---e--", total=len(valid_loader)) + + model.eval() + running_loss = 0 + running_loss2 = 0 + with torch.no_grad(): + for batch_idx, (x, y) in enumerate(valid_loader): + x, y = ( + x.to(self.pytorch_settings.device), + y.to(self.pytorch_settings.device), + ) + y_pred = model(x) + loss = util.complexNN.complex_mse_loss(y_pred, y) + running_loss += loss.item() + running_loss2 += loss.item() + + progress.update(task, advance=1, description=f"{loss.item():.3e}") + if self.writer is not None: + if (batch_idx + 1) % 10 == 0: + self.writer.add_scalar( + "loss", + running_loss / 10, + epoch * len(valid_loader) + batch_idx, + ) + running_loss = 0.0 + + if self.writer is not None: + self.writer.add_figure("fiber response", self.plot_model_response(model, plot=False), epoch+1) + + return running_loss2 / len(valid_loader) + + def run_model(self, model, loader): + model.eval() + xs = [] + ys = [] + y_preds = [] + with torch.no_grad(): + model = model.to(self.pytorch_settings.device) + for x, y in loader: + x, y = ( + x.to(self.pytorch_settings.device), + y.to(self.pytorch_settings.device), + ) + y_pred = model(x).cpu() + # x = x.cpu() + # y = y.cpu() + y_pred = y_pred.view(y_pred.shape[0], -1, 2) + y = y.view(y.shape[0], -1, 2) + x = x.view(x.shape[0], -1, 2) + xs.append(x[:, 0, :].squeeze()) + ys.append(y.squeeze()) + y_preds.append(y_pred.squeeze()) + + xs = torch.vstack(xs).cpu() + ys = torch.vstack(ys).cpu() + y_preds = torch.vstack(y_preds).cpu() + return ys, xs, y_preds + + def dummy_model(self, loader): + xs = [] + ys = [] + for x, y in loader: + y = y.cpu().view(y.shape[0], -1, 2) + x = x.cpu().view(x.shape[0], -1, 2) + xs.append(x[:, 0, :].squeeze()) + ys.append(y.squeeze()) + xs = torch.vstack(xs) + ys = torch.vstack(ys) + return xs, ys + + def objective(self, save=False, plot_before=False): + try: + rprint(*list(self.study_name.split("_"))) + + self.model = self.define_model().to(self.pytorch_settings.device) + + if self.writer is not None: + self.writer.add_figure("fiber response", self.plot_model_response(plot=plot_before), 0) + + train_loader, valid_loader = self.get_sliced_data() + + optimizer_name = self.optimizer_settings.optimizer_range + + lr = self.optimizer_settings.lr_range + + optimizer = getattr(optim, optimizer_name)(self.model.parameters(), lr=lr) + + for epoch in range(self.pytorch_settings.epochs): + self.console.rule(f"Epoch {epoch + 1}/{self.pytorch_settings.epochs}") + self.train_model(self.model, optimizer, train_loader, epoch) + eval_loss = self.eval_model(self.model, valid_loader, epoch) + + if save: + save_path = ( + Path(self.pytorch_settings.model_dir) / f"{self.study_name}.pth" + ) + save_path.parent.mkdir(parents=True, exist_ok=True) + torch.save(self.model, save_path) + + return eval_loss + except KeyboardInterrupt: + pass + finally: + if hasattr(self, "model"): + except_save_path = Path(".models/exception") / f"{self.study_name}.pth" + except_save_path.parent.mkdir(parents=True, exist_ok=True) + torch.save(self.model, except_save_path) + + def _plot_model_response_plotter(self, fiber_in, fiber_out, regen, plot=True): + fig, axs = plt.subplots(2) + for i, ax in enumerate(axs): + ax.plot(np.abs(fiber_in[:, i]) ** 2, label="fiber in") + ax.plot(np.abs(fiber_out[:, i]) ** 2, label="fiber out") + ax.plot(np.abs(regen[:, i]) ** 2, label="regenerated") + ax.legend() + if plot: + plt.show() + return fig + + def plot_model_response(self, model=None, plot=True): + data_settings_backup = copy.copy(self.data_settings) + self.data_settings.shuffle = False + self.data_settings.train_split = 0.01 + self.data_settings.drop_first = 100 + plot_loader, _ = self.get_sliced_data() + self.data_settings = data_settings_backup + + fiber_in, fiber_out, regen = self.run_model(model or self.model, plot_loader) + fiber_in = fiber_in.view(-1, 2) + fiber_out = fiber_out.view(-1, 2) + regen = regen.view(-1, 2) + + fiber_in = fiber_in.numpy() + fiber_out = fiber_out.numpy() + regen = regen.numpy() + + # https://github.com/matplotlib/matplotlib/issues/27713#issue-2104110987 + # https://github.com/matplotlib/matplotlib/issues/27713#issuecomment-1915497463 + import gc + fig = self._plot_model_response_plotter(fiber_in, fiber_out, regen, plot=plot) + gc.collect() + + return fig + +if __name__ == "__main__": + trainer = Training() + + # trainer.plot_eye() + trainer.setup_tb_writer() + trainer.objective(save=True) + + best_model = trainer.model + + # best_model = trainer.define_model(trainer.study.best_trial).to(trainer.pytorch_settings.device) + trainer.plot_model_response(best_model) + + # print(f"Best model: {best_model}") + + ... diff --git a/src/single-core-regen/util/__init__.py b/src/single-core-regen/util/__init__.py index 4e5dfc4..0cc2191 100644 --- a/src/single-core-regen/util/__init__.py +++ b/src/single-core-regen/util/__init__.py @@ -1,2 +1,17 @@ -from .datasets import FiberRegenerationDataset # noqa: F401 -from .plot import eye # noqa: F401 \ No newline at end of file +from . import datasets # noqa: F401 +# from .datasets import FiberRegenerationDataset # noqa: F401 +# from .datasets import load_data # noqa: F401 + + +from . import plot # noqa: F401 +# from .plot import eye # noqa: F401 + +from . import optuna_helpers # noqa: F401 +# from .optuna_helpers import optional_suggest_categorical # noqa: F401 +# from .optuna_helpers import optional_suggest_float # noqa: F401 +# from .optuna_helpers import optional_suggest_int # noqa: F401 + +from . import complexNN # noqa: F401 +# from .complexNN import UnitaryLayer # noqa: F401 +# from .complexNN import complex_mse_loss # noqa: F401 +# from .complexNN import complex_sse_loss # noqa: F401 diff --git a/src/single-core-regen/util/complexNN.py b/src/single-core-regen/util/complexNN.py new file mode 100644 index 0000000..e42a836 --- /dev/null +++ b/src/single-core-regen/util/complexNN.py @@ -0,0 +1,141 @@ +import torch +import torch.nn as nn + +def complex_mse_loss(input, target): + """ + Compute the mean squared error between two complex tensors. + """ + return torch.mean(torch.square(input.real - target.real) + torch.square(input.imag - target.imag)) + +def complex_sse_loss(input, target): + """ + Compute the sum squared error between two complex tensors. + """ + if input.is_complex(): + return torch.sum(torch.square(input.real - target.real) + torch.square(input.imag - target.imag)) + else: + return torch.sum(torch.square(input - target)) + + + + +class UnitaryLayer(nn.Module): + def __init__(self, in_features, out_features): + super(UnitaryLayer, self).__init__() + assert in_features >= out_features + self.in_features = in_features + self.out_features = out_features + self.weight = nn.Parameter(torch.randn(in_features, out_features, dtype=torch.cfloat)) + self.reset_parameters() + + def reset_parameters(self): + q, _ = torch.linalg.qr(self.weight) + self.weight.data = q + + @staticmethod + @torch.jit.script + def _unitary_forward(x, weight): + out = torch.matmul(x, weight) + return out + + def forward(self, x): + return self._unitary_forward(x, self.weight) + + +#### as defined by zhang et al + +class Identity(nn.Module): + """ + implements the "activation" function + M(z) = z + """ + def __init__(self): + super(Identity, self).__init__() + + def forward(self, x): + return x + +class Mag(nn.Module): + """ + implements the activation function + M(z) = ||z|| + """ + def __init__(self): + super(Mag, self).__init__() + + @torch.jit.script + def forward(self, x): + return torch.abs(x.real**2 + x.imag**2) + +# class Tanh(nn.Module): +# """ +# implements the activation function +# M(z) = tanh(z) = sinh(z)/cosh(z) = (exp(z)-exp(-z))/(exp(z)+exp(-z)) = (exp(2*z)-1)/(exp(2*z)+1) +# """ +# def __init__(self): +# super(Tanh, self).__init__() + +# def forward(self, x): +# return torch.tanh(x) + +class ModReLU(nn.Module): + """ + implements the activation function + M(z) = ReLU(||z|| + b)*exp(j*theta_z) + = ReLU(||z|| + b)*z/||z|| + """ + def __init__(self, b=0): + super(ModReLU, self).__init__() + self.b = b + self.relu = nn.ReLU() + + @staticmethod + # @torch.jit.script + def _mod_relu(x, b): + mod = torch.abs(x.real**2 + x.imag**2) + return torch.relu(mod + b) * x / mod + + def forward(self, x): + return self._mod_relu(x, self.b) + +class CReLU(nn.Module): + """ + implements the activation function + M(z) = ReLU(Re(z)) + j*ReLU(Im(z)) + """ + def __init__(self): + super(CReLU, self).__init__() + self.relu = nn.ReLU() + + @torch.jit.script + def forward(self, x): + return torch.relu(x.real) + 1j*torch.relu(x.imag) + +class ZReLU(nn.Module): + """ + implements the activation function + + M(z) = z if 0 <= angle(z) <= pi/2 + = 0 otherwise + """ + + def __init__(self): + super(ZReLU, self).__init__() + + @torch.jit.script + def forward(self, x): + return x * (torch.angle(x) >= 0) * (torch.angle(x) <= torch.pi/2) + +# class ComplexFeedForwardNN(nn.Module): +# def __init__(self, in_features, hidden_features, out_features): +# super(ComplexFeedForwardNN, self).__init__() +# self.in_features = in_features +# self.hidden_features = hidden_features +# self.out_features = out_features +# self.fc1 = UnitaryLayer(in_features, hidden_features) +# self.fc2 = UnitaryLayer(hidden_features, out_features) + +# def forward(self, x): +# x = self.fc1(x) +# x = self.fc2(x) +# return x \ No newline at end of file diff --git a/src/single-core-regen/util/datasets.py b/src/single-core-regen/util/datasets.py index 7bbdec5..1d1d543 100644 --- a/src/single-core-regen/util/datasets.py +++ b/src/single-core-regen/util/datasets.py @@ -1,11 +1,28 @@ from pathlib import Path import torch from torch.utils.data import Dataset +# from torch.utils.data import Sampler import numpy as np import configparser +# class SubsetSampler(Sampler[int]): +# """ +# Samples elements from a given list of indices. -def load_data(config_path, skipfirst=0, num_symbols=None): +# :param indices: List of indices to sample from. +# :type indices: list[int] +# """ + +# def __init__(self, indices): +# self.indices = indices + +# def __iter__(self): +# return iter(self.indices) + +# def __len__(self): +# return len(self.indices) + +def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=False, device=None, dtype=None): filepath = Path(config_path) filepath = filepath.parent.glob(filepath.name) config = configparser.ConfigParser() @@ -18,15 +35,25 @@ def load_data(config_path, skipfirst=0, num_symbols=None): datapath = Path("/".join(path_elements).replace('"', "")) sps = int(config["glova"]["sps"]) - if num_symbols is None: - num_symbols = int(config["glova"]["nos"]) - skipfirst + if symbols is None: + symbols = int(config["glova"]["nos"]) - skipfirst - data = np.load(datapath)[skipfirst * sps : num_symbols * sps + skipfirst * sps] - config["glova"]["nos"] = str(num_symbols) + data = np.load(datapath)[skipfirst * sps : symbols * sps + skipfirst * sps] + + if normalize: + a, b, c, d = data.T + a, b, c, d = a/np.max(np.abs(a)), b/np.max(np.abs(b)), c/np.max(np.abs(c)), d/np.max(np.abs(d)) + data = np.array([a, b, c, d]).T + + if real: + data = np.abs(data) + + config["glova"]["nos"] = str(symbols) + + data = torch.tensor(data, device=device, dtype=dtype) return data, config - def roll_along(arr, shifts, dim): # https://stackoverflow.com/a/76920720 # (c) Mateen Ulhaq, 2023 @@ -39,7 +66,6 @@ def roll_along(arr, shifts, dim): indices = (dim_indices - shifts.unsqueeze(dim)) % arr.shape[dim] return torch.gather(arr, dim, indices) - class FiberRegenerationDataset(Dataset): """ Dataset for fiber regeneration training. @@ -76,6 +102,9 @@ class FiberRegenerationDataset(Dataset): target_delay: float | int = 0, xy_delay: float | int = 0, drop_first: float | int = 0, + dtype: torch.dtype = None, + real: bool = False, + device = None, **kwargs, ): """ @@ -123,13 +152,16 @@ class FiberRegenerationDataset(Dataset): [[i + 0.1j, i + 0.2j, i + 1.1j, i + 1.2j] for i in range(12800)], dtype=np.complex128, ) + data_raw = torch.tensor(data_raw, device=device, dtype=dtype) self.config = { "data": {"dir": '"."', "npy_dir": '"."', "file": "faux"}, "glova": {"sps": 128}, } else: - data_raw, self.config = load_data(file_path) + data_raw, self.config = load_data(file_path, skipfirst=drop_first, real=real, normalize=True, device=device, dtype=dtype) + self.device = data_raw.device + self.samples_per_symbol = int(self.config["glova"]["sps"]) self.samples_per_slice = int(symbols * self.samples_per_symbol) self.symbols_per_slice = self.samples_per_slice / self.samples_per_symbol @@ -140,7 +172,6 @@ class FiberRegenerationDataset(Dataset): ovrd_target_delay_samples = kwargs.pop("ovrd_target_delay_samples", None) ovrd_xy_delay_samples = kwargs.pop("ovrd_xy_delay_samples", None) - ovrd_drop_first_samples = kwargs.pop("ovrd_drop_first_samples", None) self.target_delay_samples = ( ovrd_target_delay_samples @@ -152,14 +183,8 @@ class FiberRegenerationDataset(Dataset): if ovrd_xy_delay_samples is not None else int(self.xy_delay * self.samples_per_symbol) ) - drop_first_samples = ( - ovrd_drop_first_samples - if ovrd_drop_first_samples is not None - else int(drop_first * self.samples_per_symbol) - ) - # drop samples from the beginning - data_raw = data_raw[drop_first_samples:] + # data_raw = torch.tensor(data_raw, dtype=dtype) # data layout # [ [E_in_x0, E_in_y0, E_out_x0, E_out_y0], @@ -240,10 +265,10 @@ class FiberRegenerationDataset(Dataset): data = data.view(data.shape[0], self.data_size, -1) data = data[:, :, 0] - # target is corresponding to the latest data point -> try to regenerate that + # target is corresponding to the middle of the data as the output sample is influenced by the data before and after it target = target[:, : target.shape[1] // self.data_size * self.data_size] target = target.view(target.shape[0], self.data_size, -1) - target = target[:, 0, 0] + target = target[:, 0, target.shape[2] // 2] data = data.transpose(0, 1).flatten().squeeze() target = target.flatten().squeeze() diff --git a/src/single-core-regen/util/optuna_helpers.py b/src/single-core-regen/util/optuna_helpers.py new file mode 100644 index 0000000..1d924a9 --- /dev/null +++ b/src/single-core-regen/util/optuna_helpers.py @@ -0,0 +1,30 @@ +def _optional_suggest(trial, name, range_or_value, log=False, step=None, type='int'): + # not a range + if not hasattr(range_or_value, '__iter__') or isinstance(range_or_value, str): + return range_or_value + + # range with only one value + if len(range_or_value) == 1: + return range_or_value[0] + + if type == 'int': + step = step or 1 + return trial.suggest_int(name, *range_or_value, step=step, log=log) + + if type == 'float': + return trial.suggest_float(name, *range_or_value, step=step, log=log) + + if type == 'categorical': + return trial.suggest_categorical(name, range_or_value) + + raise ValueError(f"Unknown type: {type}") + + +def optional_suggest_categorical(trial, name, choices_or_value): + return _optional_suggest(trial, name, choices_or_value, type='categorical') + +def optional_suggest_int(trial, name, range_or_value, step=None, log=False): + return _optional_suggest(trial, name, range_or_value, step=step, log=log, type='int') + +def optional_suggest_float(trial, name, range_or_value, step=None, log=False): + return _optional_suggest(trial, name, range_or_value, step=step, log=log, type='float') \ No newline at end of file diff --git a/src/single-core-regen/util/optuna_vis.py b/src/single-core-regen/util/optuna_vis.py new file mode 100644 index 0000000..106dbb4 --- /dev/null +++ b/src/single-core-regen/util/optuna_vis.py @@ -0,0 +1,18 @@ +from dash import Dash, dcc, html +import logging +import dash_bootstrap_components as dbc + + +def show_figures(*figures): + for figure in figures: + figure.layout.template = 'plotly_dark' + + app = Dash(external_stylesheets=[dbc.themes.DARKLY]) + app.layout = html.Div([ + dcc.Graph(figure=figure) for figure in figures + ]) + log = logging.getLogger('werkzeug') + log.setLevel(logging.ERROR) + + app.show = lambda *args, **kwargs: app.run_server(*args, **kwargs, debug=False) + return app \ No newline at end of file diff --git a/src/single-core-regen/util/plot.py b/src/single-core-regen/util/plot.py index 8a3a196..2bbc8ce 100644 --- a/src/single-core-regen/util/plot.py +++ b/src/single-core-regen/util/plot.py @@ -2,33 +2,72 @@ import matplotlib.pyplot as plt import numpy as np from .datasets import load_data -def eye(path, title=None, head=1000, skipfirst=1000, show=True): +def eye(*, path=None, data=None, sps=None, title=None, symbols=1000, skipfirst=0, width=2, alpha=None, complex=False, show=True): """Plot an eye diagram for the data given by filepath. + + Either path or data and sps must be given. Args: path (str): Path to the data description file. + data (np.ndarray): Data to plot. + sps (int): Samples per symbol. title (str): Title of the plot. head (int): Number of symbols to plot. skipfirst (int): Number of symbols to skip. show (bool): Whether to call plt.show(). """ - data, config = load_data(path, skipfirst, head) - sps = int(config["glova"]["sps"]) + if path is None and data is None: + raise ValueError("Either path or data and sps must be given.") + if path is not None: + data, config = load_data(path, skipfirst, symbols) + sps = int(config["glova"]["sps"]) + if sps is None: + raise ValueError("sps not set.") - xaxis = np.linspace(0, 2, 2*sps, endpoint=False) + xaxis = np.linspace(0, width, width*sps, endpoint=False) fig, axs = plt.subplots(2, 2, figsize=(10, 10), sharex=True, sharey=True) - for i in range(head-1): - inx, iny, outx, outy = data[i*sps:(i+2)*sps].T - axs[0, 0].plot(xaxis, np.abs(inx)**2, color="C0", alpha=0.1) - axs[0, 1].plot(xaxis, np.abs(outx)**2, color="C0", alpha=0.1) - axs[1, 0].plot(xaxis, np.abs(iny)**2, color="C0", alpha=0.1) - axs[1, 1].plot(xaxis, np.abs(outy)**2, color="C0", alpha=0.1) + if complex: + # create secondary axis for phase + axs2 = axs[0, 0].twinx(), axs[0, 1].twinx(), axs[1, 0].twinx(), axs[1, 1].twinx() + axs2 = np.reshape(axs2, (2, 2)) + + for i in range(symbols-(width-1)): + inx, iny, outx, outy = data[i*sps:(i+width)*sps].T + if complex: + axs[0, 0].plot(xaxis, np.abs(inx), color="C0", alpha=alpha or 0.1) + axs[0, 1].plot(xaxis, np.abs(outx), color="C0", alpha=alpha or 0.1) + axs[1, 0].plot(xaxis, np.abs(iny), color="C0", alpha=alpha or 0.1) + axs[1, 1].plot(xaxis, np.abs(outy), color="C0", alpha=alpha or 0.1) + axs[0,0].set_ylim(0, 1.1*np.max(np.abs(data))) + + axs2[0, 0].plot(xaxis, np.angle(inx), color="C1", alpha=alpha or 0.1) + axs2[0, 1].plot(xaxis, np.angle(outx), color="C1", alpha=alpha or 0.1) + axs2[1, 0].plot(xaxis, np.angle(iny), color="C1", alpha=alpha or 0.1) + axs2[1, 1].plot(xaxis, np.angle(outy), color="C1", alpha=alpha or 0.1) + else: + axs[0, 0].plot(xaxis, np.abs(inx)**2, color="C0", alpha=alpha or 0.1) + axs[0, 1].plot(xaxis, np.abs(outx)**2, color="C0", alpha=alpha or 0.1) + axs[1, 0].plot(xaxis, np.abs(iny)**2, color="C0", alpha=alpha or 0.1) + axs[1, 1].plot(xaxis, np.abs(outy)**2, color="C0", alpha=alpha or 0.1) + + if complex: + axs2[0, 0].sharey(axs2[0, 1]) + axs2[0, 1].sharey(axs2[1, 0]) + axs2[1, 0].sharey(axs2[1, 1]) + # make y axis symmetric + ylim = np.max(np.abs(np.angle(data)))*1.1 + if ylim != 0: + axs2[0, 0].set_ylim(-ylim, ylim) + else: + axs[0,0].set_ylim(0, 1.1*np.max(np.abs(data))**2) axs[0, 0].set_title("Input x") axs[0, 1].set_title("Output x") axs[1, 0].set_title("Input y") axs[1, 1].set_title("Output y") - fig.suptitle(title) + fig.suptitle(title or "Eye diagram") if show: - plt.show(block=False) + plt.show() + + return fig