diff --git a/src/single-core-regen/hypertraining/hypertraining.py b/src/single-core-regen/hypertraining/hypertraining.py index 3d5cca7..5894475 100644 --- a/src/single-core-regen/hypertraining/hypertraining.py +++ b/src/single-core-regen/hypertraining/hypertraining.py @@ -6,6 +6,7 @@ import matplotlib.pyplot as plt import numpy as np import optuna +import optunahub import warnings import torch @@ -110,14 +111,18 @@ class HyperTraining: return optuna.get_all_study_summaries(storage=self.optuna_settings.storage) def setup_study(self): + module = optunahub.load_module(package="samplers/auto_sampler") self.study = optuna.create_study( study_name=self.optuna_settings.study_name, storage=self.optuna_settings.storage, load_if_exists=True, direction=self.optuna_settings.direction, directions=self.optuna_settings.directions, + sampler=module.AutoSampler(), ) + print("using sampler:", self.study.sampler) + with warnings.catch_warnings(action="ignore"): self.study.set_metric_names(self.optuna_settings.metrics_names)