add: regen.py (main hyperparameter training file)
feat: add utility functions for fiber dataset visualization and hyperparameter training; housekeeping: rename dataset.py -> datasets.py
This commit is contained in:
287
src/single-core-regen/regen.py
Normal file
287
src/single-core-regen/regen.py
Normal file
@@ -0,0 +1,287 @@
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
import time
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
import numpy as np
|
||||
import optuna
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.functional as F
|
||||
import torch.optim as optim
|
||||
import torch.utils.data
|
||||
|
||||
import multiprocessing
|
||||
|
||||
from util.datasets import FiberRegenerationDataset
|
||||
import util
|
||||
|
||||
|
||||
# global settings
|
||||
@dataclass
|
||||
class GlobalSettings:
|
||||
seed: int = 42
|
||||
|
||||
|
||||
# data settings
|
||||
@dataclass
|
||||
class DataSettings:
|
||||
config_path: str = "data/*-128-16384-10000-0-0-17-0-PAM4-0.ini"
|
||||
symbols_range: tuple = (1, 100)
|
||||
data_size_range: tuple = (1, 20)
|
||||
target_delay: float = 0
|
||||
xy_delay_range: tuple = (0, 1)
|
||||
drop_first: int = 1000
|
||||
train_split: float = 0.8
|
||||
|
||||
|
||||
# pytorch settings
|
||||
@dataclass
|
||||
class PytorchSettings:
|
||||
device: str = "cuda"
|
||||
batchsize: int = 128
|
||||
epochs: int = 100
|
||||
|
||||
|
||||
# model settings
|
||||
@dataclass
|
||||
class ModelSettings:
|
||||
output_size: int = 2
|
||||
n_layer_range: tuple = (1, 3)
|
||||
n_units_range: tuple = (4, 128)
|
||||
activation_func_range: tuple = ("ReLU",)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OptimizerSettings:
|
||||
optimizer_range: tuple = ("Adam", "RMSprop", "SGD")
|
||||
lr_range: tuple = (1e-5, 1e-1)
|
||||
|
||||
|
||||
# optuna settings
|
||||
@dataclass
|
||||
class OptunaSettings:
|
||||
n_trials: int = 128
|
||||
n_threads: int = 16
|
||||
timeout: int = 600
|
||||
directions: tuple = ("maximize",)
|
||||
|
||||
limit_examples: bool = True
|
||||
n_train_examples: int = PytorchSettings.batchsize * 30
|
||||
n_valid_examples: int = PytorchSettings.batchsize * 10
|
||||
storage: str = "sqlite:///optuna_single_core_regen.db"
|
||||
study_name: str = (
|
||||
f"single_core_regen_{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
)
|
||||
metrics_names: tuple = ("accuracy",)
|
||||
|
||||
class HyperTraining:
|
||||
def __init__(self):
|
||||
self.global_settings = GlobalSettings()
|
||||
self.data_settings = DataSettings()
|
||||
self.pytorch_settings = PytorchSettings()
|
||||
self.model_settings = ModelSettings()
|
||||
self.optimizer_settings = OptimizerSettings()
|
||||
self.optuna_settings = OptunaSettings()
|
||||
|
||||
# set some extra settings to make the code more readable
|
||||
self._extra_optuna_settings()
|
||||
|
||||
def setup_study(self):
|
||||
self.study = optuna.create_study(
|
||||
study_name=self.optuna_settings.study_name,
|
||||
storage=self.optuna_settings.storage,
|
||||
load_if_exists=True,
|
||||
direction=self.optuna_settings.direction,
|
||||
directions=self.optuna_settings.directions,
|
||||
)
|
||||
|
||||
with warnings.catch_warnings(action="ignore"):
|
||||
self.study.set_metric_names(self.optuna_settings.metrics_names)
|
||||
|
||||
self.n_threads = min(self.optuna_settings.n_trials, self.optuna_settings.n_threads)
|
||||
self.processes = []
|
||||
for _ in range(self.n_threads):
|
||||
p = multiprocessing.Process(target=self._run_optimize)
|
||||
self.processes.append(p)
|
||||
|
||||
def run_study(self):
|
||||
for p in self.processes:
|
||||
p.start()
|
||||
|
||||
for p in self.processes:
|
||||
p.join()
|
||||
|
||||
remaining_trials = self.optuna_settings.n_trials - self.optuna_settings.n_trials % self.optuna_settings.n_threads
|
||||
if remaining_trials:
|
||||
self._run_optimize(remaining_trials)
|
||||
|
||||
def _run_optimize(self, n_trials):
|
||||
self.study.optimize(self.objective, n_trials=n_trials, timeout=self.optuna_settings.timeout)
|
||||
|
||||
def eye(self, show=True):
|
||||
util.plot.eye(self.data_settings.config_path, show=show)
|
||||
|
||||
def _extra_optuna_settings(self):
|
||||
self.optuna_settings.multi_objective = len(self.optuna_settings.directions) > 1
|
||||
if self.optuna_settings.multi_objective:
|
||||
self.optuna_settings.direction = None
|
||||
else:
|
||||
self.optuna_settings.direction = self.optuna_settings.directions[0]
|
||||
self.optuna_settings.directions = None
|
||||
|
||||
self.optuna_settings.n_train_examples = (
|
||||
self.optuna_settings.n_train_examples
|
||||
if self.optuna_settings.limit_examples
|
||||
else float("inf")
|
||||
)
|
||||
self.optuna_settings.n_valid_examples = (
|
||||
self.optuna_settings.n_valid_examples
|
||||
if self.optuna_settings.limit_examples
|
||||
else float("inf")
|
||||
)
|
||||
|
||||
def define_model(self, trial: optuna.Trial):
|
||||
n_layers = trial.suggest_int(
|
||||
"model_n_layers", *self.model_settings.n_layer_range
|
||||
)
|
||||
layers = []
|
||||
|
||||
# REVIEW does that work?
|
||||
in_features = trial.params["dataset_data_size"] * 2
|
||||
for i in range(n_layers):
|
||||
out_features = trial.suggest_int(
|
||||
f"model_n_units_l{i}", *self.model_settings.n_units_range
|
||||
)
|
||||
activation_func = trial.suggest_categorical(
|
||||
f"model_activation_func_l{i}", self.model_settings.activation_func_range
|
||||
)
|
||||
|
||||
layers.append(nn.Linear(in_features, out_features))
|
||||
layers.append(getattr(nn, activation_func))
|
||||
in_features = out_features
|
||||
|
||||
layers.append(nn.Linear(in_features, self.model_settings.output_size))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def get_sliced_data(self, trial: optuna.Trial):
|
||||
assert ModelSettings.input_size % 2 == 0, "input_dim must be even"
|
||||
symbols = trial.suggest_float(
|
||||
"dataset_symbols", *self.data_settings.symbols_range, log=True
|
||||
)
|
||||
xy_delay = trial.suggest_float(
|
||||
"dataset_xy_delay", *self.data_settings.xy_delay_range
|
||||
)
|
||||
data_size = trial.suggest_int(
|
||||
"dataset_data_size", *self.data_settings.data_size_range
|
||||
)
|
||||
# get dataset
|
||||
dataset = FiberRegenerationDataset(
|
||||
file_path=self.data_settings.config_path,
|
||||
symbols=symbols,
|
||||
data_size=data_size, # two channels (x,y)
|
||||
target_delay=self.data_settings.target_delay,
|
||||
xy_delay=xy_delay,
|
||||
drop_first=self.data_settings.drop_first,
|
||||
)
|
||||
|
||||
dataset_size = len(dataset)
|
||||
indices = list(range(dataset_size))
|
||||
split = int(np.floor(self.data_settings.train_split * dataset_size))
|
||||
np.random.seed(self.global_settings.seed)
|
||||
np.random.shuffle(indices)
|
||||
train_indices, valid_indices = indices[:split], indices[split:]
|
||||
|
||||
train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
|
||||
valid_sampler = torch.utils.data.SubsetRandomSampler(valid_indices)
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
dataset, batch_size=self.pytorch_settings.batchsize, sampler=train_sampler
|
||||
)
|
||||
valid_loader = torch.utils.data.DataLoader(
|
||||
dataset, batch_size=self.pytorch_settings.batchsize, sampler=valid_sampler
|
||||
)
|
||||
|
||||
return train_loader, valid_loader
|
||||
|
||||
|
||||
def train_model(self, model, optimizer, train_loader):
|
||||
model.train()
|
||||
for batch_idx, (data, target) in enumerate(train_loader):
|
||||
if (batch_idx * train_loader.batchsize
|
||||
>= self.optuna_settings.n_train_examples):
|
||||
break
|
||||
optimizer.zero_grad()
|
||||
data, target = (
|
||||
data.to(self.pytorch_settings.device),
|
||||
target.to(self.pytorch_settings.device),
|
||||
)
|
||||
target_pred = model(data)
|
||||
loss = F.mean_squared_error(target_pred, target)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
|
||||
def eval_model(self, model, valid_loader):
|
||||
model.eval()
|
||||
correct = 0
|
||||
with torch.no_grad():
|
||||
for batch_idx, (data, target) in enumerate(valid_loader):
|
||||
if (
|
||||
batch_idx * valid_loader.batchsize
|
||||
>= self.optuna_settings.n_valid_examples
|
||||
):
|
||||
break
|
||||
data, target = (
|
||||
data.to(self.pytorch_settings.device),
|
||||
target.to(self.pytorch_settings.device),
|
||||
)
|
||||
target_pred = model(data)
|
||||
pred = target_pred.argmax(dim=1, keepdim=True)
|
||||
correct += pred.eq(target.view_as(pred)).sum().item()
|
||||
|
||||
accuracy = correct / len(valid_loader.dataset)
|
||||
# num_params = sum(p.numel() for p in model.parameters())
|
||||
return accuracy
|
||||
|
||||
def objective(self, trial: optuna.Trial):
|
||||
model = self.define_model(trial).to(self.pytorch_settings.device)
|
||||
|
||||
optimizer_name = trial.suggest_categorical(
|
||||
"optimizer", self.optimizer_settings.optimizer_range
|
||||
)
|
||||
lr = trial.suggest_float("lr", *self.optimizer_settings.lr_range, log=True)
|
||||
optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)
|
||||
|
||||
train_loader, valid_loader = self.get_sliced_data(trial)
|
||||
|
||||
for epoch in range(self.pytorch_settings.epochs):
|
||||
self.train_model(model, optimizer, train_loader)
|
||||
accuracy = self.eval_model(model, valid_loader)
|
||||
|
||||
if len(self.optuna_settings.directions) == 1:
|
||||
trial.report(accuracy, epoch)
|
||||
if trial.should_prune():
|
||||
raise optuna.exceptions.TrialPruned()
|
||||
|
||||
return accuracy
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# plt.ion()
|
||||
hyper_training = HyperTraining()
|
||||
hyper_training.eye()
|
||||
|
||||
# hyper_training.setup_study()
|
||||
# hyper_training.run_study()
|
||||
for i in range(10):
|
||||
#simulate some work
|
||||
print(i)
|
||||
time.sleep(0.2)
|
||||
|
||||
plt.show()
|
||||
...
|
||||
|
||||
Reference in New Issue
Block a user