Files
optical-regeneration/src/single-core-regen/regen.py

168 lines
4.4 KiB
Python

from datetime import datetime
import optuna
import torch
import util
from hypertraining.hypertraining import HyperTraining
from hypertraining.settings import (
GlobalSettings,
DataSettings,
PytorchSettings,
ModelSettings,
OptimizerSettings,
OptunaSettings,
)
global_settings = GlobalSettings(
seed=42,
)
data_settings = DataSettings(
# config_path="data/*-128-16384-100000-0-0-17-0-PAM4-0.ini",
config_path="data/20241204-131003-128-16384-100000-0-0-17-0-PAM4-0.ini",
dtype="complex64",
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
# symbols=13, # study: single_core_regen_20241123_011232
# symbols = (3, 13),
symbols=4,
# output_size = (11, 32), # ballpark 26 taps -> 2 taps per input symbol -> 1 tap every 0.01m (model has 52 inputs)
# output_size=26, # study: single_core_regen_20241123_011232 (model_input_dim/2)
output_size=(8, 30),
shuffle=True,
in_out_delay=0,
xy_delay=0,
drop_first=256,
train_split=0.8,
randomise_polarisations=False,
)
pytorch_settings = PytorchSettings(
epochs=10,
batchsize=2**10,
device="cuda",
dataloader_workers=4,
dataloader_prefetch=4,
summary_dir=".runs",
write_every=2**5,
save_models=True,
model_dir=".models",
)
model_settings = ModelSettings(
output_dim=2,
n_hidden_layers = (2, 5),
n_hidden_nodes=(2, 16),
model_activation_func="EOActivation",
dropout_prob=0,
model_layer_function="ONNRect",
model_layer_kwargs={"square": True},
# scale=(False, True),
scale=False,
model_layer_parametrizations=[
{
"tensor_name": "weight",
"parametrization": util.complexNN.energy_conserving,
},
{
"tensor_name": "alpha",
"parametrization": util.complexNN.clamp,
},
{
"tensor_name": "gain",
"parametrization": util.complexNN.clamp,
"kwargs": {
"min": 0,
"max": float("inf"),
},
},
{
"tensor_name": "phase_bias",
"parametrization": util.complexNN.clamp,
"kwargs": {
"min": 0,
"max": 2 * torch.pi,
},
},
{
"tensor_name": "scales",
"parametrization": util.complexNN.clamp,
},
{
"tensor_name": "angle",
"parametrization": util.complexNN.clamp,
"kwargs": {
"min": -torch.pi,
"max": torch.pi,
},
},
{
"tensor_name": "loss",
"parametrization": util.complexNN.clamp,
},
],
)
optimizer_settings = OptimizerSettings(
optimizer="AdamW",
optimizer_kwargs={
"lr": 5e-3,
"amsgrad": True,
# "weight_decay": 1e-7,
},
)
optuna_settings = OptunaSettings(
n_trials=1024,
n_workers=8,
timeout=3600,
directions=("minimize",),
metrics_names=("mse",),
limit_examples=False,
n_train_batches=500,
# n_valid_batches = 100,
storage="sqlite:///data/single_core_regen.db",
study_name=f"single_core_regen_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
n_trials_filter=(optuna.trial.TrialState.COMPLETE, optuna.trial.TrialState.PRUNED),
pruner="MedianPruner",
pruner_kwargs=None
)
if __name__ == "__main__":
hyper_training = HyperTraining(
global_settings=global_settings,
data_settings=data_settings,
pytorch_settings=pytorch_settings,
model_settings=model_settings,
optimizer_settings=optimizer_settings,
optuna_settings=optuna_settings,
)
hyper_training.setup_study()
hyper_training.run_study()
# best_trial = hyper_training.study.best_trial
# best_model = hyper_training.define_model(best_trial).to(
# hyper_training.pytorch_settings.device
# )
# title_append, subtitle = hyper_training.build_title(best_trial)
# hyper_training.plot_model_response(
# best_trial,
# model=best_model,
# title_append=title_append,
# subtitle=subtitle,
# mode="eye",
# show=True,
# )
# print(f"Best model found for trial {best_trial.number}")
# print(f"Best model error: {best_trial.value}")
# print(f"Best model params: {best_trial.params}")
# print()
# print(best_model)
# eye_fig = hyper_training.plot_eye()
...