from datetime import datetime import optuna import torch import util from hypertraining.hypertraining import HyperTraining from hypertraining.settings import ( GlobalSettings, DataSettings, PytorchSettings, ModelSettings, OptimizerSettings, OptunaSettings, ) global_settings = GlobalSettings( seed=42, ) data_settings = DataSettings( # config_path="data/*-128-16384-100000-0-0-17-0-PAM4-0.ini", config_path="data/20241204-131003-128-16384-100000-0-0-17-0-PAM4-0.ini", dtype="complex64", # symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber # symbols=13, # study: single_core_regen_20241123_011232 # symbols = (3, 13), symbols=4, # output_size = (11, 32), # ballpark 26 taps -> 2 taps per input symbol -> 1 tap every 0.01m (model has 52 inputs) # output_size=26, # study: single_core_regen_20241123_011232 (model_input_dim/2) output_size=(8, 30), shuffle=True, in_out_delay=0, xy_delay=0, drop_first=256, train_split=0.8, randomise_polarisations=False, ) pytorch_settings = PytorchSettings( epochs=10, batchsize=2**10, device="cuda", dataloader_workers=4, dataloader_prefetch=4, summary_dir=".runs", write_every=2**5, save_models=True, model_dir=".models", ) model_settings = ModelSettings( output_dim=2, n_hidden_layers = (2, 5), n_hidden_nodes=(2, 16), model_activation_func="EOActivation", dropout_prob=0, model_layer_function="ONNRect", model_layer_kwargs={"square": True}, # scale=(False, True), scale=False, model_layer_parametrizations=[ { "tensor_name": "weight", "parametrization": util.complexNN.energy_conserving, }, { "tensor_name": "alpha", "parametrization": util.complexNN.clamp, }, { "tensor_name": "gain", "parametrization": util.complexNN.clamp, "kwargs": { "min": 0, "max": float("inf"), }, }, { "tensor_name": "phase_bias", "parametrization": util.complexNN.clamp, "kwargs": { "min": 0, "max": 2 * torch.pi, }, }, { "tensor_name": "scales", "parametrization": util.complexNN.clamp, }, { "tensor_name": "angle", "parametrization": util.complexNN.clamp, "kwargs": { "min": -torch.pi, "max": torch.pi, }, }, { "tensor_name": "loss", "parametrization": util.complexNN.clamp, }, ], ) optimizer_settings = OptimizerSettings( optimizer="AdamW", optimizer_kwargs={ "lr": 5e-3, "amsgrad": True, # "weight_decay": 1e-7, }, ) optuna_settings = OptunaSettings( n_trials=1024, n_workers=8, timeout=3600, directions=("minimize",), metrics_names=("mse",), limit_examples=False, n_train_batches=500, # n_valid_batches = 100, storage="sqlite:///data/single_core_regen.db", study_name=f"single_core_regen_{datetime.now().strftime('%Y%m%d_%H%M%S')}", n_trials_filter=(optuna.trial.TrialState.COMPLETE, optuna.trial.TrialState.PRUNED), pruner="MedianPruner", pruner_kwargs=None ) if __name__ == "__main__": hyper_training = HyperTraining( global_settings=global_settings, data_settings=data_settings, pytorch_settings=pytorch_settings, model_settings=model_settings, optimizer_settings=optimizer_settings, optuna_settings=optuna_settings, ) hyper_training.setup_study() hyper_training.run_study() # best_trial = hyper_training.study.best_trial # best_model = hyper_training.define_model(best_trial).to( # hyper_training.pytorch_settings.device # ) # title_append, subtitle = hyper_training.build_title(best_trial) # hyper_training.plot_model_response( # best_trial, # model=best_model, # title_append=title_append, # subtitle=subtitle, # mode="eye", # show=True, # ) # print(f"Best model found for trial {best_trial.number}") # print(f"Best model error: {best_trial.value}") # print(f"Best model params: {best_trial.params}") # print() # print(best_model) # eye_fig = hyper_training.plot_eye() ...