99 lines
2.5 KiB
Python
99 lines
2.5 KiB
Python
from dataclasses import dataclass, field
|
|
from datetime import datetime
|
|
|
|
|
|
# global settings
|
|
@dataclass(frozen=True)
|
|
class GlobalSettings:
|
|
seed: int = 42
|
|
|
|
|
|
# data settings
|
|
@dataclass
|
|
class DataSettings:
|
|
config_path: str # = "data/*-128-16384-100000-0-0-17-0-PAM4-0.ini"
|
|
dtype: tuple = ("complex64", "float64")
|
|
symbols: tuple | float | int = 8
|
|
output_size: tuple | float | int = 64
|
|
shuffle: bool = True
|
|
in_out_delay: float = 0
|
|
xy_delay: tuple | float | int = 0
|
|
drop_first: int = 1000
|
|
train_split: float = 0.8
|
|
|
|
|
|
# pytorch settings
|
|
@dataclass
|
|
class PytorchSettings:
|
|
epochs: int = 1
|
|
batchsize: int = 2**10
|
|
|
|
device: str = "cuda"
|
|
|
|
dataloader_workers: int = 2
|
|
dataloader_prefetch: int = 2
|
|
|
|
save_models: bool = True
|
|
model_dir: str = ".models"
|
|
|
|
summary_dir: str = ".runs"
|
|
write_every: int = 10
|
|
head_symbols: int = 40
|
|
eye_symbols: int = 400
|
|
|
|
|
|
# model settings
|
|
@dataclass
|
|
class ModelSettings:
|
|
output_dim: int = 2
|
|
n_hidden_layers: tuple | int = 3
|
|
n_hidden_nodes: tuple | int = 8
|
|
model_activation_func: tuple | str = "ModReLU"
|
|
overrides: dict = field(default_factory=dict)
|
|
dropout_prob: float | None = None
|
|
model_layer_function: str | None = None
|
|
model_layer_parametrizations: list= field(default_factory=list)
|
|
|
|
|
|
@dataclass
|
|
class OptimizerSettings:
|
|
optimizer: tuple | str = ("Adam", "RMSprop", "SGD")
|
|
learning_rate: tuple | float = (1e-5, 1e-1)
|
|
scheduler: str | None = None
|
|
scheduler_kwargs: dict | None = None
|
|
|
|
|
|
def _pruner_default_kwargs():
|
|
# MedianPruner
|
|
return {
|
|
"n_startup_trials": 0,
|
|
"n_warmup_steps": 5,
|
|
}
|
|
# optuna settings
|
|
@dataclass
|
|
class OptunaSettings:
|
|
n_trials: int = 128
|
|
n_workers: int = 1
|
|
timeout: int = None
|
|
pruner: str = "MedianPruner"
|
|
pruner_kwargs: dict = field(default_factory=_pruner_default_kwargs)
|
|
directions: tuple = ("minimize",)
|
|
metrics_names: tuple = ("mse",)
|
|
limit_examples: bool = True
|
|
n_train_batches: int = float("inf")
|
|
n_valid_batches: int = float("inf")
|
|
storage: str = "sqlite:///example.db"
|
|
study_name: str = (
|
|
f"optuna_study_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}"
|
|
)
|
|
n_trials_filter: tuple|list = None, #(optuna.trial.TrialState.COMPLETE,)
|
|
remove_outliers: float|int = None
|
|
## reserved, set by HyperTraining
|
|
_multi_objective = None
|
|
_parallel = None
|
|
_n_threads = None
|
|
_directions = None
|
|
_direction = None
|
|
_n_train_batches = None
|
|
_n_valid_batches = None
|