from dataclasses import dataclass, field from datetime import datetime # global settings @dataclass(frozen=True) class GlobalSettings: seed: int = 42 # data settings @dataclass class DataSettings: config_path: str # = "data/*-128-16384-100000-0-0-17-0-PAM4-0.ini" dtype: tuple = ("complex64", "float64") symbols: tuple | float | int = 8 output_size: tuple | float | int = 64 shuffle: bool = True in_out_delay: float = 0 xy_delay: tuple | float | int = 0 drop_first: int = 1000 train_split: float = 0.8 # pytorch settings @dataclass class PytorchSettings: epochs: int = 1 batchsize: int = 2**10 device: str = "cuda" dataloader_workers: int = 2 dataloader_prefetch: int = 2 save_models: bool = True model_dir: str = ".models" summary_dir: str = ".runs" write_every: int = 10 head_symbols: int = 40 eye_symbols: int = 400 # model settings @dataclass class ModelSettings: output_dim: int = 2 n_hidden_layers: tuple | int = 3 n_hidden_nodes: tuple | int = 8 model_activation_func: tuple | str = "ModReLU" overrides: dict = field(default_factory=dict) dropout_prob: float | None = None model_layer_function: str | None = None model_layer_parametrizations: list= field(default_factory=list) @dataclass class OptimizerSettings: optimizer: tuple | str = ("Adam", "RMSprop", "SGD") learning_rate: tuple | float = (1e-5, 1e-1) scheduler: str | None = None scheduler_kwargs: dict | None = None def _pruner_default_kwargs(): # MedianPruner return { "n_startup_trials": 0, "n_warmup_steps": 5, } # optuna settings @dataclass class OptunaSettings: n_trials: int = 128 n_workers: int = 1 timeout: int = None pruner: str = "MedianPruner" pruner_kwargs: dict = field(default_factory=_pruner_default_kwargs) directions: tuple = ("minimize",) metrics_names: tuple = ("mse",) limit_examples: bool = True n_train_batches: int = float("inf") n_valid_batches: int = float("inf") storage: str = "sqlite:///example.db" study_name: str = ( f"optuna_study_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}" ) n_trials_filter: tuple|list = None, #(optuna.trial.TrialState.COMPLETE,) remove_outliers: float|int = None ## reserved, set by HyperTraining _multi_objective = None _parallel = None _n_threads = None _directions = None _direction = None _n_train_batches = None _n_valid_batches = None