Files
optical-regeneration/src/single-core-regen/hypertraining/settings.py
Joseph Hopfmüller ff32aefd52 minor fixes and changes
2024-11-29 15:49:10 +01:00

99 lines
2.5 KiB
Python

from dataclasses import dataclass, field
from datetime import datetime
# global settings
@dataclass(frozen=True)
class GlobalSettings:
seed: int = 42
# data settings
@dataclass
class DataSettings:
config_path: str # = "data/*-128-16384-100000-0-0-17-0-PAM4-0.ini"
dtype: tuple = ("complex64", "float64")
symbols: tuple | float | int = 8
output_size: tuple | float | int = 64
shuffle: bool = True
in_out_delay: float = 0
xy_delay: tuple | float | int = 0
drop_first: int = 1000
train_split: float = 0.8
# pytorch settings
@dataclass
class PytorchSettings:
epochs: int = 1
batchsize: int = 2**10
device: str = "cuda"
dataloader_workers: int = 2
dataloader_prefetch: int = 2
save_models: bool = True
model_dir: str = ".models"
summary_dir: str = ".runs"
write_every: int = 10
head_symbols: int = 40
eye_symbols: int = 400
# model settings
@dataclass
class ModelSettings:
output_dim: int = 2
n_hidden_layers: tuple | int = 3
n_hidden_nodes: tuple | int = 8
model_activation_func: tuple | str = "ModReLU"
overrides: dict = field(default_factory=dict)
dropout_prob: float | None = None
model_layer_function: str | None = None
model_layer_parametrizations: list= field(default_factory=list)
@dataclass
class OptimizerSettings:
optimizer: tuple | str = ("Adam", "RMSprop", "SGD")
learning_rate: tuple | float = (1e-5, 1e-1)
scheduler: str | None = None
scheduler_kwargs: dict | None = None
def _pruner_default_kwargs():
# MedianPruner
return {
"n_startup_trials": 0,
"n_warmup_steps": 5,
}
# optuna settings
@dataclass
class OptunaSettings:
n_trials: int = 128
n_workers: int = 1
timeout: int = None
pruner: str = "MedianPruner"
pruner_kwargs: dict = field(default_factory=_pruner_default_kwargs)
directions: tuple = ("minimize",)
metrics_names: tuple = ("mse",)
limit_examples: bool = True
n_train_batches: int = float("inf")
n_valid_batches: int = float("inf")
storage: str = "sqlite:///example.db"
study_name: str = (
f"optuna_study_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}"
)
n_trials_filter: tuple|list = None, #(optuna.trial.TrialState.COMPLETE,)
remove_outliers: float|int = None
## reserved, set by HyperTraining
_multi_objective = None
_parallel = None
_n_threads = None
_directions = None
_direction = None
_n_train_batches = None
_n_valid_batches = None