move hypertraining class into separate file;
move settings dataclasses into separate file; add SemiUnitaryLayer; clean up model response plotting code; cnt hyperparameter search
This commit is contained in:
77
src/single-core-regen/hypertraining/settings.py
Normal file
77
src/single-core-regen/hypertraining/settings.py
Normal file
@@ -0,0 +1,77 @@
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
# global settings
|
||||
@dataclass(frozen=True)
|
||||
class GlobalSettings:
|
||||
seed: int = 42
|
||||
|
||||
|
||||
# data settings
|
||||
@dataclass
|
||||
class DataSettings:
|
||||
config_path: str # = "data/*-128-16384-100000-0-0-17-0-PAM4-0.ini"
|
||||
dtype: tuple = ("complex64", "float64")
|
||||
symbols: tuple | float | int = 8
|
||||
model_input_dim: tuple | float | int = 64
|
||||
shuffle: bool = True
|
||||
in_out_delay: float = 0
|
||||
xy_delay: tuple | float | int = 0
|
||||
drop_first: int = 1000
|
||||
train_split: float = 0.8
|
||||
|
||||
|
||||
# pytorch settings
|
||||
@dataclass
|
||||
class PytorchSettings:
|
||||
epochs: int = 1
|
||||
batchsize: int = 2**10
|
||||
|
||||
device: str = "cuda"
|
||||
|
||||
dataloader_workers: int = 2
|
||||
dataloader_prefetch: int = 2
|
||||
|
||||
model_dir: str = ".models"
|
||||
|
||||
summary_dir: str = ".runs"
|
||||
write_every: int = 10
|
||||
head_symbols: int = 40
|
||||
eye_symbols: int = 1000
|
||||
|
||||
|
||||
# model settings
|
||||
@dataclass
|
||||
class ModelSettings:
|
||||
output_dim: int = 2
|
||||
model_n_layers: tuple | int = 3
|
||||
unit_count: tuple | int = 8
|
||||
# n_units_range: tuple | int = (2, 32)
|
||||
# activation_func_range: tuple = ("ModReLU", "ZReLU", "CReLU", "Mag", "Identity")
|
||||
model_activation_func: tuple = ("ModReLU",)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OptimizerSettings:
|
||||
optimizer: tuple | str = ("Adam", "RMSprop", "SGD")
|
||||
learning_rate: tuple | float = (1e-5, 1e-1)
|
||||
scheduler: str | None = None
|
||||
scheduler_kwargs: dict | None = None
|
||||
|
||||
|
||||
# optuna settings
|
||||
@dataclass
|
||||
class OptunaSettings:
|
||||
n_trials: int = 128
|
||||
n_threads: int = 4
|
||||
timeout: int = 600
|
||||
directions: tuple = ("minimize",)
|
||||
metrics_names: tuple = ("mse",)
|
||||
limit_examples: bool = True
|
||||
n_train_batches: int = 100
|
||||
n_valid_batches: int = 100
|
||||
storage: str = "sqlite:///example.db"
|
||||
study_name: str = (
|
||||
f"optuna_study_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}"
|
||||
)
|
||||
Reference in New Issue
Block a user