refactor complex loss functions for improved readability; update settings and dataset classes for consistency

This commit is contained in:
Joseph Hopfmüller
2024-11-24 01:55:32 +01:00
parent 9a16a5637d
commit 7343ccb3a5
4 changed files with 392 additions and 361 deletions

View File

@@ -1,4 +1,4 @@
from dataclasses import dataclass
from dataclasses import dataclass, field
from datetime import datetime
@@ -14,7 +14,7 @@ class DataSettings:
config_path: str # = "data/*-128-16384-100000-0-0-17-0-PAM4-0.ini"
dtype: tuple = ("complex64", "float64")
symbols: tuple | float | int = 8
model_input_dim: tuple | float | int = 64
output_size: tuple | float | int = 64
shuffle: bool = True
in_out_delay: float = 0
xy_delay: tuple | float | int = 0
@@ -33,6 +33,7 @@ class PytorchSettings:
dataloader_workers: int = 2
dataloader_prefetch: int = 2
save_models: bool = True
model_dir: str = ".models"
summary_dir: str = ".runs"
@@ -45,11 +46,10 @@ class PytorchSettings:
@dataclass
class ModelSettings:
output_dim: int = 2
model_n_layers: tuple | int = 3
unit_count: tuple | int = 8
# n_units_range: tuple | int = (2, 32)
# activation_func_range: tuple = ("ModReLU", "ZReLU", "CReLU", "Mag", "Identity")
model_activation_func: tuple = ("ModReLU",)
n_hidden_layers: tuple | int = 3
n_hidden_nodes: tuple | int = 8
model_activation_func: tuple = "ModReLU"
overrides: dict = field(default_factory=dict)
@dataclass
@@ -60,18 +60,36 @@ class OptimizerSettings:
scheduler_kwargs: dict | None = None
def _pruner_default_kwargs():
# MedianPruner
return {
"n_startup_trials": 0,
"n_warmup_steps": 5,
}
# optuna settings
@dataclass
class OptunaSettings:
n_trials: int = 128
n_threads: int = 4
timeout: int = 600
n_workers: int = 1
timeout: int = None
pruner: str = "MedianPruner"
pruner_kwargs: dict = field(default_factory=_pruner_default_kwargs)
directions: tuple = ("minimize",)
metrics_names: tuple = ("mse",)
limit_examples: bool = True
n_train_batches: int = 100
n_valid_batches: int = 100
n_train_batches: int = float("inf")
n_valid_batches: int = float("inf")
storage: str = "sqlite:///example.db"
study_name: str = (
f"optuna_study_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}"
)
)
n_trials_filter: tuple|list = None, #(optuna.trial.TrialState.COMPLETE,)
remove_outliers: float|int = None
## reserved, set by HyperTraining
_multi_objective = None
_parallel = None
_n_threads = None
_directions = None
_direction = None
_n_train_batches = None
_n_valid_batches = None