add: regen.py (main hyperparameter training file)
feat: add utility functions for fiber dataset visualization and hyperparameter training; housekeeping: rename dataset.py -> datasets.py
This commit is contained in:
194
src/single-core-regen/testing/learn_optuna.py
Normal file
194
src/single-core-regen/testing/learn_optuna.py
Normal file
@@ -0,0 +1,194 @@
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import optuna
|
||||
import warnings
|
||||
from util.optuna_vis import show_figures
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
import torch.utils.data
|
||||
|
||||
from torchvision import datasets
|
||||
from torchvision import transforms
|
||||
import multiprocessing
|
||||
|
||||
# from util.dataset import SlicedDataset
|
||||
|
||||
DEVICE = torch.device("cuda")
|
||||
BATCHSIZE = 128
|
||||
CLASSES = 10
|
||||
DIR = Path(__file__).parent
|
||||
EPOCHS = 100
|
||||
N_TRAIN_EXAMPLES = BATCHSIZE * 30
|
||||
N_VALID_EXAMPLES = BATCHSIZE * 10
|
||||
|
||||
n_trials = 128
|
||||
n_threads = 16
|
||||
|
||||
|
||||
def define_model(trial):
|
||||
n_layers = trial.suggest_int("n_layers", 1, 3)
|
||||
layers = []
|
||||
|
||||
in_features = 28 * 28
|
||||
for i in range(n_layers):
|
||||
out_features = trial.suggest_int(f"n_units_l{i}", 4, 128)
|
||||
layers.append(nn.Linear(in_features, out_features))
|
||||
layers.append(nn.ReLU())
|
||||
p = trial.suggest_float(f"dropout_l{i}", 0.2, 0.5)
|
||||
layers.append(nn.Dropout(p))
|
||||
|
||||
in_features = out_features
|
||||
|
||||
layers.append(nn.Linear(in_features, CLASSES))
|
||||
layers.append(nn.LogSoftmax(dim=1))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
def get_mnist():
|
||||
# Load FashionMNIST dataset.
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
datasets.FashionMNIST(
|
||||
DIR / ".data", train=True, download=True, transform=transforms.ToTensor()
|
||||
),
|
||||
batch_size=BATCHSIZE,
|
||||
shuffle=True,
|
||||
)
|
||||
valid_loader = torch.utils.data.DataLoader(
|
||||
datasets.FashionMNIST(
|
||||
DIR / ".data", train=False, transform=transforms.ToTensor()
|
||||
),
|
||||
batch_size=BATCHSIZE,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
return train_loader, valid_loader
|
||||
|
||||
|
||||
def objective(trial):
|
||||
model = define_model(trial).to(DEVICE)
|
||||
|
||||
optimizer_name = trial.suggest_categorical("optimizer", ["Adam", "RMSprop", "SGD"])
|
||||
lr = trial.suggest_float("lr", 1e-5, 1e-1, log=True)
|
||||
optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)
|
||||
|
||||
train_loader, valid_loader = get_mnist()
|
||||
|
||||
for epoch in range(EPOCHS):
|
||||
train_model(model, optimizer, train_loader)
|
||||
accuracy, num_params = eval_model(model, valid_loader)
|
||||
|
||||
return accuracy, num_params
|
||||
|
||||
|
||||
def eval_model(model, valid_loader):
|
||||
model.eval()
|
||||
correct = 0
|
||||
with torch.no_grad():
|
||||
for batch_idx, (data, target) in enumerate(valid_loader):
|
||||
if batch_idx * BATCHSIZE >= N_VALID_EXAMPLES:
|
||||
break
|
||||
|
||||
data, target = data.view(data.size(0), -1).to(DEVICE), target.to(DEVICE)
|
||||
output = model(data)
|
||||
|
||||
pred = output.argmax(dim=1, keepdim=True)
|
||||
correct += pred.eq(target.view_as(pred)).sum().item()
|
||||
|
||||
accuracy = correct / min(len(valid_loader.dataset), N_VALID_EXAMPLES)
|
||||
num_params = sum(p.numel() for p in model.parameters())
|
||||
return accuracy, num_params
|
||||
|
||||
|
||||
def train_model(model, optimizer, train_loader):
|
||||
model.train()
|
||||
for batch_idx, (data, target) in enumerate(train_loader):
|
||||
if batch_idx * BATCHSIZE >= N_TRAIN_EXAMPLES:
|
||||
break
|
||||
|
||||
data, target = data.view(data.size(0), -1).to(DEVICE), target.to(DEVICE)
|
||||
|
||||
optimizer.zero_grad()
|
||||
output = model(data)
|
||||
loss = F.nll_loss(output, target)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
|
||||
def run_optimize(n_trials, study):
|
||||
study.optimize(objective, n_trials=n_trials, timeout=600)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
study_name = f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')} mnist example"
|
||||
storage = "sqlite:///db.sqlite3"
|
||||
directions = ["maximize", "minimize"]
|
||||
|
||||
study = optuna.create_study(
|
||||
directions=directions,
|
||||
storage=storage,
|
||||
study_name=study_name,
|
||||
)
|
||||
|
||||
with warnings.catch_warnings(action="ignore"):
|
||||
study.set_metric_names(["accuracy", "num params"])
|
||||
|
||||
n_threads = min(n_trials, n_threads)
|
||||
|
||||
processes = []
|
||||
for _ in range(n_threads):
|
||||
p = multiprocessing.Process(
|
||||
target=run_optimize, args=(n_trials // n_threads, study)
|
||||
)
|
||||
p.start()
|
||||
processes.append(p)
|
||||
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
remaining_trials = n_trials - ((n_trials // n_threads) * n_threads)
|
||||
if remaining_trials:
|
||||
print(
|
||||
f"\nRunning last {remaining_trials} trial{'s' if remaining_trials > 1 else ''}:"
|
||||
)
|
||||
run_optimize(directions, remaining_trials, study_name, storage)
|
||||
|
||||
print(f"Number of trials on the Pareto front: {len(study.best_trials)}")
|
||||
|
||||
trial_with_highest_accuracy = max(study.best_trials, key=lambda t: t.values[1])
|
||||
print("Trial with highest accuracy: ")
|
||||
print(f"\tnumber: {trial_with_highest_accuracy.number}")
|
||||
print(f"\tparams: {trial_with_highest_accuracy.params}")
|
||||
print(f"\tvalues: {trial_with_highest_accuracy.values}")
|
||||
|
||||
# for trial in trials:
|
||||
# print(f"Trial {trial.number}")
|
||||
# print(f" Accuracy: {trial.values[0]}")
|
||||
# print(f" n_params: {int(trial.values[1])}")
|
||||
# print( " Params: ")
|
||||
# for key, value in trial.params.items():
|
||||
# print(" {}: {}".format(key, value))
|
||||
# print()
|
||||
|
||||
# print(" Value: ", trial.value)
|
||||
|
||||
# print(" Params: ")
|
||||
# for key, value in trial.params.items():
|
||||
# print(" {}: {}".format(key, value))
|
||||
|
||||
figures = []
|
||||
figures.append(
|
||||
optuna.visualization.plot_pareto_front(
|
||||
study, target_names=["accuracy", "num_params"]
|
||||
)
|
||||
)
|
||||
figures.append(optuna.visualization.plot_timeline(study))
|
||||
|
||||
plt = show_figures(*figures)
|
||||
|
||||
print()
|
||||
# plt.show()
|
||||
Reference in New Issue
Block a user