168 lines
4.4 KiB
Python
168 lines
4.4 KiB
Python
from datetime import datetime
|
|
|
|
import optuna
|
|
import torch
|
|
import util
|
|
from hypertraining.hypertraining import HyperTraining
|
|
from hypertraining.settings import (
|
|
GlobalSettings,
|
|
DataSettings,
|
|
PytorchSettings,
|
|
ModelSettings,
|
|
OptimizerSettings,
|
|
OptunaSettings,
|
|
)
|
|
|
|
global_settings = GlobalSettings(
|
|
seed=42,
|
|
)
|
|
|
|
data_settings = DataSettings(
|
|
# config_path="data/*-128-16384-100000-0-0-17-0-PAM4-0.ini",
|
|
config_path="data/20241204-131003-128-16384-100000-0-0-17-0-PAM4-0.ini",
|
|
dtype="complex64",
|
|
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
|
|
# symbols=13, # study: single_core_regen_20241123_011232
|
|
# symbols = (3, 13),
|
|
symbols=4,
|
|
# output_size = (11, 32), # ballpark 26 taps -> 2 taps per input symbol -> 1 tap every 0.01m (model has 52 inputs)
|
|
# output_size=26, # study: single_core_regen_20241123_011232 (model_input_dim/2)
|
|
output_size=(8, 30),
|
|
shuffle=True,
|
|
in_out_delay=0,
|
|
xy_delay=0,
|
|
drop_first=256,
|
|
train_split=0.8,
|
|
randomise_polarisations=False,
|
|
)
|
|
|
|
pytorch_settings = PytorchSettings(
|
|
epochs=10,
|
|
batchsize=2**10,
|
|
device="cuda",
|
|
dataloader_workers=4,
|
|
dataloader_prefetch=4,
|
|
summary_dir=".runs",
|
|
write_every=2**5,
|
|
save_models=True,
|
|
model_dir=".models",
|
|
)
|
|
|
|
model_settings = ModelSettings(
|
|
output_dim=2,
|
|
n_hidden_layers = (2, 5),
|
|
n_hidden_nodes=(2, 16),
|
|
model_activation_func="EOActivation",
|
|
dropout_prob=0,
|
|
model_layer_function="ONNRect",
|
|
model_layer_kwargs={"square": True},
|
|
# scale=(False, True),
|
|
scale=False,
|
|
model_layer_parametrizations=[
|
|
{
|
|
"tensor_name": "weight",
|
|
"parametrization": util.complexNN.energy_conserving,
|
|
},
|
|
{
|
|
"tensor_name": "alpha",
|
|
"parametrization": util.complexNN.clamp,
|
|
},
|
|
{
|
|
"tensor_name": "gain",
|
|
"parametrization": util.complexNN.clamp,
|
|
"kwargs": {
|
|
"min": 0,
|
|
"max": float("inf"),
|
|
},
|
|
},
|
|
{
|
|
"tensor_name": "phase_bias",
|
|
"parametrization": util.complexNN.clamp,
|
|
"kwargs": {
|
|
"min": 0,
|
|
"max": 2 * torch.pi,
|
|
},
|
|
},
|
|
{
|
|
"tensor_name": "scales",
|
|
"parametrization": util.complexNN.clamp,
|
|
},
|
|
{
|
|
"tensor_name": "angle",
|
|
"parametrization": util.complexNN.clamp,
|
|
"kwargs": {
|
|
"min": -torch.pi,
|
|
"max": torch.pi,
|
|
},
|
|
},
|
|
{
|
|
"tensor_name": "loss",
|
|
"parametrization": util.complexNN.clamp,
|
|
},
|
|
],
|
|
)
|
|
|
|
optimizer_settings = OptimizerSettings(
|
|
optimizer="AdamW",
|
|
optimizer_kwargs={
|
|
"lr": 5e-3,
|
|
"amsgrad": True,
|
|
# "weight_decay": 1e-7,
|
|
},
|
|
)
|
|
|
|
optuna_settings = OptunaSettings(
|
|
n_trials=1024,
|
|
n_workers=8,
|
|
timeout=3600,
|
|
directions=("minimize",),
|
|
metrics_names=("mse",),
|
|
limit_examples=False,
|
|
n_train_batches=500,
|
|
# n_valid_batches = 100,
|
|
storage="sqlite:///data/single_core_regen.db",
|
|
study_name=f"single_core_regen_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
|
|
n_trials_filter=(optuna.trial.TrialState.COMPLETE, optuna.trial.TrialState.PRUNED),
|
|
pruner="MedianPruner",
|
|
pruner_kwargs=None
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
hyper_training = HyperTraining(
|
|
global_settings=global_settings,
|
|
data_settings=data_settings,
|
|
pytorch_settings=pytorch_settings,
|
|
model_settings=model_settings,
|
|
optimizer_settings=optimizer_settings,
|
|
optuna_settings=optuna_settings,
|
|
)
|
|
|
|
hyper_training.setup_study()
|
|
|
|
hyper_training.run_study()
|
|
# best_trial = hyper_training.study.best_trial
|
|
|
|
# best_model = hyper_training.define_model(best_trial).to(
|
|
# hyper_training.pytorch_settings.device
|
|
# )
|
|
|
|
# title_append, subtitle = hyper_training.build_title(best_trial)
|
|
# hyper_training.plot_model_response(
|
|
# best_trial,
|
|
# model=best_model,
|
|
# title_append=title_append,
|
|
# subtitle=subtitle,
|
|
# mode="eye",
|
|
# show=True,
|
|
# )
|
|
|
|
# print(f"Best model found for trial {best_trial.number}")
|
|
# print(f"Best model error: {best_trial.value}")
|
|
# print(f"Best model params: {best_trial.params}")
|
|
# print()
|
|
# print(best_model)
|
|
|
|
# eye_fig = hyper_training.plot_eye()
|
|
...
|