diff --git a/src/single-core-regen/hypertraining/hypertraining.py b/src/single-core-regen/hypertraining/hypertraining.py index 5894475..50c8da5 100644 --- a/src/single-core-regen/hypertraining/hypertraining.py +++ b/src/single-core-regen/hypertraining/hypertraining.py @@ -6,7 +6,8 @@ import matplotlib.pyplot as plt import numpy as np import optuna -import optunahub + +# import optunahub import warnings import torch @@ -18,26 +19,28 @@ import torch.utils.data from torch.utils.tensorboard import SummaryWriter -from rich.progress import ( - Progress, - TextColumn, - BarColumn, - TaskProgressColumn, - TimeRemainingColumn, - MofNCompleteColumn, - TimeElapsedColumn, -) -from rich.console import Console +# from rich.progress import ( +# Progress, +# TextColumn, +# BarColumn, +# TaskProgressColumn, +# TimeRemainingColumn, +# MofNCompleteColumn, +# TimeElapsedColumn, +# ) +# from rich.console import Console # from rich import print as rprint import multiprocessing from util.datasets import FiberRegenerationDataset -from util.optuna_helpers import ( - force_suggest_categorical, - force_suggest_float, - force_suggest_int, -) + +# from util.optuna_helpers import ( +# suggest_categorical_optional, # noqa: F401 +# suggest_float_optional, # noqa: F401 +# suggest_int_optional, # noqa: F401 +# ) +from util.optuna_helpers import install_optional_suggests import util from .settings import ( @@ -49,6 +52,8 @@ from .settings import ( PytorchSettings, ) +install_optional_suggests() + class HyperTraining: def __init__( @@ -60,7 +65,7 @@ class HyperTraining: model_settings, optimizer_settings, optuna_settings, - console=None, + # console=None, ): self.global_settings: GlobalSettings = global_settings self.data_settings: DataSettings = data_settings @@ -68,18 +73,16 @@ class HyperTraining: self.model_settings: ModelSettings = model_settings self.optimizer_settings: OptimizerSettings = optimizer_settings self.optuna_settings: OptunaSettings = optuna_settings + self.processes = None - self.console = console or Console() + # self.console = console or Console() # set some extra settings to make the code more readable self._extra_optuna_settings() + self.stop_study = True def setup_tb_writer(self, study_name=None, append=None): - log_dir = ( - self.pytorch_settings.summary_dir - + "/" - + (study_name or self.optuna_settings.study_name) - ) + log_dir = self.pytorch_settings.summary_dir + "/" + (study_name or self.optuna_settings.study_name) if append is not None: log_dir += "_" + str(append) @@ -89,180 +92,211 @@ class HyperTraining: study_name = self.get_latest_study() if study_name: - print(f"Resuming study: {study_name}") + if verbose: + print(f"Resuming study: {study_name}") self.optuna_settings.study_name = study_name - def get_latest_study(self, verbose=True): + def get_latest_study(self, verbose=False) -> optuna.Study: studies = self.get_studies() + study = None for study in studies: study.datetime_start = study.datetime_start or datetime.min if studies: study = sorted(studies, key=lambda x: x.datetime_start, reverse=True)[0] if verbose: print(f"Last study: {study.study_name}") - study_name = study.study_name else: if verbose: print("No previous studies found") - study_name = None - return study_name + return optuna.load_study(study_name=study.study_name, storage=self.optuna_settings.storage) + + # def study(self) -> optuna.Study: + # return optuna.load_study(self.optuna_settings.study_name, storage=self.optuna_settings.storage) def get_studies(self): return optuna.get_all_study_summaries(storage=self.optuna_settings.storage) def setup_study(self): - module = optunahub.load_module(package="samplers/auto_sampler") + # module = optunahub.load_module(package="samplers/auto_sampler") + if self.optuna_settings._parallel: + self.processes = [] + + pruner = getattr(optuna.pruners, self.optuna_settings.pruner, None) + + if pruner and self.optuna_settings.pruner_kwargs is not None: + pruner = pruner(**self.optuna_settings.pruner_kwargs) + elif pruner: + pruner = pruner() + self.study = optuna.create_study( study_name=self.optuna_settings.study_name, storage=self.optuna_settings.storage, load_if_exists=True, - direction=self.optuna_settings.direction, - directions=self.optuna_settings.directions, - sampler=module.AutoSampler(), + direction=self.optuna_settings._direction, + directions=self.optuna_settings._directions, + pruner=pruner, + # sampler=module.AutoSampler(), ) - print("using sampler:", self.study.sampler) + # print("using sampler:", self.study.sampler) with warnings.catch_warnings(action="ignore"): self.study.set_metric_names(self.optuna_settings.metrics_names) - self.n_threads = min( - self.optuna_settings.n_trials, self.optuna_settings.n_threads - ) - self.processes = [] - if self.n_threads > 1: - for _ in range(self.n_threads): - p = multiprocessing.Process( - # target=lambda n_trials: self._run_optimize(self, n_trials), - target=self._run_optimize, - args=(self.optuna_settings.n_trials // self.n_threads,), - ) - self.processes.append(p) - - # def plot_eye(self, width=2, symbols=None, alpha=None, complex=False, show=True): - # data, config = util.datasets.load_data( - # self.data_settings.config_path, - # skipfirst=10, - # symbols=symbols or 1000, - # real=not complex, - # normalize=True, - # ) - # eye_data = {"data": data.numpy(), "sps": int(config["glova"]["sps"])} - # return util.plot.eye( - # **eye_data, - # width=width, - # show=show, - # alpha=alpha, - # complex=complex, - # symbols=symbols or 1000, - # skipfirst=0, - # ) - def run_study(self): - if self.processes: - for p in self.processes: - p.start() - for p in self.processes: - p.join() + try: + if self.optuna_settings._parallel: + self._run_parallel_study() + else: + self._run_study() + except KeyboardInterrupt: + print("Stopping. Please wait for the processes to finish.") + self.stop_study = True - remaining_trials = self.optuna_settings.n_trials % self.n_threads - else: - remaining_trials = self.optuna_settings.n_trials + def trials_left(self): + return self.optuna_settings.n_trials - len(self.study.get_trials(states=self.optuna_settings.n_trials_filter)) - if remaining_trials: - self._run_optimize(remaining_trials) + def remove_completed_processes(self): + if self.processes is None: + return + for p, process in enumerate(self.processes): + if not process.is_alive(): + process.join() + self.processes.pop(p) - def _run_optimize(self, n_trials): + def remove_outliers(self): + if self.optuna_settings.remove_outliers is not None: + trials = self.study.get_trials(states=(optuna.trial.TrialState.COMPLETE,)) + if len(trials) == 0: + return + vals = [trial.value for trial in trials] + vals = np.log(vals) + mean = np.mean(vals) + std = np.std(vals) + outliers = [ + trial for trial in trials if np.log(trial.value) > mean + self.optuna_settings.remove_outliers * std + ] + for trial in outliers: + trial: optuna.trial.Trial = trial + trial.state = optuna.trial.TrialState.FAIL + trial.set_user_attr("outlier", True) + + def _run_study(self): + while trials_left := self.trials_left(): + self.remove_outliers() + self._run_optimize(n_trials=trials_left, timeout=self.optuna_settings.timeout) + + def _run_parallel_study(self): + while trials_left := self.trials_left(): + self.remove_outliers() + self.remove_completed_processes() + + n_trials = max(trials_left, self.optuna_settings._n_threads) // self.optuna_settings._n_threads + + def target_fun(): + self._run_optimize(n_trials=n_trials, timeout=self.optuna_settings.timeout) + + for _ in range(self.optuna_settings._n_threads - len(self.processes)): + self.processes.append(multiprocessing.Process(target=target_fun)) + self.processes[-1].start() + + def _run_optimize(self, **kwargs): self.study.optimize( - self.objective, n_trials=n_trials, timeout=self.optuna_settings.timeout + self.objective, + **kwargs, + show_progress_bar=not self.optuna_settings._parallel, ) def _extra_optuna_settings(self): - self.optuna_settings.multi_objective = len(self.optuna_settings.directions) > 1 - if self.optuna_settings.multi_objective: - self.optuna_settings.direction = None - else: - self.optuna_settings.direction = self.optuna_settings.directions[0] - self.optuna_settings.directions = None + self.optuna_settings._multi_objective = len(self.optuna_settings.directions) > 1 - self.optuna_settings.n_train_batches = ( - self.optuna_settings.n_train_batches - if self.optuna_settings.limit_examples - else float("inf") + if self.optuna_settings._multi_objective: + self.optuna_settings._direction = None + self.optuna_settings._directions = self.optuna_settings.directions + else: + self.optuna_settings._direction = self.optuna_settings.directions[0] + self.optuna_settings._directions = None + + self.optuna_settings._n_train_batches = ( + self.optuna_settings.n_train_batches if self.optuna_settings.limit_examples else float("inf") ) - self.optuna_settings.n_valid_batches = ( - self.optuna_settings.n_valid_batches - if self.optuna_settings.limit_examples - else float("inf") + self.optuna_settings._n_valid_batches = ( + self.optuna_settings.n_valid_batches if self.optuna_settings.limit_examples else float("inf") ) + self.optuna_settings._n_threads = self.optuna_settings.n_workers + + self.optuna_settings._parallel = self.optuna_settings._n_threads > 1 + def define_model(self, trial: optuna.Trial, writer=None): - n_layers = force_suggest_int( - trial, "model_n_layers", self.model_settings.model_n_layers - ) + n_layers = trial.suggest_int_optional("model_n_hidden_layers", self.model_settings.n_hidden_layers) - input_dim = 2 * trial.params.get( + input_dim = trial.suggest_int_optional( "model_input_dim", - force_suggest_int(trial, "model_input_dim", self.data_settings.model_input_dim), + self.data_settings.output_size, + step=2, + multiply=2, + set_new=False, ) - dtype = trial.params.get( - "model_dtype", - force_suggest_categorical(trial, "model_dtype", self.data_settings.dtype), - ) + # trial.set_user_attr("model_input_dim", input_dim) + + dtype = trial.suggest_categorical_optional("model_dtype", self.data_settings.dtype, set_new=False) dtype = getattr(torch, dtype) - afunc = force_suggest_categorical( - trial, "model_activation_func", self.model_settings.model_activation_func - ) + afunc = trial.suggest_categorical_optional("model_activation_func", self.model_settings.model_activation_func) layers = [] last_dim = input_dim + n_nodes = last_dim for i in range(n_layers): - hidden_dim = force_suggest_int( - trial, f"model_hidden_dim_{i}", self.model_settings.unit_count - ) - layers.append( - util.complexNN.SemiUnitaryLayer(last_dim, hidden_dim, dtype=dtype) - ) + if hidden_dim_override := self.model_settings.overrides.get(f"n_hidden_nodes_{i}", False): + hidden_dim = trial.suggest_int_optional(f"model_hidden_dim_{i}", hidden_dim_override, force=True) + else: + hidden_dim = trial.suggest_int_optional( + f"model_hidden_dim_{i}", + self.model_settings.n_hidden_nodes, + # step=2, + ) + layers.append(util.complexNN.SemiUnitaryLayer(last_dim, hidden_dim, dtype=dtype)) last_dim = hidden_dim layers.append(getattr(util.complexNN, afunc)()) + n_nodes += last_dim - layers.append( - util.complexNN.UnitaryLayer( - hidden_dim, self.model_settings.output_dim, dtype=dtype - ) - ) + layers.append(util.complexNN.SemiUnitaryLayer(last_dim, self.model_settings.output_dim, dtype=dtype)) model = nn.Sequential(*layers) if writer is not None: - writer.add_graph( - model, torch.zeros(1, input_dim, dtype=dtype), use_strict_trace=False - ) + writer.add_graph(model, torch.zeros(1, input_dim, dtype=dtype), use_strict_trace=False) + + n_params = sum(p.numel() for p in model.parameters()) + trial.set_user_attr("model_n_params", n_params) + trial.set_user_attr("model_n_nodes", n_nodes) return model.to(self.pytorch_settings.device) def get_sliced_data(self, trial: optuna.Trial, override=None): - symbols = trial.params.get( - "dataset_symbols", - force_suggest_float(trial, "dataset_symbols", self.data_settings.symbols), + symbols = trial.suggest_float_optional("dataset_symbols", self.data_settings.symbols, set_new=False) + + in_out_delay = trial.suggest_float_optional( + "dataset_in_out_delay", self.data_settings.in_out_delay, set_new=False ) - xy_delay = trial.params.get( - "dataset_xy_delay", - force_suggest_float(trial, "dataset_xy_delay", self.data_settings.xy_delay), + xy_delay = trial.suggest_float_optional("dataset_xy_delay", self.data_settings.xy_delay, set_new=False) + + data_size = int( + 0.5 + * trial.suggest_int_optional( + "model_input_dim", + self.data_settings.output_size, + step=2, + multiply=2, + set_new=False, + ) ) - data_size = trial.params.get( - "model_input_dim", - force_suggest_int(trial, "model_input_dim", self.data_settings.model_input_dim), - ) - - dtype = trial.params.get( - "model_dtype", - force_suggest_categorical(trial, "model_dtype", self.data_settings.dtype), - ) + dtype = trial.suggest_categorical_optional("model_dtype", self.data_settings.dtype, set_new=False) dtype = getattr(torch, dtype) num_symbols = None @@ -273,7 +307,7 @@ class HyperTraining: file_path=self.data_settings.config_path, symbols=symbols, output_dim=data_size, - target_delay=self.data_settings.in_out_delay, + target_delay=in_out_delay, xy_delay=xy_delay, drop_first=self.data_settings.drop_first, dtype=dtype, @@ -327,31 +361,31 @@ class HyperTraining: train_loader, epoch, writer=None, - enable_progress=False, + # enable_progress=False, ): - if enable_progress: - progress = Progress( - TextColumn("[yellow] Training..."), - TextColumn("Error: {task.description}"), - BarColumn(), - TaskProgressColumn(), - TextColumn("[green]Batch"), - MofNCompleteColumn(), - TimeRemainingColumn(), - TimeElapsedColumn(), - # description="Training", - transient=False, - console=self.console, - refresh_per_second=10, - ) - task = progress.add_task("-.---e--", total=len(train_loader)) - progress.start() + # if enable_progress: + # progress = Progress( + # TextColumn("[yellow] Training..."), + # TextColumn("Error: {task.description}"), + # BarColumn(), + # TaskProgressColumn(), + # TextColumn("[green]Batch"), + # MofNCompleteColumn(), + # TimeRemainingColumn(), + # TimeElapsedColumn(), + # # description="Training", + # transient=False, + # console=self.console, + # refresh_per_second=10, + # ) + # task = progress.add_task("-.---e--", total=len(train_loader)) + # progress.start() running_loss2 = 0.0 running_loss = 0.0 model.train() for batch_idx, (x, y) in enumerate(train_loader): - if batch_idx >= self.optuna_settings.n_train_batches: + if batch_idx >= self.optuna_settings._n_train_batches: break model.zero_grad(set_to_none=True) x, y = ( @@ -366,55 +400,56 @@ class HyperTraining: running_loss2 += loss_value running_loss += loss_value - if enable_progress: - progress.update(task, advance=1, description=f"{loss_value:.3e}") + # if enable_progress: + # progress.update(task, advance=1, description=f"{loss_value:.3e}") if writer is not None: if batch_idx % self.pytorch_settings.write_every == 0: writer.add_scalar( "training loss", - running_loss2 - / (self.pytorch_settings.write_every if batch_idx > 0 else 1), - epoch - * min(len(train_loader), self.optuna_settings.n_train_batches) - + batch_idx, + running_loss2 / (self.pytorch_settings.write_every if batch_idx > 0 else 1), + epoch * min(len(train_loader), self.optuna_settings._n_train_batches) + batch_idx, ) running_loss2 = 0.0 - if enable_progress: - progress.stop() + # if enable_progress: + # progress.stop() - return running_loss / min( - len(train_loader), self.optuna_settings.n_train_batches - ) + return running_loss / min(len(train_loader), self.optuna_settings._n_train_batches) def eval_model( - self, trial, model, valid_loader, epoch, writer=None, enable_progress=True + self, + trial, + model, + valid_loader, + epoch, + writer=None, + # enable_progress=True ): - if enable_progress: - progress = Progress( - TextColumn("[green]Evaluating..."), - TextColumn("Error: {task.description}"), - BarColumn(), - TaskProgressColumn(), - TextColumn("[green]Batch"), - MofNCompleteColumn(), - TimeRemainingColumn(), - TimeElapsedColumn(), - # description="Training", - transient=False, - console=self.console, - refresh_per_second=10, - ) - progress.start() - task = progress.add_task("-.---e--", total=len(valid_loader)) + # if enable_progress: + # progress = Progress( + # TextColumn("[green]Evaluating..."), + # TextColumn("Error: {task.description}"), + # BarColumn(), + # TaskProgressColumn(), + # TextColumn("[green]Batch"), + # MofNCompleteColumn(), + # TimeRemainingColumn(), + # TimeElapsedColumn(), + # # description="Training", + # transient=False, + # console=self.console, + # refresh_per_second=10, + # ) + # progress.start() + # task = progress.add_task("-.---e--", total=len(valid_loader)) model.eval() running_error = 0 running_error_2 = 0 with torch.no_grad(): for batch_idx, (x, y) in enumerate(valid_loader): - if batch_idx >= self.optuna_settings.n_valid_batches: + if batch_idx >= self.optuna_settings._n_valid_batches: break x, y = ( x.to(self.pytorch_settings.device), @@ -426,28 +461,19 @@ class HyperTraining: running_error += error_value running_error_2 += error_value - if enable_progress: - progress.update(task, advance=1, description=f"{error_value:.3e}") + # if enable_progress: + # progress.update(task, advance=1, description=f"{error_value:.3e}") if writer is not None: if batch_idx % self.pytorch_settings.write_every == 0: writer.add_scalar( "eval loss", - running_error_2 - / ( - self.pytorch_settings.write_every - if batch_idx > 0 - else 1 - ), - epoch - * min( - len(valid_loader), self.optuna_settings.n_valid_batches - ) - + batch_idx, + running_error_2 / (self.pytorch_settings.write_every if batch_idx > 0 else 1), + epoch * min(len(valid_loader), self.optuna_settings._n_valid_batches) + batch_idx, ) running_error_2 = 0.0 - running_error /= min(len(valid_loader), self.optuna_settings.n_valid_batches) + running_error /= min(len(valid_loader), self.optuna_settings._n_valid_batches) if writer is not None: title_append, subtitle = self.build_title(trial) @@ -463,8 +489,8 @@ class HyperTraining: epoch + 1, ) - if enable_progress: - progress.stop() + # if enable_progress: + # progress.stop() return running_error @@ -496,103 +522,94 @@ class HyperTraining: return ys, xs, y_preds def objective(self, trial: optuna.Trial, plot_before=False): + if self.stop_study: + trial.study.stop() model = None - exc = None - try: - # rprint(*list(self.study_name.split("_"))) - writer = self.setup_tb_writer( - self.optuna_settings.study_name, - f"{trial.number:0{len(str(self.optuna_settings.n_trials))}}", + writer = self.setup_tb_writer( + self.optuna_settings.study_name, + f"{trial.number:0{len(str(self.optuna_settings.n_trials)) + 2}}", + ) + + model = self.define_model(trial, writer) + + # n_nodes = trial.params.get("model_n_hidden_layers", self.model_settings.model_n_layers) * trial.params.get("model_hidden_dim", self.model_settings.unit_count) + + title_append, subtitle = self.build_title(trial) + + writer.add_figure( + "fiber response", + self.plot_model_response( + trial, + model=model, + title_append=title_append, + subtitle=subtitle, + show=plot_before, + ), + 0, + ) + + train_loader, valid_loader = self.get_sliced_data(trial) + + optimizer_name = trial.suggest_categorical_optional("optimizer", self.optimizer_settings.optimizer) + + lr = trial.suggest_float_optional("lr", self.optimizer_settings.learning_rate, log=True) + + optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr) + if self.optimizer_settings.scheduler is not None: + scheduler = getattr(optim.lr_scheduler, self.optimizer_settings.scheduler)( + optimizer, **self.optimizer_settings.scheduler_kwargs ) - model = self.define_model(trial, writer) - n_params = sum(p.numel() for p in model.parameters()) - # n_nodes = trial.params.get("model_n_layers", self.model_settings.model_n_layers) * trial.params.get("model_hidden_dim", self.model_settings.unit_count) - - title_append, subtitle = self.build_title(trial) - - writer.add_figure( - "fiber response", - self.plot_model_response( - trial, - model=model, - title_append=title_append, - subtitle=subtitle, - show=plot_before, - ), - 0, + for epoch in range(self.pytorch_settings.epochs): + trial.set_user_attr("epoch", epoch) + # enable_progress = self.optuna_settings.n_threads == 1 + # if enable_progress: + # self.console.rule( + # f"Epoch {epoch + 1}/{self.pytorch_settings.epochs}" + # ) + self.train_model( + trial, + model, + optimizer, + train_loader, + epoch, + writer, + # enable_progress=enable_progress, ) - - train_loader, valid_loader = self.get_sliced_data(trial) - - optimizer_name = force_suggest_categorical( - trial, "optimizer", self.optimizer_settings.optimizer + error = self.eval_model( + trial, + model, + valid_loader, + epoch, + writer, + # enable_progress=enable_progress, ) - - lr = force_suggest_float( - trial, "lr", self.optimizer_settings.learning_rate, log=True - ) - - optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr) if self.optimizer_settings.scheduler is not None: - scheduler = getattr(optim.lr_scheduler, self.optimizer_settings.scheduler)( - optimizer, **self.optimizer_settings.scheduler_kwargs) + scheduler.step(error) - for epoch in range(self.pytorch_settings.epochs): - enable_progress = self.optuna_settings.n_threads == 1 - if enable_progress: - self.console.rule( - f"Epoch {epoch + 1}/{self.pytorch_settings.epochs}" - ) - self.train_model( - trial, - model, - optimizer, - train_loader, - epoch, - writer, - enable_progress=enable_progress, - ) - error = self.eval_model( - trial, - model, - valid_loader, - epoch, - writer, - enable_progress=enable_progress, - ) - if self.optimizer_settings.scheduler is not None: - scheduler.step(error) - - writer.close() + trial.set_user_attr("mse", error) + trial.set_user_attr("log_mse", np.log10(error + np.finfo(float).eps)) + trial.set_user_attr("neg_mse", -error) + trial.set_user_attr("neg_log_mse", -np.log10(error + np.finfo(float).eps)) + if not self.optuna_settings._multi_objective: + trial.report(error, epoch) + if trial.should_prune(): + raise optuna.exceptions.TrialPruned() - if self.optuna_settings.multi_objective: - return n_params, error - trial.report(error, epoch) - if trial.should_prune(): - raise optuna.exceptions.TrialPruned() - return error - - except KeyboardInterrupt: - ... - # except Exception as e: - # exc = e - finally: - if model is not None: - save_path = ( - Path(self.pytorch_settings.model_dir) - / f"{self.optuna_settings.study_name}_{trial.number}.pth" - ) - save_path.parent.mkdir(parents=True, exist_ok=True) - torch.save(model, save_path) - if exc is not None: - raise exc - + writer.close() - def _plot_model_response_eye( - self, *signals, labels=None, sps=None, title_append="", subtitle="", show=True - ): + if self.optuna_settings._multi_objective: + return -np.log10(error + np.finfo(float).eps), trial.user_attrs.get("model_n_nodes", -1) + + if self.pytorch_settings.save_models and model is not None: + save_path = Path(self.pytorch_settings.model_dir) / f"{self.optuna_settings.study_name}_{trial.number}.pth" + save_path.parent.mkdir(parents=True, exist_ok=True) + torch.save(model, save_path) + + return error + + def _plot_model_response_eye(self, *signals, labels=None, sps=None, title_append="", subtitle="", show=True): if sps is None: raise ValueError("sps must be provided") if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))): @@ -608,9 +625,7 @@ class HyperTraining: labels = [f"signal {i + 1}" for i in range(len(signals))] fig, axs = plt.subplots(2, len(signals), sharex=True, sharey=True) - fig.suptitle( - f"Eye diagram{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}" - ) + fig.suptitle(f"Eye diagram{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}") xaxis = np.linspace(0, 2, 2 * sps, endpoint=False) for j, (label, signal) in enumerate(zip(labels, signals)): # signal = signal.cpu().numpy() @@ -628,9 +643,7 @@ class HyperTraining: if show: plt.show() - def _plot_model_response_head( - self, *signals, labels=None, sps=None, title_append="", subtitle="", show=True - ): + def _plot_model_response_head(self, *signals, labels=None, sps=None, title_append="", subtitle="", show=True): if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))): labels = [labels] else: @@ -644,16 +657,12 @@ class HyperTraining: labels = [f"signal {i + 1}" for i in range(len(signals))] fig, axs = plt.subplots(1, 2, sharex=True, sharey=True) - fig.set_size_inches(18,6) - fig.suptitle( - f"Fiber response{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}" - ) + fig.set_size_inches(18, 6) + fig.suptitle(f"Fiber response{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}") for i, ax in enumerate(axs): for signal, label in zip(signals, labels): if sps is not None: - xaxis = np.linspace( - 0, len(signal) / sps, len(signal), endpoint=False - ) + xaxis = np.linspace(0, len(signal) / sps, len(signal), endpoint=False) else: xaxis = np.arange(len(signal)) ax.plot(xaxis, np.abs(signal[:, i]) ** 2, label=label) @@ -678,10 +687,10 @@ class HyperTraining: self.data_settings.drop_first = 100 self.data_settings.shuffle = False self.data_settings.train_split = 1.0 - self.pytorch_settings.batchsize = self.pytorch_settings.eye_symbols if mode == "eye" else self.pytorch_settings.head_symbols - plot_loader, _ = self.get_sliced_data( - trial, override={"num_symbols": self.pytorch_settings.batchsize} + self.pytorch_settings.batchsize = ( + self.pytorch_settings.eye_symbols if mode == "eye" else self.pytorch_settings.head_symbols ) + plot_loader, _ = self.get_sliced_data(trial, override={"num_symbols": self.pytorch_settings.batchsize}) self.data_settings = data_settings_backup self.pytorch_settings = pytorch_settings_backup @@ -728,13 +737,22 @@ class HyperTraining: return fig @staticmethod - def build_title(trial): + def build_title(trial: optuna.trial.Trial): title_append = f"for trial {trial.number}" + model_n_layers = util.misc.multi_getattr((trial.params, trial.user_attrs), "model_n_layers", 0) + model_hidden_dims = [ + util.misc.multi_getattr((trial.params, trial.user_attrs), f"model_hidden_dim_{i}", 0) + for i in range(model_n_layers) + ] + model_activation_func = util.misc.multi_getattr( + (trial.params, trial.user_attrs), + "model_activation_func", + "unknown act. fun", + ) + model_dtype = util.misc.multi_getattr((trial.params, trial.user_attrs), "model_dtype", "unknown dtype") + subtitle = ( - f"{trial.params['model_n_layers']} layers, " - f"{', '.join([str(trial.params[f'model_hidden_dim_{i}']) for i in range(trial.params['model_n_layers'])])} units, " - f"{trial.params['model_activation_func']}, " - f"{trial.params['model_dtype']}" + f"{model_n_layers} layers à ({', '.join(model_hidden_dims)}) units, {model_activation_func}, {model_dtype}" ) return title_append, subtitle diff --git a/src/single-core-regen/hypertraining/settings.py b/src/single-core-regen/hypertraining/settings.py index 1fca6f2..20b34f3 100644 --- a/src/single-core-regen/hypertraining/settings.py +++ b/src/single-core-regen/hypertraining/settings.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field from datetime import datetime @@ -14,7 +14,7 @@ class DataSettings: config_path: str # = "data/*-128-16384-100000-0-0-17-0-PAM4-0.ini" dtype: tuple = ("complex64", "float64") symbols: tuple | float | int = 8 - model_input_dim: tuple | float | int = 64 + output_size: tuple | float | int = 64 shuffle: bool = True in_out_delay: float = 0 xy_delay: tuple | float | int = 0 @@ -33,6 +33,7 @@ class PytorchSettings: dataloader_workers: int = 2 dataloader_prefetch: int = 2 + save_models: bool = True model_dir: str = ".models" summary_dir: str = ".runs" @@ -45,11 +46,10 @@ class PytorchSettings: @dataclass class ModelSettings: output_dim: int = 2 - model_n_layers: tuple | int = 3 - unit_count: tuple | int = 8 - # n_units_range: tuple | int = (2, 32) - # activation_func_range: tuple = ("ModReLU", "ZReLU", "CReLU", "Mag", "Identity") - model_activation_func: tuple = ("ModReLU",) + n_hidden_layers: tuple | int = 3 + n_hidden_nodes: tuple | int = 8 + model_activation_func: tuple = "ModReLU" + overrides: dict = field(default_factory=dict) @dataclass @@ -60,18 +60,36 @@ class OptimizerSettings: scheduler_kwargs: dict | None = None +def _pruner_default_kwargs(): + # MedianPruner + return { + "n_startup_trials": 0, + "n_warmup_steps": 5, + } # optuna settings @dataclass class OptunaSettings: n_trials: int = 128 - n_threads: int = 4 - timeout: int = 600 + n_workers: int = 1 + timeout: int = None + pruner: str = "MedianPruner" + pruner_kwargs: dict = field(default_factory=_pruner_default_kwargs) directions: tuple = ("minimize",) metrics_names: tuple = ("mse",) limit_examples: bool = True - n_train_batches: int = 100 - n_valid_batches: int = 100 + n_train_batches: int = float("inf") + n_valid_batches: int = float("inf") storage: str = "sqlite:///example.db" study_name: str = ( f"optuna_study_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}" - ) \ No newline at end of file + ) + n_trials_filter: tuple|list = None, #(optuna.trial.TrialState.COMPLETE,) + remove_outliers: float|int = None + ## reserved, set by HyperTraining + _multi_objective = None + _parallel = None + _n_threads = None + _directions = None + _direction = None + _n_train_batches = None + _n_valid_batches = None diff --git a/src/single-core-regen/util/complexNN.py b/src/single-core-regen/util/complexNN.py index 972bafb..3a42fdb 100644 --- a/src/single-core-regen/util/complexNN.py +++ b/src/single-core-regen/util/complexNN.py @@ -8,10 +8,7 @@ def complex_mse_loss(input, target): Compute the mean squared error between two complex tensors. """ if input.is_complex(): - return torch.mean( - torch.square(input.real - target.real) - + torch.square(input.imag - target.imag) - ) + return torch.mean(torch.square(input.real - target.real) + torch.square(input.imag - target.imag)) else: return F.mse_loss(input, target) @@ -21,10 +18,7 @@ def complex_sse_loss(input, target): Compute the sum squared error between two complex tensors. """ if input.is_complex(): - return torch.sum( - torch.square(input.real - target.real) - + torch.square(input.imag - target.imag) - ) + return torch.sum(torch.square(input.real - target.real) + torch.square(input.imag - target.imag)) else: return torch.sum(torch.square(input - target)) @@ -41,19 +35,20 @@ class UnitaryLayer(nn.Module): def reset_parameters(self): q, _ = torch.linalg.qr(self.weight) self.weight.data = q - + def forward(self, x): return torch.matmul(x, self.weight) - + def __repr__(self): return f"UnitaryLayer({self.in_features}, {self.out_features})" - + + class SemiUnitaryLayer(nn.Module): def __init__(self, input_dim, output_dim, dtype=None): super(SemiUnitaryLayer, self).__init__() self.input_dim = input_dim self.output_dim = output_dim - + # Create a larger square matrix for QR decomposition self.weight = nn.Parameter(torch.randn(max(input_dim, output_dim), max(input_dim, output_dim), dtype=dtype)) self.reset_parameters() @@ -62,14 +57,14 @@ class SemiUnitaryLayer(nn.Module): # Ensure the weights are semi-unitary by QR decomposition q, _ = torch.linalg.qr(self.weight) if self.input_dim > self.output_dim: - self.weight.data = q[:self.input_dim, :self.output_dim] + self.weight.data = q[: self.input_dim, : self.output_dim] else: - self.weight.data = q[:self.output_dim, :self.input_dim].t() + self.weight.data = q[: self.output_dim, : self.input_dim].t() def forward(self, x): out = torch.matmul(x, self.weight) return out - + def __repr__(self): return f"SemiUnitaryLayer({self.input_dim}, {self.output_dim})" diff --git a/src/single-core-regen/util/datasets.py b/src/single-core-regen/util/datasets.py index 9b2953e..0aac64e 100644 --- a/src/single-core-regen/util/datasets.py +++ b/src/single-core-regen/util/datasets.py @@ -1,6 +1,7 @@ from pathlib import Path import torch from torch.utils.data import Dataset + # from torch.utils.data import Sampler import numpy as np import configparser @@ -22,6 +23,7 @@ import configparser # def __len__(self): # return len(self.indices) + def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=False, device=None, dtype=None): filepath = Path(config_path) filepath = filepath.parent.glob(filepath.name) @@ -43,18 +45,19 @@ def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=Fals if normalize: # square gets normalized to 1, as the power is (proportional to) the square of the amplitude a, b, c, d = np.square(data.T) - a, b, c, d = a/np.max(np.abs(a)), b/np.max(np.abs(b)), c/np.max(np.abs(c)), d/np.max(np.abs(d)) + a, b, c, d = a / np.max(np.abs(a)), b / np.max(np.abs(b)), c / np.max(np.abs(c)), d / np.max(np.abs(d)) data = np.sqrt(np.array([a, b, c, d]).T) if real: data = np.abs(data) - + config["glova"]["nos"] = str(symbols) data = torch.tensor(data, device=device, dtype=dtype) return data, config + def roll_along(arr, shifts, dim): # https://stackoverflow.com/a/76920720 # (c) Mateen Ulhaq, 2023 @@ -67,6 +70,7 @@ def roll_along(arr, shifts, dim): indices = (dim_indices - shifts.unsqueeze(dim)) % arr.shape[dim] return torch.gather(arr, dim, indices) + class FiberRegenerationDataset(Dataset): """ Dataset for fiber regeneration training. @@ -105,7 +109,7 @@ class FiberRegenerationDataset(Dataset): drop_first: float | int = 0, dtype: torch.dtype = None, real: bool = False, - device = None, + device=None, **kwargs, ): """ @@ -127,18 +131,10 @@ class FiberRegenerationDataset(Dataset): # check types assert isinstance(file_path, str), "file_path must be a string" - assert isinstance(symbols, (float, int)), ( - "symbols must be a float or an integer" - ) - assert output_dim is None or isinstance(output_dim, int), ( - "output_len must be an integer" - ) - assert isinstance(target_delay, (float, int)), ( - "target_delay must be a float or an integer" - ) - assert isinstance(xy_delay, (float, int)), ( - "xy_delay must be a float or an integer" - ) + assert isinstance(symbols, (float, int)), "symbols must be a float or an integer" + assert output_dim is None or isinstance(output_dim, int), "output_len must be an integer" + assert isinstance(target_delay, (float, int)), "target_delay must be a float or an integer" + assert isinstance(xy_delay, (float, int)), "xy_delay must be a float or an integer" assert isinstance(drop_first, int), "drop_first must be an integer" # check values @@ -159,10 +155,18 @@ class FiberRegenerationDataset(Dataset): "glova": {"sps": 128}, } else: - data_raw, self.config = load_data(file_path, skipfirst=drop_first, symbols=kwargs.pop("num_symbols", None), real=real, normalize=True, device=device, dtype=dtype) + data_raw, self.config = load_data( + file_path, + skipfirst=drop_first, + symbols=kwargs.pop("num_symbols", None), + real=real, + normalize=True, + device=device, + dtype=dtype, + ) self.device = data_raw.device - + self.samples_per_symbol = int(self.config["glova"]["sps"]) self.samples_per_slice = int(symbols * self.samples_per_symbol) self.symbols_per_slice = self.samples_per_slice / self.samples_per_symbol @@ -180,9 +184,7 @@ class FiberRegenerationDataset(Dataset): else int(self.target_delay * self.samples_per_symbol) ) self.xy_delay_samples = ( - ovrd_xy_delay_samples - if ovrd_xy_delay_samples is not None - else int(self.xy_delay * self.samples_per_symbol) + ovrd_xy_delay_samples if ovrd_xy_delay_samples is not None else int(self.xy_delay * self.samples_per_symbol) ) # data_raw = torch.tensor(data_raw, dtype=dtype) @@ -190,15 +192,15 @@ class FiberRegenerationDataset(Dataset): # data layout # [ [E_in_x0, E_in_y0, E_out_x0, E_out_y0], # [E_in_x1, E_in_y1, E_out_x1, E_out_y1], - # ... + # ... # [E_in_xN, E_in_yN, E_out_xN, E_out_yN] ] data_raw = data_raw.transpose(0, 1) # data layout - # [ E_in_x[0:N], - # E_in_y[0:N], - # E_out_x[0:N], + # [ E_in_x[0:N], + # E_in_y[0:N], + # E_out_x[0:N], # E_out_y[0:N] ] # shift x data by xy_delay_samples relative to the y data (example value: 3) @@ -208,9 +210,7 @@ class FiberRegenerationDataset(Dataset): # E_out_y[0:N] ] E_out_y[-3:N-3] ] E_out_y[0:N-3] ] if self.xy_delay_samples != 0: - data_raw = roll_along( - data_raw, [0, self.xy_delay_samples, 0, self.xy_delay_samples], dim=1 - ) + data_raw = roll_along(data_raw, [0, self.xy_delay_samples, 0, self.xy_delay_samples], dim=1) if self.xy_delay_samples > 0: data_raw = data_raw[:, self.xy_delay_samples :] elif self.xy_delay_samples < 0: