refactor complex loss functions for improved readability; update settings and dataset classes for consistency
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ class DataSettings:
|
||||
config_path: str # = "data/*-128-16384-100000-0-0-17-0-PAM4-0.ini"
|
||||
dtype: tuple = ("complex64", "float64")
|
||||
symbols: tuple | float | int = 8
|
||||
model_input_dim: tuple | float | int = 64
|
||||
output_size: tuple | float | int = 64
|
||||
shuffle: bool = True
|
||||
in_out_delay: float = 0
|
||||
xy_delay: tuple | float | int = 0
|
||||
@@ -33,6 +33,7 @@ class PytorchSettings:
|
||||
dataloader_workers: int = 2
|
||||
dataloader_prefetch: int = 2
|
||||
|
||||
save_models: bool = True
|
||||
model_dir: str = ".models"
|
||||
|
||||
summary_dir: str = ".runs"
|
||||
@@ -45,11 +46,10 @@ class PytorchSettings:
|
||||
@dataclass
|
||||
class ModelSettings:
|
||||
output_dim: int = 2
|
||||
model_n_layers: tuple | int = 3
|
||||
unit_count: tuple | int = 8
|
||||
# n_units_range: tuple | int = (2, 32)
|
||||
# activation_func_range: tuple = ("ModReLU", "ZReLU", "CReLU", "Mag", "Identity")
|
||||
model_activation_func: tuple = ("ModReLU",)
|
||||
n_hidden_layers: tuple | int = 3
|
||||
n_hidden_nodes: tuple | int = 8
|
||||
model_activation_func: tuple = "ModReLU"
|
||||
overrides: dict = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -60,18 +60,36 @@ class OptimizerSettings:
|
||||
scheduler_kwargs: dict | None = None
|
||||
|
||||
|
||||
def _pruner_default_kwargs():
|
||||
# MedianPruner
|
||||
return {
|
||||
"n_startup_trials": 0,
|
||||
"n_warmup_steps": 5,
|
||||
}
|
||||
# optuna settings
|
||||
@dataclass
|
||||
class OptunaSettings:
|
||||
n_trials: int = 128
|
||||
n_threads: int = 4
|
||||
timeout: int = 600
|
||||
n_workers: int = 1
|
||||
timeout: int = None
|
||||
pruner: str = "MedianPruner"
|
||||
pruner_kwargs: dict = field(default_factory=_pruner_default_kwargs)
|
||||
directions: tuple = ("minimize",)
|
||||
metrics_names: tuple = ("mse",)
|
||||
limit_examples: bool = True
|
||||
n_train_batches: int = 100
|
||||
n_valid_batches: int = 100
|
||||
n_train_batches: int = float("inf")
|
||||
n_valid_batches: int = float("inf")
|
||||
storage: str = "sqlite:///example.db"
|
||||
study_name: str = (
|
||||
f"optuna_study_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}"
|
||||
)
|
||||
)
|
||||
n_trials_filter: tuple|list = None, #(optuna.trial.TrialState.COMPLETE,)
|
||||
remove_outliers: float|int = None
|
||||
## reserved, set by HyperTraining
|
||||
_multi_objective = None
|
||||
_parallel = None
|
||||
_n_threads = None
|
||||
_directions = None
|
||||
_direction = None
|
||||
_n_train_batches = None
|
||||
_n_valid_batches = None
|
||||
|
||||
Reference in New Issue
Block a user