Compare commits
3 Commits
80e9a3379e
...
0422c81f3b
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0422c81f3b | ||
|
|
7343ccb3a5 | ||
|
|
9a16a5637d |
@@ -1,3 +1,3 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:7231dea2c9107f443de9122fdc971d9ce6df93db2ee27a9d68a5e22c986373eb
|
oid sha256:f3510d41f9f0605e438a09767c43edda38162601292be1207f50747117ae5479
|
||||||
size 937984
|
size 9863168
|
||||||
|
|||||||
@@ -6,7 +6,8 @@ import matplotlib.pyplot as plt
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import optuna
|
import optuna
|
||||||
import optunahub
|
|
||||||
|
# import optunahub
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -18,26 +19,28 @@ import torch.utils.data
|
|||||||
|
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from rich.progress import (
|
# from rich.progress import (
|
||||||
Progress,
|
# Progress,
|
||||||
TextColumn,
|
# TextColumn,
|
||||||
BarColumn,
|
# BarColumn,
|
||||||
TaskProgressColumn,
|
# TaskProgressColumn,
|
||||||
TimeRemainingColumn,
|
# TimeRemainingColumn,
|
||||||
MofNCompleteColumn,
|
# MofNCompleteColumn,
|
||||||
TimeElapsedColumn,
|
# TimeElapsedColumn,
|
||||||
)
|
# )
|
||||||
from rich.console import Console
|
# from rich.console import Console
|
||||||
# from rich import print as rprint
|
# from rich import print as rprint
|
||||||
|
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
|
|
||||||
from util.datasets import FiberRegenerationDataset
|
from util.datasets import FiberRegenerationDataset
|
||||||
from util.optuna_helpers import (
|
|
||||||
force_suggest_categorical,
|
# from util.optuna_helpers import (
|
||||||
force_suggest_float,
|
# suggest_categorical_optional, # noqa: F401
|
||||||
force_suggest_int,
|
# suggest_float_optional, # noqa: F401
|
||||||
)
|
# suggest_int_optional, # noqa: F401
|
||||||
|
# )
|
||||||
|
from util.optuna_helpers import install_optional_suggests
|
||||||
import util
|
import util
|
||||||
|
|
||||||
from .settings import (
|
from .settings import (
|
||||||
@@ -49,6 +52,8 @@ from .settings import (
|
|||||||
PytorchSettings,
|
PytorchSettings,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
install_optional_suggests()
|
||||||
|
|
||||||
|
|
||||||
class HyperTraining:
|
class HyperTraining:
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -60,7 +65,7 @@ class HyperTraining:
|
|||||||
model_settings,
|
model_settings,
|
||||||
optimizer_settings,
|
optimizer_settings,
|
||||||
optuna_settings,
|
optuna_settings,
|
||||||
console=None,
|
# console=None,
|
||||||
):
|
):
|
||||||
self.global_settings: GlobalSettings = global_settings
|
self.global_settings: GlobalSettings = global_settings
|
||||||
self.data_settings: DataSettings = data_settings
|
self.data_settings: DataSettings = data_settings
|
||||||
@@ -68,18 +73,16 @@ class HyperTraining:
|
|||||||
self.model_settings: ModelSettings = model_settings
|
self.model_settings: ModelSettings = model_settings
|
||||||
self.optimizer_settings: OptimizerSettings = optimizer_settings
|
self.optimizer_settings: OptimizerSettings = optimizer_settings
|
||||||
self.optuna_settings: OptunaSettings = optuna_settings
|
self.optuna_settings: OptunaSettings = optuna_settings
|
||||||
|
self.processes = None
|
||||||
|
|
||||||
self.console = console or Console()
|
# self.console = console or Console()
|
||||||
|
|
||||||
# set some extra settings to make the code more readable
|
# set some extra settings to make the code more readable
|
||||||
self._extra_optuna_settings()
|
self._extra_optuna_settings()
|
||||||
|
self.stop_study = True
|
||||||
|
|
||||||
def setup_tb_writer(self, study_name=None, append=None):
|
def setup_tb_writer(self, study_name=None, append=None):
|
||||||
log_dir = (
|
log_dir = self.pytorch_settings.summary_dir + "/" + (study_name or self.optuna_settings.study_name)
|
||||||
self.pytorch_settings.summary_dir
|
|
||||||
+ "/"
|
|
||||||
+ (study_name or self.optuna_settings.study_name)
|
|
||||||
)
|
|
||||||
if append is not None:
|
if append is not None:
|
||||||
log_dir += "_" + str(append)
|
log_dir += "_" + str(append)
|
||||||
|
|
||||||
@@ -89,180 +92,211 @@ class HyperTraining:
|
|||||||
study_name = self.get_latest_study()
|
study_name = self.get_latest_study()
|
||||||
|
|
||||||
if study_name:
|
if study_name:
|
||||||
print(f"Resuming study: {study_name}")
|
if verbose:
|
||||||
|
print(f"Resuming study: {study_name}")
|
||||||
self.optuna_settings.study_name = study_name
|
self.optuna_settings.study_name = study_name
|
||||||
|
|
||||||
def get_latest_study(self, verbose=True):
|
def get_latest_study(self, verbose=False) -> optuna.Study:
|
||||||
studies = self.get_studies()
|
studies = self.get_studies()
|
||||||
|
study = None
|
||||||
for study in studies:
|
for study in studies:
|
||||||
study.datetime_start = study.datetime_start or datetime.min
|
study.datetime_start = study.datetime_start or datetime.min
|
||||||
if studies:
|
if studies:
|
||||||
study = sorted(studies, key=lambda x: x.datetime_start, reverse=True)[0]
|
study = sorted(studies, key=lambda x: x.datetime_start, reverse=True)[0]
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f"Last study: {study.study_name}")
|
print(f"Last study: {study.study_name}")
|
||||||
study_name = study.study_name
|
|
||||||
else:
|
else:
|
||||||
if verbose:
|
if verbose:
|
||||||
print("No previous studies found")
|
print("No previous studies found")
|
||||||
study_name = None
|
return optuna.load_study(study_name=study.study_name, storage=self.optuna_settings.storage)
|
||||||
return study_name
|
|
||||||
|
# def study(self) -> optuna.Study:
|
||||||
|
# return optuna.load_study(self.optuna_settings.study_name, storage=self.optuna_settings.storage)
|
||||||
|
|
||||||
def get_studies(self):
|
def get_studies(self):
|
||||||
return optuna.get_all_study_summaries(storage=self.optuna_settings.storage)
|
return optuna.get_all_study_summaries(storage=self.optuna_settings.storage)
|
||||||
|
|
||||||
def setup_study(self):
|
def setup_study(self):
|
||||||
module = optunahub.load_module(package="samplers/auto_sampler")
|
# module = optunahub.load_module(package="samplers/auto_sampler")
|
||||||
|
if self.optuna_settings._parallel:
|
||||||
|
self.processes = []
|
||||||
|
|
||||||
|
pruner = getattr(optuna.pruners, self.optuna_settings.pruner, None)
|
||||||
|
|
||||||
|
if pruner and self.optuna_settings.pruner_kwargs is not None:
|
||||||
|
pruner = pruner(**self.optuna_settings.pruner_kwargs)
|
||||||
|
elif pruner:
|
||||||
|
pruner = pruner()
|
||||||
|
|
||||||
self.study = optuna.create_study(
|
self.study = optuna.create_study(
|
||||||
study_name=self.optuna_settings.study_name,
|
study_name=self.optuna_settings.study_name,
|
||||||
storage=self.optuna_settings.storage,
|
storage=self.optuna_settings.storage,
|
||||||
load_if_exists=True,
|
load_if_exists=True,
|
||||||
direction=self.optuna_settings.direction,
|
direction=self.optuna_settings._direction,
|
||||||
directions=self.optuna_settings.directions,
|
directions=self.optuna_settings._directions,
|
||||||
sampler=module.AutoSampler(),
|
pruner=pruner,
|
||||||
|
# sampler=module.AutoSampler(),
|
||||||
)
|
)
|
||||||
|
|
||||||
print("using sampler:", self.study.sampler)
|
# print("using sampler:", self.study.sampler)
|
||||||
|
|
||||||
with warnings.catch_warnings(action="ignore"):
|
with warnings.catch_warnings(action="ignore"):
|
||||||
self.study.set_metric_names(self.optuna_settings.metrics_names)
|
self.study.set_metric_names(self.optuna_settings.metrics_names)
|
||||||
|
|
||||||
self.n_threads = min(
|
|
||||||
self.optuna_settings.n_trials, self.optuna_settings.n_threads
|
|
||||||
)
|
|
||||||
self.processes = []
|
|
||||||
if self.n_threads > 1:
|
|
||||||
for _ in range(self.n_threads):
|
|
||||||
p = multiprocessing.Process(
|
|
||||||
# target=lambda n_trials: self._run_optimize(self, n_trials),
|
|
||||||
target=self._run_optimize,
|
|
||||||
args=(self.optuna_settings.n_trials // self.n_threads,),
|
|
||||||
)
|
|
||||||
self.processes.append(p)
|
|
||||||
|
|
||||||
# def plot_eye(self, width=2, symbols=None, alpha=None, complex=False, show=True):
|
|
||||||
# data, config = util.datasets.load_data(
|
|
||||||
# self.data_settings.config_path,
|
|
||||||
# skipfirst=10,
|
|
||||||
# symbols=symbols or 1000,
|
|
||||||
# real=not complex,
|
|
||||||
# normalize=True,
|
|
||||||
# )
|
|
||||||
# eye_data = {"data": data.numpy(), "sps": int(config["glova"]["sps"])}
|
|
||||||
# return util.plot.eye(
|
|
||||||
# **eye_data,
|
|
||||||
# width=width,
|
|
||||||
# show=show,
|
|
||||||
# alpha=alpha,
|
|
||||||
# complex=complex,
|
|
||||||
# symbols=symbols or 1000,
|
|
||||||
# skipfirst=0,
|
|
||||||
# )
|
|
||||||
|
|
||||||
def run_study(self):
|
def run_study(self):
|
||||||
if self.processes:
|
try:
|
||||||
for p in self.processes:
|
if self.optuna_settings._parallel:
|
||||||
p.start()
|
self._run_parallel_study()
|
||||||
for p in self.processes:
|
else:
|
||||||
p.join()
|
self._run_study()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("Stopping. Please wait for the processes to finish.")
|
||||||
|
self.stop_study = True
|
||||||
|
|
||||||
remaining_trials = self.optuna_settings.n_trials % self.n_threads
|
def trials_left(self):
|
||||||
else:
|
return self.optuna_settings.n_trials - len(self.study.get_trials(states=self.optuna_settings.n_trials_filter))
|
||||||
remaining_trials = self.optuna_settings.n_trials
|
|
||||||
|
|
||||||
if remaining_trials:
|
def remove_completed_processes(self):
|
||||||
self._run_optimize(remaining_trials)
|
if self.processes is None:
|
||||||
|
return
|
||||||
|
for p, process in enumerate(self.processes):
|
||||||
|
if not process.is_alive():
|
||||||
|
process.join()
|
||||||
|
self.processes.pop(p)
|
||||||
|
|
||||||
def _run_optimize(self, n_trials):
|
def remove_outliers(self):
|
||||||
|
if self.optuna_settings.remove_outliers is not None:
|
||||||
|
trials = self.study.get_trials(states=(optuna.trial.TrialState.COMPLETE,))
|
||||||
|
if len(trials) == 0:
|
||||||
|
return
|
||||||
|
vals = [trial.value for trial in trials]
|
||||||
|
vals = np.log(vals)
|
||||||
|
mean = np.mean(vals)
|
||||||
|
std = np.std(vals)
|
||||||
|
outliers = [
|
||||||
|
trial for trial in trials if np.log(trial.value) > mean + self.optuna_settings.remove_outliers * std
|
||||||
|
]
|
||||||
|
for trial in outliers:
|
||||||
|
trial: optuna.trial.Trial = trial
|
||||||
|
trial.state = optuna.trial.TrialState.FAIL
|
||||||
|
trial.set_user_attr("outlier", True)
|
||||||
|
|
||||||
|
def _run_study(self):
|
||||||
|
while trials_left := self.trials_left():
|
||||||
|
self.remove_outliers()
|
||||||
|
self._run_optimize(n_trials=trials_left, timeout=self.optuna_settings.timeout)
|
||||||
|
|
||||||
|
def _run_parallel_study(self):
|
||||||
|
while trials_left := self.trials_left():
|
||||||
|
self.remove_outliers()
|
||||||
|
self.remove_completed_processes()
|
||||||
|
|
||||||
|
n_trials = max(trials_left, self.optuna_settings._n_threads) // self.optuna_settings._n_threads
|
||||||
|
|
||||||
|
def target_fun():
|
||||||
|
self._run_optimize(n_trials=n_trials, timeout=self.optuna_settings.timeout)
|
||||||
|
|
||||||
|
for _ in range(self.optuna_settings._n_threads - len(self.processes)):
|
||||||
|
self.processes.append(multiprocessing.Process(target=target_fun))
|
||||||
|
self.processes[-1].start()
|
||||||
|
|
||||||
|
def _run_optimize(self, **kwargs):
|
||||||
self.study.optimize(
|
self.study.optimize(
|
||||||
self.objective, n_trials=n_trials, timeout=self.optuna_settings.timeout
|
self.objective,
|
||||||
|
**kwargs,
|
||||||
|
show_progress_bar=not self.optuna_settings._parallel,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _extra_optuna_settings(self):
|
def _extra_optuna_settings(self):
|
||||||
self.optuna_settings.multi_objective = len(self.optuna_settings.directions) > 1
|
self.optuna_settings._multi_objective = len(self.optuna_settings.directions) > 1
|
||||||
if self.optuna_settings.multi_objective:
|
|
||||||
self.optuna_settings.direction = None
|
|
||||||
else:
|
|
||||||
self.optuna_settings.direction = self.optuna_settings.directions[0]
|
|
||||||
self.optuna_settings.directions = None
|
|
||||||
|
|
||||||
self.optuna_settings.n_train_batches = (
|
if self.optuna_settings._multi_objective:
|
||||||
self.optuna_settings.n_train_batches
|
self.optuna_settings._direction = None
|
||||||
if self.optuna_settings.limit_examples
|
self.optuna_settings._directions = self.optuna_settings.directions
|
||||||
else float("inf")
|
else:
|
||||||
|
self.optuna_settings._direction = self.optuna_settings.directions[0]
|
||||||
|
self.optuna_settings._directions = None
|
||||||
|
|
||||||
|
self.optuna_settings._n_train_batches = (
|
||||||
|
self.optuna_settings.n_train_batches if self.optuna_settings.limit_examples else float("inf")
|
||||||
)
|
)
|
||||||
self.optuna_settings.n_valid_batches = (
|
self.optuna_settings._n_valid_batches = (
|
||||||
self.optuna_settings.n_valid_batches
|
self.optuna_settings.n_valid_batches if self.optuna_settings.limit_examples else float("inf")
|
||||||
if self.optuna_settings.limit_examples
|
|
||||||
else float("inf")
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.optuna_settings._n_threads = self.optuna_settings.n_workers
|
||||||
|
|
||||||
|
self.optuna_settings._parallel = self.optuna_settings._n_threads > 1
|
||||||
|
|
||||||
def define_model(self, trial: optuna.Trial, writer=None):
|
def define_model(self, trial: optuna.Trial, writer=None):
|
||||||
n_layers = force_suggest_int(
|
n_layers = trial.suggest_int_optional("model_n_hidden_layers", self.model_settings.n_hidden_layers)
|
||||||
trial, "model_n_layers", self.model_settings.model_n_layers
|
|
||||||
)
|
|
||||||
|
|
||||||
input_dim = 2 * trial.params.get(
|
input_dim = trial.suggest_int_optional(
|
||||||
"model_input_dim",
|
"model_input_dim",
|
||||||
force_suggest_int(trial, "model_input_dim", self.data_settings.model_input_dim),
|
self.data_settings.output_size,
|
||||||
|
step=2,
|
||||||
|
multiply=2,
|
||||||
|
set_new=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
dtype = trial.params.get(
|
# trial.set_user_attr("model_input_dim", input_dim)
|
||||||
"model_dtype",
|
|
||||||
force_suggest_categorical(trial, "model_dtype", self.data_settings.dtype),
|
dtype = trial.suggest_categorical_optional("model_dtype", self.data_settings.dtype, set_new=False)
|
||||||
)
|
|
||||||
dtype = getattr(torch, dtype)
|
dtype = getattr(torch, dtype)
|
||||||
|
|
||||||
afunc = force_suggest_categorical(
|
afunc = trial.suggest_categorical_optional("model_activation_func", self.model_settings.model_activation_func)
|
||||||
trial, "model_activation_func", self.model_settings.model_activation_func
|
|
||||||
)
|
|
||||||
|
|
||||||
layers = []
|
layers = []
|
||||||
last_dim = input_dim
|
last_dim = input_dim
|
||||||
|
n_nodes = last_dim
|
||||||
for i in range(n_layers):
|
for i in range(n_layers):
|
||||||
hidden_dim = force_suggest_int(
|
if hidden_dim_override := self.model_settings.overrides.get(f"n_hidden_nodes_{i}", False):
|
||||||
trial, f"model_hidden_dim_{i}", self.model_settings.unit_count
|
hidden_dim = trial.suggest_int_optional(f"model_hidden_dim_{i}", hidden_dim_override, force=True)
|
||||||
)
|
else:
|
||||||
layers.append(
|
hidden_dim = trial.suggest_int_optional(
|
||||||
util.complexNN.SemiUnitaryLayer(last_dim, hidden_dim, dtype=dtype)
|
f"model_hidden_dim_{i}",
|
||||||
)
|
self.model_settings.n_hidden_nodes,
|
||||||
|
# step=2,
|
||||||
|
)
|
||||||
|
layers.append(util.complexNN.SemiUnitaryLayer(last_dim, hidden_dim, dtype=dtype))
|
||||||
last_dim = hidden_dim
|
last_dim = hidden_dim
|
||||||
layers.append(getattr(util.complexNN, afunc)())
|
layers.append(getattr(util.complexNN, afunc)())
|
||||||
|
n_nodes += last_dim
|
||||||
|
|
||||||
layers.append(
|
layers.append(util.complexNN.SemiUnitaryLayer(last_dim, self.model_settings.output_dim, dtype=dtype))
|
||||||
util.complexNN.UnitaryLayer(
|
|
||||||
hidden_dim, self.model_settings.output_dim, dtype=dtype
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
model = nn.Sequential(*layers)
|
model = nn.Sequential(*layers)
|
||||||
|
|
||||||
if writer is not None:
|
if writer is not None:
|
||||||
writer.add_graph(
|
writer.add_graph(model, torch.zeros(1, input_dim, dtype=dtype), use_strict_trace=False)
|
||||||
model, torch.zeros(1, input_dim, dtype=dtype), use_strict_trace=False
|
|
||||||
)
|
n_params = sum(p.numel() for p in model.parameters())
|
||||||
|
trial.set_user_attr("model_n_params", n_params)
|
||||||
|
trial.set_user_attr("model_n_nodes", n_nodes)
|
||||||
|
|
||||||
return model.to(self.pytorch_settings.device)
|
return model.to(self.pytorch_settings.device)
|
||||||
|
|
||||||
def get_sliced_data(self, trial: optuna.Trial, override=None):
|
def get_sliced_data(self, trial: optuna.Trial, override=None):
|
||||||
symbols = trial.params.get(
|
symbols = trial.suggest_float_optional("dataset_symbols", self.data_settings.symbols, set_new=False)
|
||||||
"dataset_symbols",
|
|
||||||
force_suggest_float(trial, "dataset_symbols", self.data_settings.symbols),
|
in_out_delay = trial.suggest_float_optional(
|
||||||
|
"dataset_in_out_delay", self.data_settings.in_out_delay, set_new=False
|
||||||
)
|
)
|
||||||
|
|
||||||
xy_delay = trial.params.get(
|
xy_delay = trial.suggest_float_optional("dataset_xy_delay", self.data_settings.xy_delay, set_new=False)
|
||||||
"dataset_xy_delay",
|
|
||||||
force_suggest_float(trial, "dataset_xy_delay", self.data_settings.xy_delay),
|
data_size = int(
|
||||||
|
0.5
|
||||||
|
* trial.suggest_int_optional(
|
||||||
|
"model_input_dim",
|
||||||
|
self.data_settings.output_size,
|
||||||
|
step=2,
|
||||||
|
multiply=2,
|
||||||
|
set_new=False,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
data_size = trial.params.get(
|
dtype = trial.suggest_categorical_optional("model_dtype", self.data_settings.dtype, set_new=False)
|
||||||
"model_input_dim",
|
|
||||||
force_suggest_int(trial, "model_input_dim", self.data_settings.model_input_dim),
|
|
||||||
)
|
|
||||||
|
|
||||||
dtype = trial.params.get(
|
|
||||||
"model_dtype",
|
|
||||||
force_suggest_categorical(trial, "model_dtype", self.data_settings.dtype),
|
|
||||||
)
|
|
||||||
dtype = getattr(torch, dtype)
|
dtype = getattr(torch, dtype)
|
||||||
|
|
||||||
num_symbols = None
|
num_symbols = None
|
||||||
@@ -273,7 +307,7 @@ class HyperTraining:
|
|||||||
file_path=self.data_settings.config_path,
|
file_path=self.data_settings.config_path,
|
||||||
symbols=symbols,
|
symbols=symbols,
|
||||||
output_dim=data_size,
|
output_dim=data_size,
|
||||||
target_delay=self.data_settings.in_out_delay,
|
target_delay=in_out_delay,
|
||||||
xy_delay=xy_delay,
|
xy_delay=xy_delay,
|
||||||
drop_first=self.data_settings.drop_first,
|
drop_first=self.data_settings.drop_first,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
@@ -327,31 +361,31 @@ class HyperTraining:
|
|||||||
train_loader,
|
train_loader,
|
||||||
epoch,
|
epoch,
|
||||||
writer=None,
|
writer=None,
|
||||||
enable_progress=False,
|
# enable_progress=False,
|
||||||
):
|
):
|
||||||
if enable_progress:
|
# if enable_progress:
|
||||||
progress = Progress(
|
# progress = Progress(
|
||||||
TextColumn("[yellow] Training..."),
|
# TextColumn("[yellow] Training..."),
|
||||||
TextColumn("Error: {task.description}"),
|
# TextColumn("Error: {task.description}"),
|
||||||
BarColumn(),
|
# BarColumn(),
|
||||||
TaskProgressColumn(),
|
# TaskProgressColumn(),
|
||||||
TextColumn("[green]Batch"),
|
# TextColumn("[green]Batch"),
|
||||||
MofNCompleteColumn(),
|
# MofNCompleteColumn(),
|
||||||
TimeRemainingColumn(),
|
# TimeRemainingColumn(),
|
||||||
TimeElapsedColumn(),
|
# TimeElapsedColumn(),
|
||||||
# description="Training",
|
# # description="Training",
|
||||||
transient=False,
|
# transient=False,
|
||||||
console=self.console,
|
# console=self.console,
|
||||||
refresh_per_second=10,
|
# refresh_per_second=10,
|
||||||
)
|
# )
|
||||||
task = progress.add_task("-.---e--", total=len(train_loader))
|
# task = progress.add_task("-.---e--", total=len(train_loader))
|
||||||
progress.start()
|
# progress.start()
|
||||||
|
|
||||||
running_loss2 = 0.0
|
running_loss2 = 0.0
|
||||||
running_loss = 0.0
|
running_loss = 0.0
|
||||||
model.train()
|
model.train()
|
||||||
for batch_idx, (x, y) in enumerate(train_loader):
|
for batch_idx, (x, y) in enumerate(train_loader):
|
||||||
if batch_idx >= self.optuna_settings.n_train_batches:
|
if batch_idx >= self.optuna_settings._n_train_batches:
|
||||||
break
|
break
|
||||||
model.zero_grad(set_to_none=True)
|
model.zero_grad(set_to_none=True)
|
||||||
x, y = (
|
x, y = (
|
||||||
@@ -366,55 +400,56 @@ class HyperTraining:
|
|||||||
running_loss2 += loss_value
|
running_loss2 += loss_value
|
||||||
running_loss += loss_value
|
running_loss += loss_value
|
||||||
|
|
||||||
if enable_progress:
|
# if enable_progress:
|
||||||
progress.update(task, advance=1, description=f"{loss_value:.3e}")
|
# progress.update(task, advance=1, description=f"{loss_value:.3e}")
|
||||||
|
|
||||||
if writer is not None:
|
if writer is not None:
|
||||||
if batch_idx % self.pytorch_settings.write_every == 0:
|
if batch_idx % self.pytorch_settings.write_every == 0:
|
||||||
writer.add_scalar(
|
writer.add_scalar(
|
||||||
"training loss",
|
"training loss",
|
||||||
running_loss2
|
running_loss2 / (self.pytorch_settings.write_every if batch_idx > 0 else 1),
|
||||||
/ (self.pytorch_settings.write_every if batch_idx > 0 else 1),
|
epoch * min(len(train_loader), self.optuna_settings._n_train_batches) + batch_idx,
|
||||||
epoch
|
|
||||||
* min(len(train_loader), self.optuna_settings.n_train_batches)
|
|
||||||
+ batch_idx,
|
|
||||||
)
|
)
|
||||||
running_loss2 = 0.0
|
running_loss2 = 0.0
|
||||||
|
|
||||||
if enable_progress:
|
# if enable_progress:
|
||||||
progress.stop()
|
# progress.stop()
|
||||||
|
|
||||||
return running_loss / min(
|
return running_loss / min(len(train_loader), self.optuna_settings._n_train_batches)
|
||||||
len(train_loader), self.optuna_settings.n_train_batches
|
|
||||||
)
|
|
||||||
|
|
||||||
def eval_model(
|
def eval_model(
|
||||||
self, trial, model, valid_loader, epoch, writer=None, enable_progress=True
|
self,
|
||||||
|
trial,
|
||||||
|
model,
|
||||||
|
valid_loader,
|
||||||
|
epoch,
|
||||||
|
writer=None,
|
||||||
|
# enable_progress=True
|
||||||
):
|
):
|
||||||
if enable_progress:
|
# if enable_progress:
|
||||||
progress = Progress(
|
# progress = Progress(
|
||||||
TextColumn("[green]Evaluating..."),
|
# TextColumn("[green]Evaluating..."),
|
||||||
TextColumn("Error: {task.description}"),
|
# TextColumn("Error: {task.description}"),
|
||||||
BarColumn(),
|
# BarColumn(),
|
||||||
TaskProgressColumn(),
|
# TaskProgressColumn(),
|
||||||
TextColumn("[green]Batch"),
|
# TextColumn("[green]Batch"),
|
||||||
MofNCompleteColumn(),
|
# MofNCompleteColumn(),
|
||||||
TimeRemainingColumn(),
|
# TimeRemainingColumn(),
|
||||||
TimeElapsedColumn(),
|
# TimeElapsedColumn(),
|
||||||
# description="Training",
|
# # description="Training",
|
||||||
transient=False,
|
# transient=False,
|
||||||
console=self.console,
|
# console=self.console,
|
||||||
refresh_per_second=10,
|
# refresh_per_second=10,
|
||||||
)
|
# )
|
||||||
progress.start()
|
# progress.start()
|
||||||
task = progress.add_task("-.---e--", total=len(valid_loader))
|
# task = progress.add_task("-.---e--", total=len(valid_loader))
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
running_error = 0
|
running_error = 0
|
||||||
running_error_2 = 0
|
running_error_2 = 0
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for batch_idx, (x, y) in enumerate(valid_loader):
|
for batch_idx, (x, y) in enumerate(valid_loader):
|
||||||
if batch_idx >= self.optuna_settings.n_valid_batches:
|
if batch_idx >= self.optuna_settings._n_valid_batches:
|
||||||
break
|
break
|
||||||
x, y = (
|
x, y = (
|
||||||
x.to(self.pytorch_settings.device),
|
x.to(self.pytorch_settings.device),
|
||||||
@@ -426,28 +461,19 @@ class HyperTraining:
|
|||||||
running_error += error_value
|
running_error += error_value
|
||||||
running_error_2 += error_value
|
running_error_2 += error_value
|
||||||
|
|
||||||
if enable_progress:
|
# if enable_progress:
|
||||||
progress.update(task, advance=1, description=f"{error_value:.3e}")
|
# progress.update(task, advance=1, description=f"{error_value:.3e}")
|
||||||
|
|
||||||
if writer is not None:
|
if writer is not None:
|
||||||
if batch_idx % self.pytorch_settings.write_every == 0:
|
if batch_idx % self.pytorch_settings.write_every == 0:
|
||||||
writer.add_scalar(
|
writer.add_scalar(
|
||||||
"eval loss",
|
"eval loss",
|
||||||
running_error_2
|
running_error_2 / (self.pytorch_settings.write_every if batch_idx > 0 else 1),
|
||||||
/ (
|
epoch * min(len(valid_loader), self.optuna_settings._n_valid_batches) + batch_idx,
|
||||||
self.pytorch_settings.write_every
|
|
||||||
if batch_idx > 0
|
|
||||||
else 1
|
|
||||||
),
|
|
||||||
epoch
|
|
||||||
* min(
|
|
||||||
len(valid_loader), self.optuna_settings.n_valid_batches
|
|
||||||
)
|
|
||||||
+ batch_idx,
|
|
||||||
)
|
)
|
||||||
running_error_2 = 0.0
|
running_error_2 = 0.0
|
||||||
|
|
||||||
running_error /= min(len(valid_loader), self.optuna_settings.n_valid_batches)
|
running_error /= min(len(valid_loader), self.optuna_settings._n_valid_batches)
|
||||||
|
|
||||||
if writer is not None:
|
if writer is not None:
|
||||||
title_append, subtitle = self.build_title(trial)
|
title_append, subtitle = self.build_title(trial)
|
||||||
@@ -463,8 +489,8 @@ class HyperTraining:
|
|||||||
epoch + 1,
|
epoch + 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
if enable_progress:
|
# if enable_progress:
|
||||||
progress.stop()
|
# progress.stop()
|
||||||
|
|
||||||
return running_error
|
return running_error
|
||||||
|
|
||||||
@@ -496,103 +522,94 @@ class HyperTraining:
|
|||||||
return ys, xs, y_preds
|
return ys, xs, y_preds
|
||||||
|
|
||||||
def objective(self, trial: optuna.Trial, plot_before=False):
|
def objective(self, trial: optuna.Trial, plot_before=False):
|
||||||
|
if self.stop_study:
|
||||||
|
trial.study.stop()
|
||||||
model = None
|
model = None
|
||||||
exc = None
|
|
||||||
try:
|
|
||||||
# rprint(*list(self.study_name.split("_")))
|
|
||||||
|
|
||||||
writer = self.setup_tb_writer(
|
writer = self.setup_tb_writer(
|
||||||
self.optuna_settings.study_name,
|
self.optuna_settings.study_name,
|
||||||
f"{trial.number:0{len(str(self.optuna_settings.n_trials))}}",
|
f"{trial.number:0{len(str(self.optuna_settings.n_trials)) + 2}}",
|
||||||
|
)
|
||||||
|
|
||||||
|
model = self.define_model(trial, writer)
|
||||||
|
|
||||||
|
# n_nodes = trial.params.get("model_n_hidden_layers", self.model_settings.model_n_layers) * trial.params.get("model_hidden_dim", self.model_settings.unit_count)
|
||||||
|
|
||||||
|
title_append, subtitle = self.build_title(trial)
|
||||||
|
|
||||||
|
writer.add_figure(
|
||||||
|
"fiber response",
|
||||||
|
self.plot_model_response(
|
||||||
|
trial,
|
||||||
|
model=model,
|
||||||
|
title_append=title_append,
|
||||||
|
subtitle=subtitle,
|
||||||
|
show=plot_before,
|
||||||
|
),
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
|
||||||
|
train_loader, valid_loader = self.get_sliced_data(trial)
|
||||||
|
|
||||||
|
optimizer_name = trial.suggest_categorical_optional("optimizer", self.optimizer_settings.optimizer)
|
||||||
|
|
||||||
|
lr = trial.suggest_float_optional("lr", self.optimizer_settings.learning_rate, log=True)
|
||||||
|
|
||||||
|
optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)
|
||||||
|
if self.optimizer_settings.scheduler is not None:
|
||||||
|
scheduler = getattr(optim.lr_scheduler, self.optimizer_settings.scheduler)(
|
||||||
|
optimizer, **self.optimizer_settings.scheduler_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
model = self.define_model(trial, writer)
|
for epoch in range(self.pytorch_settings.epochs):
|
||||||
n_params = sum(p.numel() for p in model.parameters())
|
trial.set_user_attr("epoch", epoch)
|
||||||
# n_nodes = trial.params.get("model_n_layers", self.model_settings.model_n_layers) * trial.params.get("model_hidden_dim", self.model_settings.unit_count)
|
# enable_progress = self.optuna_settings.n_threads == 1
|
||||||
|
# if enable_progress:
|
||||||
title_append, subtitle = self.build_title(trial)
|
# self.console.rule(
|
||||||
|
# f"Epoch {epoch + 1}/{self.pytorch_settings.epochs}"
|
||||||
writer.add_figure(
|
# )
|
||||||
"fiber response",
|
self.train_model(
|
||||||
self.plot_model_response(
|
trial,
|
||||||
trial,
|
model,
|
||||||
model=model,
|
optimizer,
|
||||||
title_append=title_append,
|
train_loader,
|
||||||
subtitle=subtitle,
|
epoch,
|
||||||
show=plot_before,
|
writer,
|
||||||
),
|
# enable_progress=enable_progress,
|
||||||
0,
|
|
||||||
)
|
)
|
||||||
|
error = self.eval_model(
|
||||||
train_loader, valid_loader = self.get_sliced_data(trial)
|
trial,
|
||||||
|
model,
|
||||||
optimizer_name = force_suggest_categorical(
|
valid_loader,
|
||||||
trial, "optimizer", self.optimizer_settings.optimizer
|
epoch,
|
||||||
|
writer,
|
||||||
|
# enable_progress=enable_progress,
|
||||||
)
|
)
|
||||||
|
|
||||||
lr = force_suggest_float(
|
|
||||||
trial, "lr", self.optimizer_settings.learning_rate, log=True
|
|
||||||
)
|
|
||||||
|
|
||||||
optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)
|
|
||||||
if self.optimizer_settings.scheduler is not None:
|
if self.optimizer_settings.scheduler is not None:
|
||||||
scheduler = getattr(optim.lr_scheduler, self.optimizer_settings.scheduler)(
|
scheduler.step(error)
|
||||||
optimizer, **self.optimizer_settings.scheduler_kwargs)
|
|
||||||
|
|
||||||
for epoch in range(self.pytorch_settings.epochs):
|
trial.set_user_attr("mse", error)
|
||||||
enable_progress = self.optuna_settings.n_threads == 1
|
trial.set_user_attr("log_mse", np.log10(error + np.finfo(float).eps))
|
||||||
if enable_progress:
|
trial.set_user_attr("neg_mse", -error)
|
||||||
self.console.rule(
|
trial.set_user_attr("neg_log_mse", -np.log10(error + np.finfo(float).eps))
|
||||||
f"Epoch {epoch + 1}/{self.pytorch_settings.epochs}"
|
if not self.optuna_settings._multi_objective:
|
||||||
)
|
trial.report(error, epoch)
|
||||||
self.train_model(
|
if trial.should_prune():
|
||||||
trial,
|
raise optuna.exceptions.TrialPruned()
|
||||||
model,
|
|
||||||
optimizer,
|
|
||||||
train_loader,
|
|
||||||
epoch,
|
|
||||||
writer,
|
|
||||||
enable_progress=enable_progress,
|
|
||||||
)
|
|
||||||
error = self.eval_model(
|
|
||||||
trial,
|
|
||||||
model,
|
|
||||||
valid_loader,
|
|
||||||
epoch,
|
|
||||||
writer,
|
|
||||||
enable_progress=enable_progress,
|
|
||||||
)
|
|
||||||
if self.optimizer_settings.scheduler is not None:
|
|
||||||
scheduler.step(error)
|
|
||||||
|
|
||||||
writer.close()
|
writer.close()
|
||||||
|
|
||||||
if self.optuna_settings.multi_objective:
|
if self.optuna_settings._multi_objective:
|
||||||
return n_params, error
|
return -np.log10(error + np.finfo(float).eps), trial.user_attrs.get("model_n_nodes", -1)
|
||||||
trial.report(error, epoch)
|
|
||||||
if trial.should_prune():
|
|
||||||
raise optuna.exceptions.TrialPruned()
|
|
||||||
return error
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
if self.pytorch_settings.save_models and model is not None:
|
||||||
...
|
save_path = Path(self.pytorch_settings.model_dir) / f"{self.optuna_settings.study_name}_{trial.number}.pth"
|
||||||
# except Exception as e:
|
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
# exc = e
|
torch.save(model, save_path)
|
||||||
finally:
|
|
||||||
if model is not None:
|
|
||||||
save_path = (
|
|
||||||
Path(self.pytorch_settings.model_dir)
|
|
||||||
/ f"{self.optuna_settings.study_name}_{trial.number}.pth"
|
|
||||||
)
|
|
||||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
torch.save(model, save_path)
|
|
||||||
if exc is not None:
|
|
||||||
raise exc
|
|
||||||
|
|
||||||
|
return error
|
||||||
|
|
||||||
def _plot_model_response_eye(
|
def _plot_model_response_eye(self, *signals, labels=None, sps=None, title_append="", subtitle="", show=True):
|
||||||
self, *signals, labels=None, sps=None, title_append="", subtitle="", show=True
|
|
||||||
):
|
|
||||||
if sps is None:
|
if sps is None:
|
||||||
raise ValueError("sps must be provided")
|
raise ValueError("sps must be provided")
|
||||||
if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))):
|
if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))):
|
||||||
@@ -608,9 +625,7 @@ class HyperTraining:
|
|||||||
labels = [f"signal {i + 1}" for i in range(len(signals))]
|
labels = [f"signal {i + 1}" for i in range(len(signals))]
|
||||||
|
|
||||||
fig, axs = plt.subplots(2, len(signals), sharex=True, sharey=True)
|
fig, axs = plt.subplots(2, len(signals), sharex=True, sharey=True)
|
||||||
fig.suptitle(
|
fig.suptitle(f"Eye diagram{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}")
|
||||||
f"Eye diagram{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}"
|
|
||||||
)
|
|
||||||
xaxis = np.linspace(0, 2, 2 * sps, endpoint=False)
|
xaxis = np.linspace(0, 2, 2 * sps, endpoint=False)
|
||||||
for j, (label, signal) in enumerate(zip(labels, signals)):
|
for j, (label, signal) in enumerate(zip(labels, signals)):
|
||||||
# signal = signal.cpu().numpy()
|
# signal = signal.cpu().numpy()
|
||||||
@@ -628,9 +643,7 @@ class HyperTraining:
|
|||||||
if show:
|
if show:
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
def _plot_model_response_head(
|
def _plot_model_response_head(self, *signals, labels=None, sps=None, title_append="", subtitle="", show=True):
|
||||||
self, *signals, labels=None, sps=None, title_append="", subtitle="", show=True
|
|
||||||
):
|
|
||||||
if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))):
|
if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))):
|
||||||
labels = [labels]
|
labels = [labels]
|
||||||
else:
|
else:
|
||||||
@@ -644,16 +657,12 @@ class HyperTraining:
|
|||||||
labels = [f"signal {i + 1}" for i in range(len(signals))]
|
labels = [f"signal {i + 1}" for i in range(len(signals))]
|
||||||
|
|
||||||
fig, axs = plt.subplots(1, 2, sharex=True, sharey=True)
|
fig, axs = plt.subplots(1, 2, sharex=True, sharey=True)
|
||||||
fig.set_size_inches(18,6)
|
fig.set_size_inches(18, 6)
|
||||||
fig.suptitle(
|
fig.suptitle(f"Fiber response{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}")
|
||||||
f"Fiber response{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}"
|
|
||||||
)
|
|
||||||
for i, ax in enumerate(axs):
|
for i, ax in enumerate(axs):
|
||||||
for signal, label in zip(signals, labels):
|
for signal, label in zip(signals, labels):
|
||||||
if sps is not None:
|
if sps is not None:
|
||||||
xaxis = np.linspace(
|
xaxis = np.linspace(0, len(signal) / sps, len(signal), endpoint=False)
|
||||||
0, len(signal) / sps, len(signal), endpoint=False
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
xaxis = np.arange(len(signal))
|
xaxis = np.arange(len(signal))
|
||||||
ax.plot(xaxis, np.abs(signal[:, i]) ** 2, label=label)
|
ax.plot(xaxis, np.abs(signal[:, i]) ** 2, label=label)
|
||||||
@@ -678,10 +687,10 @@ class HyperTraining:
|
|||||||
self.data_settings.drop_first = 100
|
self.data_settings.drop_first = 100
|
||||||
self.data_settings.shuffle = False
|
self.data_settings.shuffle = False
|
||||||
self.data_settings.train_split = 1.0
|
self.data_settings.train_split = 1.0
|
||||||
self.pytorch_settings.batchsize = self.pytorch_settings.eye_symbols if mode == "eye" else self.pytorch_settings.head_symbols
|
self.pytorch_settings.batchsize = (
|
||||||
plot_loader, _ = self.get_sliced_data(
|
self.pytorch_settings.eye_symbols if mode == "eye" else self.pytorch_settings.head_symbols
|
||||||
trial, override={"num_symbols": self.pytorch_settings.batchsize}
|
|
||||||
)
|
)
|
||||||
|
plot_loader, _ = self.get_sliced_data(trial, override={"num_symbols": self.pytorch_settings.batchsize})
|
||||||
self.data_settings = data_settings_backup
|
self.data_settings = data_settings_backup
|
||||||
self.pytorch_settings = pytorch_settings_backup
|
self.pytorch_settings = pytorch_settings_backup
|
||||||
|
|
||||||
@@ -728,13 +737,22 @@ class HyperTraining:
|
|||||||
return fig
|
return fig
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def build_title(trial):
|
def build_title(trial: optuna.trial.Trial):
|
||||||
title_append = f"for trial {trial.number}"
|
title_append = f"for trial {trial.number}"
|
||||||
|
model_n_layers = util.misc.multi_getattr((trial.params, trial.user_attrs), "model_n_layers", 0)
|
||||||
|
model_hidden_dims = [
|
||||||
|
util.misc.multi_getattr((trial.params, trial.user_attrs), f"model_hidden_dim_{i}", 0)
|
||||||
|
for i in range(model_n_layers)
|
||||||
|
]
|
||||||
|
model_activation_func = util.misc.multi_getattr(
|
||||||
|
(trial.params, trial.user_attrs),
|
||||||
|
"model_activation_func",
|
||||||
|
"unknown act. fun",
|
||||||
|
)
|
||||||
|
model_dtype = util.misc.multi_getattr((trial.params, trial.user_attrs), "model_dtype", "unknown dtype")
|
||||||
|
|
||||||
subtitle = (
|
subtitle = (
|
||||||
f"{trial.params['model_n_layers']} layers, "
|
f"{model_n_layers} layers à ({', '.join(model_hidden_dims)}) units, {model_activation_func}, {model_dtype}"
|
||||||
f"{', '.join([str(trial.params[f'model_hidden_dim_{i}']) for i in range(trial.params['model_n_layers'])])} units, "
|
|
||||||
f"{trial.params['model_activation_func']}, "
|
|
||||||
f"{trial.params['model_dtype']}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return title_append, subtitle
|
return title_append, subtitle
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
@@ -14,7 +14,7 @@ class DataSettings:
|
|||||||
config_path: str # = "data/*-128-16384-100000-0-0-17-0-PAM4-0.ini"
|
config_path: str # = "data/*-128-16384-100000-0-0-17-0-PAM4-0.ini"
|
||||||
dtype: tuple = ("complex64", "float64")
|
dtype: tuple = ("complex64", "float64")
|
||||||
symbols: tuple | float | int = 8
|
symbols: tuple | float | int = 8
|
||||||
model_input_dim: tuple | float | int = 64
|
output_size: tuple | float | int = 64
|
||||||
shuffle: bool = True
|
shuffle: bool = True
|
||||||
in_out_delay: float = 0
|
in_out_delay: float = 0
|
||||||
xy_delay: tuple | float | int = 0
|
xy_delay: tuple | float | int = 0
|
||||||
@@ -33,6 +33,7 @@ class PytorchSettings:
|
|||||||
dataloader_workers: int = 2
|
dataloader_workers: int = 2
|
||||||
dataloader_prefetch: int = 2
|
dataloader_prefetch: int = 2
|
||||||
|
|
||||||
|
save_models: bool = True
|
||||||
model_dir: str = ".models"
|
model_dir: str = ".models"
|
||||||
|
|
||||||
summary_dir: str = ".runs"
|
summary_dir: str = ".runs"
|
||||||
@@ -45,11 +46,10 @@ class PytorchSettings:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class ModelSettings:
|
class ModelSettings:
|
||||||
output_dim: int = 2
|
output_dim: int = 2
|
||||||
model_n_layers: tuple | int = 3
|
n_hidden_layers: tuple | int = 3
|
||||||
unit_count: tuple | int = 8
|
n_hidden_nodes: tuple | int = 8
|
||||||
# n_units_range: tuple | int = (2, 32)
|
model_activation_func: tuple = "ModReLU"
|
||||||
# activation_func_range: tuple = ("ModReLU", "ZReLU", "CReLU", "Mag", "Identity")
|
overrides: dict = field(default_factory=dict)
|
||||||
model_activation_func: tuple = ("ModReLU",)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -60,18 +60,36 @@ class OptimizerSettings:
|
|||||||
scheduler_kwargs: dict | None = None
|
scheduler_kwargs: dict | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def _pruner_default_kwargs():
|
||||||
|
# MedianPruner
|
||||||
|
return {
|
||||||
|
"n_startup_trials": 0,
|
||||||
|
"n_warmup_steps": 5,
|
||||||
|
}
|
||||||
# optuna settings
|
# optuna settings
|
||||||
@dataclass
|
@dataclass
|
||||||
class OptunaSettings:
|
class OptunaSettings:
|
||||||
n_trials: int = 128
|
n_trials: int = 128
|
||||||
n_threads: int = 4
|
n_workers: int = 1
|
||||||
timeout: int = 600
|
timeout: int = None
|
||||||
|
pruner: str = "MedianPruner"
|
||||||
|
pruner_kwargs: dict = field(default_factory=_pruner_default_kwargs)
|
||||||
directions: tuple = ("minimize",)
|
directions: tuple = ("minimize",)
|
||||||
metrics_names: tuple = ("mse",)
|
metrics_names: tuple = ("mse",)
|
||||||
limit_examples: bool = True
|
limit_examples: bool = True
|
||||||
n_train_batches: int = 100
|
n_train_batches: int = float("inf")
|
||||||
n_valid_batches: int = 100
|
n_valid_batches: int = float("inf")
|
||||||
storage: str = "sqlite:///example.db"
|
storage: str = "sqlite:///example.db"
|
||||||
study_name: str = (
|
study_name: str = (
|
||||||
f"optuna_study_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}"
|
f"optuna_study_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}"
|
||||||
)
|
)
|
||||||
|
n_trials_filter: tuple|list = None, #(optuna.trial.TrialState.COMPLETE,)
|
||||||
|
remove_outliers: float|int = None
|
||||||
|
## reserved, set by HyperTraining
|
||||||
|
_multi_objective = None
|
||||||
|
_parallel = None
|
||||||
|
_n_threads = None
|
||||||
|
_directions = None
|
||||||
|
_direction = None
|
||||||
|
_n_train_batches = None
|
||||||
|
_n_valid_batches = None
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
import optuna
|
||||||
from hypertraining.hypertraining import HyperTraining
|
from hypertraining.hypertraining import HyperTraining
|
||||||
from hypertraining.settings import (
|
from hypertraining.settings import (
|
||||||
GlobalSettings,
|
GlobalSettings,
|
||||||
@@ -10,59 +12,72 @@ from hypertraining.settings import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
global_settings = GlobalSettings(
|
global_settings = GlobalSettings(
|
||||||
seed = 42,
|
seed=42,
|
||||||
)
|
)
|
||||||
|
|
||||||
data_settings = DataSettings(
|
data_settings = DataSettings(
|
||||||
config_path = "data/*-128-16384-100000-0-0-17-0-PAM4-0.ini",
|
config_path="data/*-128-16384-100000-0-0-17-0-PAM4-0.ini",
|
||||||
dtype = ("complex128", "complex64", "float64", "float32"),
|
dtype="complex64",
|
||||||
symbols = (1, 16),
|
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
|
||||||
model_input_dim = (1, 32),
|
symbols=13, # study: single_core_regen_20241123_011232
|
||||||
shuffle = True,
|
# output_size = (11, 32), # ballpark 26 taps -> 2 taps per input symbol -> 1 tap every 0.01m (model has 52 inputs)
|
||||||
in_out_delay = 0,
|
output_size=26, # study: single_core_regen_20241123_011232 (model_input_dim/2)
|
||||||
xy_delay = 0,
|
shuffle=True,
|
||||||
drop_first = 1000,
|
in_out_delay=0,
|
||||||
train_split = 0.8,
|
xy_delay=0,
|
||||||
|
drop_first=128 * 100,
|
||||||
|
train_split=0.8,
|
||||||
)
|
)
|
||||||
|
|
||||||
pytorch_settings = PytorchSettings(
|
pytorch_settings = PytorchSettings(
|
||||||
epochs = 25,
|
epochs=10,
|
||||||
batchsize = 2**10,
|
batchsize=2**10,
|
||||||
device = "cuda",
|
device="cuda",
|
||||||
dataloader_workers = 2,
|
dataloader_workers=2,
|
||||||
dataloader_prefetch = 2,
|
dataloader_prefetch=4,
|
||||||
summary_dir = ".runs",
|
summary_dir=".runs",
|
||||||
write_every = 2**5,
|
write_every=2**5,
|
||||||
model_dir = ".models",
|
save_models=True,
|
||||||
|
model_dir=".models",
|
||||||
)
|
)
|
||||||
|
|
||||||
model_settings = ModelSettings(
|
model_settings = ModelSettings(
|
||||||
output_dim = 2,
|
output_dim=2,
|
||||||
model_n_layers = (2, 8),
|
# n_hidden_layers = (3, 8),
|
||||||
unit_count = (2, 16),
|
n_hidden_layers=(4, 6), # study: single_core_regen_20241123_011232
|
||||||
model_activation_func = ("ModReLU")#, "ZReLU", "Mag")#, "CReLU", "Identity"),
|
n_hidden_nodes=(4,20),
|
||||||
|
# overrides={
|
||||||
|
# "n_hidden_nodes_0": (14, 20), # study: single_core_regen_20241123_011232
|
||||||
|
# "n_hidden_nodes_1": (8, 16),
|
||||||
|
# "n_hidden_nodes_2": (10, 16),
|
||||||
|
# # "n_hidden_nodes_3": (4, 20), # study: single_core_regen_20241123_135749
|
||||||
|
# "n_hidden_nodes_4": (2, 8),
|
||||||
|
# "n_hidden_nodes_5": (10, 16),
|
||||||
|
# },
|
||||||
|
# model_activation_func = ("ModReLU", "Mag", "Identity")
|
||||||
|
model_activation_func="Mag", # study: single_core_regen_20241123_011232
|
||||||
)
|
)
|
||||||
|
|
||||||
optimizer_settings = OptimizerSettings(
|
optimizer_settings = OptimizerSettings(
|
||||||
optimizer = ("Adam", "RMSprop"),#, "SGD"),
|
optimizer="Adam",
|
||||||
# learning_rate = (1e-5, 1e-1),
|
# learning_rate = (1e-5, 1e-1),
|
||||||
learning_rate=1e-3,
|
learning_rate=5e-4,
|
||||||
# scheduler = "ReduceLROnPlateau",
|
|
||||||
# scheduler_kwargs = {"mode": "min", "factor": 0.5, "patience": 10}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
optuna_settings = OptunaSettings(
|
optuna_settings = OptunaSettings(
|
||||||
n_trials = 4096,
|
n_trials=512,
|
||||||
n_threads = 16,
|
n_workers=14,
|
||||||
timeout = 600,
|
timeout=3600,
|
||||||
directions = ("minimize","minimize"),
|
directions=("maximize", "minimize"),
|
||||||
metrics_names = ("n_params","mse"),
|
metrics_names=("neg_log_mse","n_nodes"),
|
||||||
|
limit_examples=True,
|
||||||
limit_examples = True,
|
n_train_batches=500,
|
||||||
n_train_batches = 100,
|
# n_valid_batches = 100,
|
||||||
n_valid_batches = 100,
|
storage="sqlite:///data/single_core_regen.db",
|
||||||
storage = "sqlite:///data/single_core_regen.db",
|
study_name=f"single_core_regen_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
|
||||||
study_name = f"single_core_regen_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
|
n_trials_filter=(optuna.trial.TrialState.COMPLETE, optuna.trial.TrialState.PRUNED),
|
||||||
|
pruner="MedianPruner",
|
||||||
|
pruner_kwargs=None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -78,8 +93,6 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
hyper_training.setup_study()
|
hyper_training.setup_study()
|
||||||
|
|
||||||
# hyper_training.resume_latest_study()
|
|
||||||
|
|
||||||
hyper_training.run_study()
|
hyper_training.run_study()
|
||||||
# best_trial = hyper_training.study.best_trial
|
# best_trial = hyper_training.study.best_trial
|
||||||
|
|
||||||
|
|||||||
@@ -8,10 +8,7 @@ def complex_mse_loss(input, target):
|
|||||||
Compute the mean squared error between two complex tensors.
|
Compute the mean squared error between two complex tensors.
|
||||||
"""
|
"""
|
||||||
if input.is_complex():
|
if input.is_complex():
|
||||||
return torch.mean(
|
return torch.mean(torch.square(input.real - target.real) + torch.square(input.imag - target.imag))
|
||||||
torch.square(input.real - target.real)
|
|
||||||
+ torch.square(input.imag - target.imag)
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
return F.mse_loss(input, target)
|
return F.mse_loss(input, target)
|
||||||
|
|
||||||
@@ -21,10 +18,7 @@ def complex_sse_loss(input, target):
|
|||||||
Compute the sum squared error between two complex tensors.
|
Compute the sum squared error between two complex tensors.
|
||||||
"""
|
"""
|
||||||
if input.is_complex():
|
if input.is_complex():
|
||||||
return torch.sum(
|
return torch.sum(torch.square(input.real - target.real) + torch.square(input.imag - target.imag))
|
||||||
torch.square(input.real - target.real)
|
|
||||||
+ torch.square(input.imag - target.imag)
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
return torch.sum(torch.square(input - target))
|
return torch.sum(torch.square(input - target))
|
||||||
|
|
||||||
@@ -48,6 +42,7 @@ class UnitaryLayer(nn.Module):
|
|||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"UnitaryLayer({self.in_features}, {self.out_features})"
|
return f"UnitaryLayer({self.in_features}, {self.out_features})"
|
||||||
|
|
||||||
|
|
||||||
class SemiUnitaryLayer(nn.Module):
|
class SemiUnitaryLayer(nn.Module):
|
||||||
def __init__(self, input_dim, output_dim, dtype=None):
|
def __init__(self, input_dim, output_dim, dtype=None):
|
||||||
super(SemiUnitaryLayer, self).__init__()
|
super(SemiUnitaryLayer, self).__init__()
|
||||||
@@ -62,9 +57,9 @@ class SemiUnitaryLayer(nn.Module):
|
|||||||
# Ensure the weights are semi-unitary by QR decomposition
|
# Ensure the weights are semi-unitary by QR decomposition
|
||||||
q, _ = torch.linalg.qr(self.weight)
|
q, _ = torch.linalg.qr(self.weight)
|
||||||
if self.input_dim > self.output_dim:
|
if self.input_dim > self.output_dim:
|
||||||
self.weight.data = q[:self.input_dim, :self.output_dim]
|
self.weight.data = q[: self.input_dim, : self.output_dim]
|
||||||
else:
|
else:
|
||||||
self.weight.data = q[:self.output_dim, :self.input_dim].t()
|
self.weight.data = q[: self.output_dim, : self.input_dim].t()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
out = torch.matmul(x, self.weight)
|
out = torch.matmul(x, self.weight)
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
# from torch.utils.data import Sampler
|
# from torch.utils.data import Sampler
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import configparser
|
import configparser
|
||||||
@@ -22,6 +23,7 @@ import configparser
|
|||||||
# def __len__(self):
|
# def __len__(self):
|
||||||
# return len(self.indices)
|
# return len(self.indices)
|
||||||
|
|
||||||
|
|
||||||
def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=False, device=None, dtype=None):
|
def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=False, device=None, dtype=None):
|
||||||
filepath = Path(config_path)
|
filepath = Path(config_path)
|
||||||
filepath = filepath.parent.glob(filepath.name)
|
filepath = filepath.parent.glob(filepath.name)
|
||||||
@@ -43,7 +45,7 @@ def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=Fals
|
|||||||
if normalize:
|
if normalize:
|
||||||
# square gets normalized to 1, as the power is (proportional to) the square of the amplitude
|
# square gets normalized to 1, as the power is (proportional to) the square of the amplitude
|
||||||
a, b, c, d = np.square(data.T)
|
a, b, c, d = np.square(data.T)
|
||||||
a, b, c, d = a/np.max(np.abs(a)), b/np.max(np.abs(b)), c/np.max(np.abs(c)), d/np.max(np.abs(d))
|
a, b, c, d = a / np.max(np.abs(a)), b / np.max(np.abs(b)), c / np.max(np.abs(c)), d / np.max(np.abs(d))
|
||||||
data = np.sqrt(np.array([a, b, c, d]).T)
|
data = np.sqrt(np.array([a, b, c, d]).T)
|
||||||
|
|
||||||
if real:
|
if real:
|
||||||
@@ -55,6 +57,7 @@ def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=Fals
|
|||||||
|
|
||||||
return data, config
|
return data, config
|
||||||
|
|
||||||
|
|
||||||
def roll_along(arr, shifts, dim):
|
def roll_along(arr, shifts, dim):
|
||||||
# https://stackoverflow.com/a/76920720
|
# https://stackoverflow.com/a/76920720
|
||||||
# (c) Mateen Ulhaq, 2023
|
# (c) Mateen Ulhaq, 2023
|
||||||
@@ -67,6 +70,7 @@ def roll_along(arr, shifts, dim):
|
|||||||
indices = (dim_indices - shifts.unsqueeze(dim)) % arr.shape[dim]
|
indices = (dim_indices - shifts.unsqueeze(dim)) % arr.shape[dim]
|
||||||
return torch.gather(arr, dim, indices)
|
return torch.gather(arr, dim, indices)
|
||||||
|
|
||||||
|
|
||||||
class FiberRegenerationDataset(Dataset):
|
class FiberRegenerationDataset(Dataset):
|
||||||
"""
|
"""
|
||||||
Dataset for fiber regeneration training.
|
Dataset for fiber regeneration training.
|
||||||
@@ -105,7 +109,7 @@ class FiberRegenerationDataset(Dataset):
|
|||||||
drop_first: float | int = 0,
|
drop_first: float | int = 0,
|
||||||
dtype: torch.dtype = None,
|
dtype: torch.dtype = None,
|
||||||
real: bool = False,
|
real: bool = False,
|
||||||
device = None,
|
device=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -127,18 +131,10 @@ class FiberRegenerationDataset(Dataset):
|
|||||||
|
|
||||||
# check types
|
# check types
|
||||||
assert isinstance(file_path, str), "file_path must be a string"
|
assert isinstance(file_path, str), "file_path must be a string"
|
||||||
assert isinstance(symbols, (float, int)), (
|
assert isinstance(symbols, (float, int)), "symbols must be a float or an integer"
|
||||||
"symbols must be a float or an integer"
|
assert output_dim is None or isinstance(output_dim, int), "output_len must be an integer"
|
||||||
)
|
assert isinstance(target_delay, (float, int)), "target_delay must be a float or an integer"
|
||||||
assert output_dim is None or isinstance(output_dim, int), (
|
assert isinstance(xy_delay, (float, int)), "xy_delay must be a float or an integer"
|
||||||
"output_len must be an integer"
|
|
||||||
)
|
|
||||||
assert isinstance(target_delay, (float, int)), (
|
|
||||||
"target_delay must be a float or an integer"
|
|
||||||
)
|
|
||||||
assert isinstance(xy_delay, (float, int)), (
|
|
||||||
"xy_delay must be a float or an integer"
|
|
||||||
)
|
|
||||||
assert isinstance(drop_first, int), "drop_first must be an integer"
|
assert isinstance(drop_first, int), "drop_first must be an integer"
|
||||||
|
|
||||||
# check values
|
# check values
|
||||||
@@ -159,7 +155,15 @@ class FiberRegenerationDataset(Dataset):
|
|||||||
"glova": {"sps": 128},
|
"glova": {"sps": 128},
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
data_raw, self.config = load_data(file_path, skipfirst=drop_first, symbols=kwargs.pop("num_symbols", None), real=real, normalize=True, device=device, dtype=dtype)
|
data_raw, self.config = load_data(
|
||||||
|
file_path,
|
||||||
|
skipfirst=drop_first,
|
||||||
|
symbols=kwargs.pop("num_symbols", None),
|
||||||
|
real=real,
|
||||||
|
normalize=True,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
self.device = data_raw.device
|
self.device = data_raw.device
|
||||||
|
|
||||||
@@ -180,9 +184,7 @@ class FiberRegenerationDataset(Dataset):
|
|||||||
else int(self.target_delay * self.samples_per_symbol)
|
else int(self.target_delay * self.samples_per_symbol)
|
||||||
)
|
)
|
||||||
self.xy_delay_samples = (
|
self.xy_delay_samples = (
|
||||||
ovrd_xy_delay_samples
|
ovrd_xy_delay_samples if ovrd_xy_delay_samples is not None else int(self.xy_delay * self.samples_per_symbol)
|
||||||
if ovrd_xy_delay_samples is not None
|
|
||||||
else int(self.xy_delay * self.samples_per_symbol)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# data_raw = torch.tensor(data_raw, dtype=dtype)
|
# data_raw = torch.tensor(data_raw, dtype=dtype)
|
||||||
@@ -208,9 +210,7 @@ class FiberRegenerationDataset(Dataset):
|
|||||||
# E_out_y[0:N] ] E_out_y[-3:N-3] ] E_out_y[0:N-3] ]
|
# E_out_y[0:N] ] E_out_y[-3:N-3] ] E_out_y[0:N-3] ]
|
||||||
|
|
||||||
if self.xy_delay_samples != 0:
|
if self.xy_delay_samples != 0:
|
||||||
data_raw = roll_along(
|
data_raw = roll_along(data_raw, [0, self.xy_delay_samples, 0, self.xy_delay_samples], dim=1)
|
||||||
data_raw, [0, self.xy_delay_samples, 0, self.xy_delay_samples], dim=1
|
|
||||||
)
|
|
||||||
if self.xy_delay_samples > 0:
|
if self.xy_delay_samples > 0:
|
||||||
data_raw = data_raw[:, self.xy_delay_samples :]
|
data_raw = data_raw[:, self.xy_delay_samples :]
|
||||||
elif self.xy_delay_samples < 0:
|
elif self.xy_delay_samples < 0:
|
||||||
|
|||||||
@@ -1,45 +1,335 @@
|
|||||||
def _optional_suggest(trial, name, range_or_value, log=False, step=None, type='int'):
|
from typing import Any
|
||||||
# not a range
|
from optuna import trial
|
||||||
if not hasattr(range_or_value, '__iter__') or isinstance(range_or_value, str):
|
|
||||||
return range_or_value
|
|
||||||
|
def install_optional_suggests():
|
||||||
|
trial.Trial.suggest_categorical_optional = suggest_categorical_optional_wrapper
|
||||||
|
trial.Trial.suggest_int_optional = suggest_int_optional_wrapper
|
||||||
|
trial.Trial.suggest_float_optional = suggest_float_optional_wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def _is_listlike(obj: Any) -> bool:
|
||||||
|
return hasattr(obj, "__iter__") and not isinstance(obj, str)
|
||||||
|
|
||||||
|
|
||||||
|
def _optional_suggest(
|
||||||
|
*,
|
||||||
|
trial: trial.Trial,
|
||||||
|
name: str,
|
||||||
|
range_or_value: Any,
|
||||||
|
type: str,
|
||||||
|
log: bool = False,
|
||||||
|
step: int | float | None = None,
|
||||||
|
add_user: bool = False,
|
||||||
|
force: bool = False,
|
||||||
|
multiply: float | int = 1,
|
||||||
|
set_new: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Suggest a value for a parameter with more control over the process
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
type : str
|
||||||
|
The type of the parameter
|
||||||
|
trial : optuna.trial.Trial
|
||||||
|
The trial object
|
||||||
|
name : str
|
||||||
|
The name of the parameter
|
||||||
|
range_or_value : Any
|
||||||
|
The range of values or a single value
|
||||||
|
log : bool, optional
|
||||||
|
Whether to use a logarithmic scale, by default False
|
||||||
|
step : int|float|None, optional
|
||||||
|
The step size, by default None
|
||||||
|
add_user : bool, optional
|
||||||
|
Whether to add the suggested value to the user attributes if not added as a parameter, by default False
|
||||||
|
force : bool, optional
|
||||||
|
Whether to force a single value to be suggested, by default False
|
||||||
|
multiply : float| int, optional
|
||||||
|
A multiplier to apply to the range or value, by default 1. Ignored for type "categorical".
|
||||||
|
set_new : bool, optional
|
||||||
|
Whether to override the parameter if it already exists, by default True
|
||||||
|
"""
|
||||||
|
|
||||||
|
# value should be retrieved from trial
|
||||||
|
if not set_new and name in trial.params:
|
||||||
|
return trial.params[name]
|
||||||
|
|
||||||
|
# value is not a list or tuple
|
||||||
|
if not _is_listlike(range_or_value):
|
||||||
|
range_or_value = (range_or_value,)
|
||||||
|
|
||||||
# range with only one value
|
# range with only one value
|
||||||
if len(range_or_value) == 1:
|
if len(range_or_value) == 1 and not force:
|
||||||
|
if add_user:
|
||||||
|
trial.set_user_attr(name, range_or_value[0])
|
||||||
return range_or_value[0]
|
return range_or_value[0]
|
||||||
|
|
||||||
if type == 'int':
|
# normal operation
|
||||||
step = step or 1
|
if type == "categorical":
|
||||||
return trial.suggest_int(name, *range_or_value, step=step, log=log)
|
|
||||||
|
|
||||||
if type == 'float':
|
|
||||||
return trial.suggest_float(name, *range_or_value, step=step, log=log)
|
|
||||||
|
|
||||||
if type == 'categorical':
|
|
||||||
return trial.suggest_categorical(name, range_or_value)
|
return trial.suggest_categorical(name, range_or_value)
|
||||||
|
|
||||||
|
# multiply range
|
||||||
|
range_or_value = tuple(multiply * x for x in range_or_value)
|
||||||
|
#
|
||||||
|
if len(range_or_value) > 2:
|
||||||
|
raise UserWarning("More than two values in range, using highest and lowest")
|
||||||
|
low = min(range_or_value)
|
||||||
|
high = max(range_or_value)
|
||||||
|
|
||||||
|
if type == "float":
|
||||||
|
return trial.suggest_float(name, low, high, step=step, log=log)
|
||||||
|
|
||||||
|
if type == "int":
|
||||||
|
step = step or 1
|
||||||
|
lowi = int(low)
|
||||||
|
highi = int(high)
|
||||||
|
if lowi != low or highi != high:
|
||||||
|
raise ValueError(f"Range {low} to {high} (using multiplier {multiply}) is not valid for int")
|
||||||
|
return trial.suggest_int(name, lowi, highi, step=step, log=log)
|
||||||
|
|
||||||
raise ValueError(f"Unknown type: {type}")
|
raise ValueError(f"Unknown type: {type}")
|
||||||
|
|
||||||
|
|
||||||
def optional_suggest_categorical(trial, name, choices_or_value):
|
def suggest_categorical_optional(
|
||||||
return _optional_suggest(trial, name, choices_or_value, type='categorical')
|
trial: trial.Trial,
|
||||||
|
name: str,
|
||||||
|
choices_or_value: tuple[Any] | list[Any] | Any,
|
||||||
|
add_user: bool = False,
|
||||||
|
force: bool = False,
|
||||||
|
set_new: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Suggest a value for a categorical parameter with more control over the process
|
||||||
|
|
||||||
def optional_suggest_int(trial, name, range_or_value, step=None, log=False):
|
Parameters
|
||||||
return _optional_suggest(trial, name, range_or_value, step=step, log=log, type='int')
|
----------
|
||||||
|
trial : optuna.trial.Trial
|
||||||
|
The trial object
|
||||||
|
name : str
|
||||||
|
The name of the parameter
|
||||||
|
choices_or_value : tuple|list|Any
|
||||||
|
The choices or a single value
|
||||||
|
add_user : bool, optional
|
||||||
|
Whether to add the suggested value to the user attributes if not added as a parameter, by default False
|
||||||
|
force : bool, optional
|
||||||
|
Whether to suggest a single value as a parameter, by default False
|
||||||
|
set_new : bool, optional
|
||||||
|
Whether to override the parameter if it already exists, by default True
|
||||||
|
"""
|
||||||
|
return _optional_suggest(
|
||||||
|
trial=trial, name=name, range_or_value=choices_or_value, type="categorical", add_user=add_user, force=force, set_new=set_new
|
||||||
|
)
|
||||||
|
|
||||||
def optional_suggest_float(trial, name, range_or_value, step=None, log=False):
|
|
||||||
return _optional_suggest(trial, name, range_or_value, step=step, log=log, type='float')
|
|
||||||
|
|
||||||
def force_suggest_int(trial, name, range_or_value, step=1, log=False):
|
def suggest_int_optional(
|
||||||
if not hasattr(range_or_value, '__iter__') or isinstance(range_or_value, str):
|
trial: trial.Trial,
|
||||||
return trial.suggest_int(name, range_or_value, range_or_value, step=step, log=log)
|
name: str,
|
||||||
return trial.suggest_int(name, *range_or_value, step=step, log=log)
|
range_or_value: tuple[int] | list[int] | int,
|
||||||
|
step: int = 1,
|
||||||
|
log: bool = False,
|
||||||
|
add_user: bool = False,
|
||||||
|
force: bool = False,
|
||||||
|
multiply: int = 1,
|
||||||
|
set_new: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Suggest a value for an integer parameter with more control over the process
|
||||||
|
|
||||||
def force_suggest_float(trial, name, range_or_value, step=None, log=False):
|
Parameters
|
||||||
if not hasattr(range_or_value, '__iter__') or isinstance(range_or_value, str):
|
----------
|
||||||
return trial.suggest_float(name, range_or_value, range_or_value, step=step, log=log)
|
trial : optuna.trial.Trial
|
||||||
return trial.suggest_float(name, *range_or_value, step=step, log=log)
|
The trial object
|
||||||
|
name : str
|
||||||
|
The name of the parameter
|
||||||
|
range_or_value : tuple|list|int
|
||||||
|
The range of values or a single value.
|
||||||
|
step : int, optional
|
||||||
|
The step size, by default 1
|
||||||
|
log : bool, optional
|
||||||
|
Whether to use a logarithmic scale, by default False
|
||||||
|
add_user : bool, optional
|
||||||
|
Whether to add the suggested value to the user attributes if not added as a parameter, by default False
|
||||||
|
force : bool, optional
|
||||||
|
Whether to suggest a single value as a parameter, by default False
|
||||||
|
"""
|
||||||
|
return _optional_suggest(
|
||||||
|
trial=trial,
|
||||||
|
name=name,
|
||||||
|
range_or_value=range_or_value,
|
||||||
|
step=step,
|
||||||
|
log=log,
|
||||||
|
type="int",
|
||||||
|
add_user=add_user,
|
||||||
|
force=force,
|
||||||
|
multiply=multiply,
|
||||||
|
set_new=set_new,
|
||||||
|
)
|
||||||
|
|
||||||
def force_suggest_categorical(trial, name, range_or_value):
|
|
||||||
if not hasattr(range_or_value, '__iter__') or isinstance(range_or_value, str):
|
def suggest_float_optional(
|
||||||
return trial.suggest_categorical(name, [range_or_value])
|
trial: trial.Trial,
|
||||||
return trial.suggest_categorical(name, range_or_value)
|
name: str,
|
||||||
|
range_or_value: tuple[float] | list[float] | float,
|
||||||
|
step: float | None = None,
|
||||||
|
log: bool = False,
|
||||||
|
add_user: bool = False,
|
||||||
|
force: bool = False,
|
||||||
|
multiply: float = 1,
|
||||||
|
set_new: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Suggest a value for a float parameter with more control over the process
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
trial : optuna.trial.Trial
|
||||||
|
The trial object
|
||||||
|
name : str
|
||||||
|
The name of the parameter
|
||||||
|
range_or_value : tuple|list|float
|
||||||
|
The range of values or a single value
|
||||||
|
step : float|None, optional
|
||||||
|
The step size, by default None
|
||||||
|
log : bool, optional
|
||||||
|
Whether to use a logarithmic scale, by default False
|
||||||
|
add_user : bool, optional
|
||||||
|
Whether to add the suggested value to the user attributes if not added as a parameter, by default False
|
||||||
|
force : bool, optional
|
||||||
|
Whether to suggest a single value as a parameter, by default False
|
||||||
|
multiply : float, optional
|
||||||
|
A multiplier to apply to the range or value, by default 1
|
||||||
|
set_new : bool, optional
|
||||||
|
Whether to override the parameter if it already exists, by default True
|
||||||
|
"""
|
||||||
|
|
||||||
|
return _optional_suggest(
|
||||||
|
trial=trial,
|
||||||
|
name=name,
|
||||||
|
range_or_value=range_or_value,
|
||||||
|
step=step,
|
||||||
|
log=log,
|
||||||
|
type="float",
|
||||||
|
add_user=add_user,
|
||||||
|
force=force,
|
||||||
|
multiply=multiply,
|
||||||
|
set_new=set_new,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def suggest_categorical_optional_wrapper(
|
||||||
|
self: trial.Trial,
|
||||||
|
name: str,
|
||||||
|
choices_or_value: tuple[Any] | list[Any] | Any,
|
||||||
|
add_user: bool = False,
|
||||||
|
force: bool = False,
|
||||||
|
set_new: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Suggest a value for a categorical parameter with more control over the process
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
name : str
|
||||||
|
The name of the parameter
|
||||||
|
choices_or_value : tuple|list|Any
|
||||||
|
The choices or a single value
|
||||||
|
add_user : bool, optional
|
||||||
|
Whether to add the suggested value to the user attributes if not added as a parameter, by default False
|
||||||
|
force : bool, optional
|
||||||
|
Whether to suggest a single value as a parameter, by default False
|
||||||
|
set_new : bool, optional
|
||||||
|
Whether to override the parameter if it already exists, by default True
|
||||||
|
"""
|
||||||
|
return suggest_categorical_optional(
|
||||||
|
trial=self, name=name, choices_or_value=choices_or_value, add_user=add_user, force=force, set_new=set_new
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def suggest_int_optional_wrapper(
|
||||||
|
self: trial.Trial,
|
||||||
|
name: str,
|
||||||
|
range_or_value: tuple[int] | list[int] | int,
|
||||||
|
step: int = 1,
|
||||||
|
log: bool = False,
|
||||||
|
add_user: bool = False,
|
||||||
|
force: bool = False,
|
||||||
|
multiply: int = 1,
|
||||||
|
set_new: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Suggest a value for an integer parameter with more control over the process
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
name : str
|
||||||
|
The name of the parameter
|
||||||
|
range_or_value : tuple|list|int
|
||||||
|
The range of values or a single value.
|
||||||
|
step : int, optional
|
||||||
|
The step size, by default 1
|
||||||
|
log : bool, optional
|
||||||
|
Whether to use a logarithmic scale, by default False
|
||||||
|
add_user : bool, optional
|
||||||
|
Whether to add the suggested value to the user attributes if not added as a parameter, by default False
|
||||||
|
force : bool, optional
|
||||||
|
Whether to suggest a single value as a parameter, by default False
|
||||||
|
"""
|
||||||
|
return suggest_int_optional(
|
||||||
|
trial=self,
|
||||||
|
name=name,
|
||||||
|
range_or_value=range_or_value,
|
||||||
|
step=step,
|
||||||
|
log=log,
|
||||||
|
add_user=add_user,
|
||||||
|
force=force,
|
||||||
|
multiply=multiply,
|
||||||
|
set_new=set_new,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def suggest_float_optional_wrapper(
|
||||||
|
self: trial.Trial,
|
||||||
|
name: str,
|
||||||
|
range_or_value: tuple[float] | list[float] | float,
|
||||||
|
step: float | None = None,
|
||||||
|
log: bool = False,
|
||||||
|
add_user: bool = False,
|
||||||
|
force: bool = False,
|
||||||
|
multiply: float = 1,
|
||||||
|
set_new: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Suggest a value for a float parameter with more control over the process
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
name : str
|
||||||
|
The name of the parameter
|
||||||
|
range_or_value : tuple|list|float
|
||||||
|
The range of values or a single value
|
||||||
|
step : float|None, optional
|
||||||
|
The step size, by default None
|
||||||
|
log : bool, optional
|
||||||
|
Whether to use a logarithmic scale, by default False
|
||||||
|
add_user : bool, optional
|
||||||
|
Whether to add the suggested value to the user attributes if not added as a parameter, by default False
|
||||||
|
force : bool, optional
|
||||||
|
Whether to suggest a single value as a parameter, by default False
|
||||||
|
multiply : float, optional
|
||||||
|
A multiplier to apply to the range or value, by default 1
|
||||||
|
set_new : bool, optional
|
||||||
|
Whether to override the parameter if it already exists, by default True
|
||||||
|
"""
|
||||||
|
return suggest_float_optional(
|
||||||
|
trial=self,
|
||||||
|
name=name,
|
||||||
|
range_or_value=range_or_value,
|
||||||
|
step=step,
|
||||||
|
log=log,
|
||||||
|
add_user=add_user,
|
||||||
|
force=force,
|
||||||
|
multiply=multiply,
|
||||||
|
set_new=set_new,
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user