add autosampler support
This commit is contained in:
@@ -6,6 +6,7 @@ import matplotlib.pyplot as plt
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import optuna
|
import optuna
|
||||||
|
import optunahub
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -110,14 +111,18 @@ class HyperTraining:
|
|||||||
return optuna.get_all_study_summaries(storage=self.optuna_settings.storage)
|
return optuna.get_all_study_summaries(storage=self.optuna_settings.storage)
|
||||||
|
|
||||||
def setup_study(self):
|
def setup_study(self):
|
||||||
|
module = optunahub.load_module(package="samplers/auto_sampler")
|
||||||
self.study = optuna.create_study(
|
self.study = optuna.create_study(
|
||||||
study_name=self.optuna_settings.study_name,
|
study_name=self.optuna_settings.study_name,
|
||||||
storage=self.optuna_settings.storage,
|
storage=self.optuna_settings.storage,
|
||||||
load_if_exists=True,
|
load_if_exists=True,
|
||||||
direction=self.optuna_settings.direction,
|
direction=self.optuna_settings.direction,
|
||||||
directions=self.optuna_settings.directions,
|
directions=self.optuna_settings.directions,
|
||||||
|
sampler=module.AutoSampler(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
print("using sampler:", self.study.sampler)
|
||||||
|
|
||||||
with warnings.catch_warnings(action="ignore"):
|
with warnings.catch_warnings(action="ignore"):
|
||||||
self.study.set_metric_names(self.optuna_settings.metrics_names)
|
self.study.set_metric_names(self.optuna_settings.metrics_names)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user