add: regen.py (main hyperparameter training file)

feat: add utility functions for fiber dataset visualization and hyperparameter training;
housekeeping: rename dataset.py -> datasets.py
This commit is contained in:
Joseph Hopfmüller
2024-11-17 22:22:37 +01:00
parent 05a3ee9394
commit 9ec548757d
6 changed files with 774 additions and 53 deletions

View File

@@ -0,0 +1,287 @@
from dataclasses import dataclass
from datetime import datetime
import time
import matplotlib.pyplot as plt
import numpy as np
import optuna
import warnings
import torch
import torch.nn as nn
import torch.functional as F
import torch.optim as optim
import torch.utils.data
import multiprocessing
from util.datasets import FiberRegenerationDataset
import util
# global settings
@dataclass
class GlobalSettings:
seed: int = 42
# data settings
@dataclass
class DataSettings:
config_path: str = "data/*-128-16384-10000-0-0-17-0-PAM4-0.ini"
symbols_range: tuple = (1, 100)
data_size_range: tuple = (1, 20)
target_delay: float = 0
xy_delay_range: tuple = (0, 1)
drop_first: int = 1000
train_split: float = 0.8
# pytorch settings
@dataclass
class PytorchSettings:
device: str = "cuda"
batchsize: int = 128
epochs: int = 100
# model settings
@dataclass
class ModelSettings:
output_size: int = 2
n_layer_range: tuple = (1, 3)
n_units_range: tuple = (4, 128)
activation_func_range: tuple = ("ReLU",)
@dataclass
class OptimizerSettings:
optimizer_range: tuple = ("Adam", "RMSprop", "SGD")
lr_range: tuple = (1e-5, 1e-1)
# optuna settings
@dataclass
class OptunaSettings:
n_trials: int = 128
n_threads: int = 16
timeout: int = 600
directions: tuple = ("maximize",)
limit_examples: bool = True
n_train_examples: int = PytorchSettings.batchsize * 30
n_valid_examples: int = PytorchSettings.batchsize * 10
storage: str = "sqlite:///optuna_single_core_regen.db"
study_name: str = (
f"single_core_regen_{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
)
metrics_names: tuple = ("accuracy",)
class HyperTraining:
def __init__(self):
self.global_settings = GlobalSettings()
self.data_settings = DataSettings()
self.pytorch_settings = PytorchSettings()
self.model_settings = ModelSettings()
self.optimizer_settings = OptimizerSettings()
self.optuna_settings = OptunaSettings()
# set some extra settings to make the code more readable
self._extra_optuna_settings()
def setup_study(self):
self.study = optuna.create_study(
study_name=self.optuna_settings.study_name,
storage=self.optuna_settings.storage,
load_if_exists=True,
direction=self.optuna_settings.direction,
directions=self.optuna_settings.directions,
)
with warnings.catch_warnings(action="ignore"):
self.study.set_metric_names(self.optuna_settings.metrics_names)
self.n_threads = min(self.optuna_settings.n_trials, self.optuna_settings.n_threads)
self.processes = []
for _ in range(self.n_threads):
p = multiprocessing.Process(target=self._run_optimize)
self.processes.append(p)
def run_study(self):
for p in self.processes:
p.start()
for p in self.processes:
p.join()
remaining_trials = self.optuna_settings.n_trials - self.optuna_settings.n_trials % self.optuna_settings.n_threads
if remaining_trials:
self._run_optimize(remaining_trials)
def _run_optimize(self, n_trials):
self.study.optimize(self.objective, n_trials=n_trials, timeout=self.optuna_settings.timeout)
def eye(self, show=True):
util.plot.eye(self.data_settings.config_path, show=show)
def _extra_optuna_settings(self):
self.optuna_settings.multi_objective = len(self.optuna_settings.directions) > 1
if self.optuna_settings.multi_objective:
self.optuna_settings.direction = None
else:
self.optuna_settings.direction = self.optuna_settings.directions[0]
self.optuna_settings.directions = None
self.optuna_settings.n_train_examples = (
self.optuna_settings.n_train_examples
if self.optuna_settings.limit_examples
else float("inf")
)
self.optuna_settings.n_valid_examples = (
self.optuna_settings.n_valid_examples
if self.optuna_settings.limit_examples
else float("inf")
)
def define_model(self, trial: optuna.Trial):
n_layers = trial.suggest_int(
"model_n_layers", *self.model_settings.n_layer_range
)
layers = []
# REVIEW does that work?
in_features = trial.params["dataset_data_size"] * 2
for i in range(n_layers):
out_features = trial.suggest_int(
f"model_n_units_l{i}", *self.model_settings.n_units_range
)
activation_func = trial.suggest_categorical(
f"model_activation_func_l{i}", self.model_settings.activation_func_range
)
layers.append(nn.Linear(in_features, out_features))
layers.append(getattr(nn, activation_func))
in_features = out_features
layers.append(nn.Linear(in_features, self.model_settings.output_size))
return nn.Sequential(*layers)
def get_sliced_data(self, trial: optuna.Trial):
assert ModelSettings.input_size % 2 == 0, "input_dim must be even"
symbols = trial.suggest_float(
"dataset_symbols", *self.data_settings.symbols_range, log=True
)
xy_delay = trial.suggest_float(
"dataset_xy_delay", *self.data_settings.xy_delay_range
)
data_size = trial.suggest_int(
"dataset_data_size", *self.data_settings.data_size_range
)
# get dataset
dataset = FiberRegenerationDataset(
file_path=self.data_settings.config_path,
symbols=symbols,
data_size=data_size, # two channels (x,y)
target_delay=self.data_settings.target_delay,
xy_delay=xy_delay,
drop_first=self.data_settings.drop_first,
)
dataset_size = len(dataset)
indices = list(range(dataset_size))
split = int(np.floor(self.data_settings.train_split * dataset_size))
np.random.seed(self.global_settings.seed)
np.random.shuffle(indices)
train_indices, valid_indices = indices[:split], indices[split:]
train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
valid_sampler = torch.utils.data.SubsetRandomSampler(valid_indices)
train_loader = torch.utils.data.DataLoader(
dataset, batch_size=self.pytorch_settings.batchsize, sampler=train_sampler
)
valid_loader = torch.utils.data.DataLoader(
dataset, batch_size=self.pytorch_settings.batchsize, sampler=valid_sampler
)
return train_loader, valid_loader
def train_model(self, model, optimizer, train_loader):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
if (batch_idx * train_loader.batchsize
>= self.optuna_settings.n_train_examples):
break
optimizer.zero_grad()
data, target = (
data.to(self.pytorch_settings.device),
target.to(self.pytorch_settings.device),
)
target_pred = model(data)
loss = F.mean_squared_error(target_pred, target)
loss.backward()
optimizer.step()
def eval_model(self, model, valid_loader):
model.eval()
correct = 0
with torch.no_grad():
for batch_idx, (data, target) in enumerate(valid_loader):
if (
batch_idx * valid_loader.batchsize
>= self.optuna_settings.n_valid_examples
):
break
data, target = (
data.to(self.pytorch_settings.device),
target.to(self.pytorch_settings.device),
)
target_pred = model(data)
pred = target_pred.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
accuracy = correct / len(valid_loader.dataset)
# num_params = sum(p.numel() for p in model.parameters())
return accuracy
def objective(self, trial: optuna.Trial):
model = self.define_model(trial).to(self.pytorch_settings.device)
optimizer_name = trial.suggest_categorical(
"optimizer", self.optimizer_settings.optimizer_range
)
lr = trial.suggest_float("lr", *self.optimizer_settings.lr_range, log=True)
optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)
train_loader, valid_loader = self.get_sliced_data(trial)
for epoch in range(self.pytorch_settings.epochs):
self.train_model(model, optimizer, train_loader)
accuracy = self.eval_model(model, valid_loader)
if len(self.optuna_settings.directions) == 1:
trial.report(accuracy, epoch)
if trial.should_prune():
raise optuna.exceptions.TrialPruned()
return accuracy
if __name__ == "__main__":
# plt.ion()
hyper_training = HyperTraining()
hyper_training.eye()
# hyper_training.setup_study()
# hyper_training.run_study()
for i in range(10):
#simulate some work
print(i)
time.sleep(0.2)
plt.show()
...