add autosampler support
This commit is contained in:
@@ -6,6 +6,7 @@ import matplotlib.pyplot as plt
|
||||
|
||||
import numpy as np
|
||||
import optuna
|
||||
import optunahub
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
@@ -110,14 +111,18 @@ class HyperTraining:
|
||||
return optuna.get_all_study_summaries(storage=self.optuna_settings.storage)
|
||||
|
||||
def setup_study(self):
|
||||
module = optunahub.load_module(package="samplers/auto_sampler")
|
||||
self.study = optuna.create_study(
|
||||
study_name=self.optuna_settings.study_name,
|
||||
storage=self.optuna_settings.storage,
|
||||
load_if_exists=True,
|
||||
direction=self.optuna_settings.direction,
|
||||
directions=self.optuna_settings.directions,
|
||||
sampler=module.AutoSampler(),
|
||||
)
|
||||
|
||||
print("using sampler:", self.study.sampler)
|
||||
|
||||
with warnings.catch_warnings(action="ignore"):
|
||||
self.study.set_metric_names(self.optuna_settings.metrics_names)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user