add: regen.py (main hyperparameter training file)

feat: add utility functions for fiber dataset visualization and hyperparameter training;
housekeeping: rename dataset.py -> datasets.py
This commit is contained in:
Joseph Hopfmüller
2024-11-17 22:22:37 +01:00
parent 05a3ee9394
commit 9ec548757d
6 changed files with 774 additions and 53 deletions

View File

@@ -0,0 +1,194 @@
from datetime import datetime
from pathlib import Path
import optuna
import warnings
from util.optuna_vis import show_figures
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data
from torchvision import datasets
from torchvision import transforms
import multiprocessing
# from util.dataset import SlicedDataset
DEVICE = torch.device("cuda")
BATCHSIZE = 128
CLASSES = 10
DIR = Path(__file__).parent
EPOCHS = 100
N_TRAIN_EXAMPLES = BATCHSIZE * 30
N_VALID_EXAMPLES = BATCHSIZE * 10
n_trials = 128
n_threads = 16
def define_model(trial):
n_layers = trial.suggest_int("n_layers", 1, 3)
layers = []
in_features = 28 * 28
for i in range(n_layers):
out_features = trial.suggest_int(f"n_units_l{i}", 4, 128)
layers.append(nn.Linear(in_features, out_features))
layers.append(nn.ReLU())
p = trial.suggest_float(f"dropout_l{i}", 0.2, 0.5)
layers.append(nn.Dropout(p))
in_features = out_features
layers.append(nn.Linear(in_features, CLASSES))
layers.append(nn.LogSoftmax(dim=1))
return nn.Sequential(*layers)
def get_mnist():
# Load FashionMNIST dataset.
train_loader = torch.utils.data.DataLoader(
datasets.FashionMNIST(
DIR / ".data", train=True, download=True, transform=transforms.ToTensor()
),
batch_size=BATCHSIZE,
shuffle=True,
)
valid_loader = torch.utils.data.DataLoader(
datasets.FashionMNIST(
DIR / ".data", train=False, transform=transforms.ToTensor()
),
batch_size=BATCHSIZE,
shuffle=True,
)
return train_loader, valid_loader
def objective(trial):
model = define_model(trial).to(DEVICE)
optimizer_name = trial.suggest_categorical("optimizer", ["Adam", "RMSprop", "SGD"])
lr = trial.suggest_float("lr", 1e-5, 1e-1, log=True)
optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)
train_loader, valid_loader = get_mnist()
for epoch in range(EPOCHS):
train_model(model, optimizer, train_loader)
accuracy, num_params = eval_model(model, valid_loader)
return accuracy, num_params
def eval_model(model, valid_loader):
model.eval()
correct = 0
with torch.no_grad():
for batch_idx, (data, target) in enumerate(valid_loader):
if batch_idx * BATCHSIZE >= N_VALID_EXAMPLES:
break
data, target = data.view(data.size(0), -1).to(DEVICE), target.to(DEVICE)
output = model(data)
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
accuracy = correct / min(len(valid_loader.dataset), N_VALID_EXAMPLES)
num_params = sum(p.numel() for p in model.parameters())
return accuracy, num_params
def train_model(model, optimizer, train_loader):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
if batch_idx * BATCHSIZE >= N_TRAIN_EXAMPLES:
break
data, target = data.view(data.size(0), -1).to(DEVICE), target.to(DEVICE)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
def run_optimize(n_trials, study):
study.optimize(objective, n_trials=n_trials, timeout=600)
if __name__ == "__main__":
study_name = f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')} mnist example"
storage = "sqlite:///db.sqlite3"
directions = ["maximize", "minimize"]
study = optuna.create_study(
directions=directions,
storage=storage,
study_name=study_name,
)
with warnings.catch_warnings(action="ignore"):
study.set_metric_names(["accuracy", "num params"])
n_threads = min(n_trials, n_threads)
processes = []
for _ in range(n_threads):
p = multiprocessing.Process(
target=run_optimize, args=(n_trials // n_threads, study)
)
p.start()
processes.append(p)
for p in processes:
p.join()
remaining_trials = n_trials - ((n_trials // n_threads) * n_threads)
if remaining_trials:
print(
f"\nRunning last {remaining_trials} trial{'s' if remaining_trials > 1 else ''}:"
)
run_optimize(directions, remaining_trials, study_name, storage)
print(f"Number of trials on the Pareto front: {len(study.best_trials)}")
trial_with_highest_accuracy = max(study.best_trials, key=lambda t: t.values[1])
print("Trial with highest accuracy: ")
print(f"\tnumber: {trial_with_highest_accuracy.number}")
print(f"\tparams: {trial_with_highest_accuracy.params}")
print(f"\tvalues: {trial_with_highest_accuracy.values}")
# for trial in trials:
# print(f"Trial {trial.number}")
# print(f" Accuracy: {trial.values[0]}")
# print(f" n_params: {int(trial.values[1])}")
# print( " Params: ")
# for key, value in trial.params.items():
# print(" {}: {}".format(key, value))
# print()
# print(" Value: ", trial.value)
# print(" Params: ")
# for key, value in trial.params.items():
# print(" {}: {}".format(key, value))
figures = []
figures.append(
optuna.visualization.plot_pareto_front(
study, target_names=["accuracy", "num_params"]
)
)
figures.append(optuna.visualization.plot_timeline(study))
plt = show_figures(*figures)
print()
# plt.show()