training loop speedup

This commit is contained in:
Joseph Hopfmüller
2024-11-20 11:29:18 +01:00
parent 1622c38582
commit cdca5de473
11 changed files with 1026 additions and 151 deletions

View File

@@ -1,6 +1,6 @@
import copy
from dataclasses import dataclass
from datetime import datetime
import time
import matplotlib.pyplot as plt
import numpy as np
@@ -9,16 +9,21 @@ import warnings
import torch
import torch.nn as nn
import torch.functional as F
# import torch.nn.functional as F # mse_loss doesn't support complex numbers
import torch.optim as optim
import torch.utils.data
from torch.utils.tensorboard import SummaryWriter
from rich.progress import Progress, TextColumn, BarColumn, TaskProgressColumn, TimeRemainingColumn, MofNCompleteColumn
from rich.console import Console
import multiprocessing
from util.datasets import FiberRegenerationDataset
from util.complexNN import complex_sse_loss
from util.optuna_helpers import optional_suggest_categorical, optional_suggest_float, optional_suggest_int
import util
# global settings
@dataclass
class GlobalSettings:
@@ -28,12 +33,14 @@ class GlobalSettings:
# data settings
@dataclass
class DataSettings:
config_path: str = "data/*-128-16384-10000-0-0-17-0-PAM4-0.ini"
symbols_range: tuple = (1, 100)
data_size_range: tuple = (1, 20)
config_path: str = "data/*-128-16384-1000-0-0-17-0-PAM4-0.ini"
dtype: torch.dtype = torch.complex64
symbols_range: tuple|float|int = 16
data_size_range: tuple|float|int = 32
shuffle: bool = True
target_delay: float = 0
xy_delay_range: tuple = (0, 1)
drop_first: int = 1000
xy_delay_range: tuple|float|int = 0
drop_first: int = 10
train_split: float = 0.8
@@ -41,41 +48,46 @@ class DataSettings:
@dataclass
class PytorchSettings:
device: str = "cuda"
batchsize: int = 128
epochs: int = 100
batchsize: int = 1024
epochs: int = 10
summary_dir: str = ".runs"
# model settings
@dataclass
class ModelSettings:
output_size: int = 2
n_layer_range: tuple = (1, 3)
n_units_range: tuple = (4, 128)
activation_func_range: tuple = ("ReLU",)
n_layer_range: tuple|float|int = (2,8)
n_units_range: tuple|float|int = (2,32)
# activation_func_range: tuple = ("ReLU",)
@dataclass
class OptimizerSettings:
optimizer_range: tuple = ("Adam", "RMSprop", "SGD")
lr_range: tuple = (1e-5, 1e-1)
# optimizer_range: tuple|str = ("Adam", "RMSprop", "SGD")
optimizer_range: tuple|str = "RMSprop"
# lr_range: tuple|float = (1e-5, 1e-1)
lr_range: tuple|float = 2e-5
# optuna settings
@dataclass
class OptunaSettings:
n_trials: int = 128
n_threads: int = 16
n_threads: int = 8
timeout: int = 600
directions: tuple = ("maximize",)
directions: tuple = ("minimize",)
metrics_names: tuple = ("sse",)
limit_examples: bool = True
n_train_examples: int = PytorchSettings.batchsize * 30
n_valid_examples: int = PytorchSettings.batchsize * 10
n_train_examples: int = PytorchSettings.batchsize * 50
# n_valid_examples: int = PytorchSettings.batchsize * 100
n_valid_examples: int = float("inf")
storage: str = "sqlite:///optuna_single_core_regen.db"
study_name: str = (
f"single_core_regen_{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
f"single_core_regen_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}"
)
metrics_names: tuple = ("accuracy",)
class HyperTraining:
def __init__(self):
@@ -86,9 +98,43 @@ class HyperTraining:
self.optimizer_settings = OptimizerSettings()
self.optuna_settings = OptunaSettings()
self.console = Console()
# set some extra settings to make the code more readable
self._extra_optuna_settings()
def setup_tb_writer(self, study_name=None, append=None):
log_dir = self.pytorch_settings.summary_dir + "/" + (study_name or self.optuna_settings.study_name)
if append is not None:
log_dir += "_" + str(append)
return SummaryWriter(log_dir)
def resume_latest_study(self, verbose=True):
study_name = hyper_training.get_latest_study()
if study_name:
print(f"Resuming study: {study_name}")
self.optuna_settings.study_name = study_name
def get_latest_study(self, verbose=True):
studies = self.get_studies()
for study in studies:
study.datetime_start = study.datetime_start or datetime.min
if studies:
study = sorted(studies, key = lambda x: x.datetime_start, reverse=True)[0]
if verbose:
print(f"Last study: {study.study_name}")
study_name = study.study_name
else:
if verbose:
print("No previous studies found")
study_name = None
return study_name
def get_studies(self):
return optuna.get_all_study_summaries(storage=self.optuna_settings.storage)
def setup_study(self):
self.study = optuna.create_study(
study_name=self.optuna_settings.study_name,
@@ -100,29 +146,49 @@ class HyperTraining:
with warnings.catch_warnings(action="ignore"):
self.study.set_metric_names(self.optuna_settings.metrics_names)
self.n_threads = min(self.optuna_settings.n_trials, self.optuna_settings.n_threads)
self.n_threads = min(
self.optuna_settings.n_trials, self.optuna_settings.n_threads
)
self.processes = []
for _ in range(self.n_threads):
p = multiprocessing.Process(target=self._run_optimize)
self.processes.append(p)
if self.n_threads > 1:
for _ in range(self.n_threads):
p = multiprocessing.Process(
# target=lambda n_trials: self._run_optimize(self, n_trials),
target = self._run_optimize,
args = (self.optuna_settings.n_trials // self.n_threads,),
)
self.processes.append(p)
def run_study(self):
for p in self.processes:
p.start()
for p in self.processes:
p.join()
remaining_trials = self.optuna_settings.n_trials - self.optuna_settings.n_trials % self.optuna_settings.n_threads
if self.processes:
for p in self.processes:
p.start()
for p in self.processes:
p.join()
remaining_trials = (
self.optuna_settings.n_trials
- self.optuna_settings.n_trials % self.optuna_settings.n_threads
)
else:
remaining_trials = self.optuna_settings.n_trials
if remaining_trials:
self._run_optimize(remaining_trials)
def _run_optimize(self, n_trials):
self.study.optimize(self.objective, n_trials=n_trials, timeout=self.optuna_settings.timeout)
def eye(self, show=True):
util.plot.eye(self.data_settings.config_path, show=show)
def _run_optimize(self, n_trials):
self.study.optimize(
self.objective, n_trials=n_trials, timeout=self.optuna_settings.timeout
)
def plot_eye(self, show=True):
if not hasattr(self, "eye_data"):
data, config = util.datasets.load_data(
self.data_settings.config_path, skipfirst=10, symbols=1000
)
self.eye_data = {"data": data, "sps": int(config["glova"]["sps"])}
return util.plot.eye(**self.eye_data, show=show)
def _extra_optuna_settings(self):
self.optuna_settings.multi_objective = len(self.optuna_settings.directions) > 1
@@ -143,145 +209,256 @@ class HyperTraining:
else float("inf")
)
def define_model(self, trial: optuna.Trial):
n_layers = trial.suggest_int(
"model_n_layers", *self.model_settings.n_layer_range
def define_model(self, trial: optuna.Trial, writer=None):
n_layers = optional_suggest_int(trial, "model_n_layers", self.model_settings.n_layer_range)
in_features = 2 * trial.params.get(
"dataset_data_size",
optional_suggest_int(trial, "dataset_data_size", self.data_settings.data_size_range),
)
trial.set_user_attr("input_dim", in_features)
layers = []
# REVIEW does that work?
in_features = trial.params["dataset_data_size"] * 2
for i in range(n_layers):
out_features = trial.suggest_int(
f"model_n_units_l{i}", *self.model_settings.n_units_range
)
activation_func = trial.suggest_categorical(
f"model_activation_func_l{i}", self.model_settings.activation_func_range
)
layers.append(nn.Linear(in_features, out_features))
layers.append(getattr(nn, activation_func))
out_features = optional_suggest_int(trial, f"model_n_units_l{i}", self.model_settings.n_units_range, log=True)
layers.append(nn.Linear(in_features, out_features, dtype=self.data_settings.dtype))
# layers.append(getattr(nn, activation_func)())
in_features = out_features
layers.append(nn.Linear(in_features, self.model_settings.output_size))
layers.append(nn.Linear(in_features, self.model_settings.output_size, dtype=self.data_settings.dtype))
if writer is not None:
writer.add_graph(nn.Sequential(*layers), torch.zeros(1, trial.user_attrs["input_dim"], dtype=self.data_settings.dtype))
return nn.Sequential(*layers)
def get_sliced_data(self, trial: optuna.Trial):
assert ModelSettings.input_size % 2 == 0, "input_dim must be even"
symbols = trial.suggest_float(
"dataset_symbols", *self.data_settings.symbols_range, log=True
)
xy_delay = trial.suggest_float(
"dataset_xy_delay", *self.data_settings.xy_delay_range
)
data_size = trial.suggest_int(
"dataset_data_size", *self.data_settings.data_size_range
symbols = optional_suggest_float(trial, "dataset_symbols", self.data_settings.symbols_range)
xy_delay = optional_suggest_float(trial, "dataset_xy_delay", self.data_settings.xy_delay_range)
data_size = trial.params.get(
"dataset_data_size",
optional_suggest_int(trial, "dataset_data_size", self.data_settings.data_size_range)
)
# get dataset
dataset = FiberRegenerationDataset(
file_path=self.data_settings.config_path,
symbols=symbols,
data_size=data_size, # two channels (x,y)
data_size=data_size,
target_delay=self.data_settings.target_delay,
xy_delay=xy_delay,
drop_first=self.data_settings.drop_first,
dtype=self.data_settings.dtype,
)
dataset_size = len(dataset)
indices = list(range(dataset_size))
split = int(np.floor(self.data_settings.train_split * dataset_size))
np.random.seed(self.global_settings.seed)
np.random.shuffle(indices)
if self.data_settings.shuffle:
np.random.seed(self.global_settings.seed)
np.random.shuffle(indices)
train_indices, valid_indices = indices[:split], indices[split:]
train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
valid_sampler = torch.utils.data.SubsetRandomSampler(valid_indices)
train_loader = torch.utils.data.DataLoader(
dataset, batch_size=self.pytorch_settings.batchsize, sampler=train_sampler
dataset, batch_size=self.pytorch_settings.batchsize, sampler=train_sampler, drop_last=True
)
valid_loader = torch.utils.data.DataLoader(
dataset, batch_size=self.pytorch_settings.batchsize, sampler=valid_sampler
dataset, batch_size=self.pytorch_settings.batchsize, sampler=valid_sampler, drop_last=True
)
return train_loader, valid_loader
def train_model(self, model, optimizer, train_loader, epoch, writer=None, enable_progress=True):
if enable_progress:
progress = Progress(
TextColumn("[yellow] Training..."),
TextColumn(" Loss: {task.description}"),
BarColumn(),
TaskProgressColumn(),
TextColumn("[green]Batch"),
MofNCompleteColumn(),
TimeRemainingColumn(),
# description="Training",
transient=False,
console=self.console,
refresh_per_second=10,
)
task = progress.add_task("-.---e--", total=len(train_loader))
def train_model(self, model, optimizer, train_loader):
running_loss = 0.0
last_loss = 0.0
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
if (batch_idx * train_loader.batchsize
>= self.optuna_settings.n_train_examples):
for batch_idx, (x, y) in enumerate(train_loader):
if (
batch_idx * train_loader.batch_size
>= self.optuna_settings.n_train_examples
):
break
optimizer.zero_grad()
data, target = (
data.to(self.pytorch_settings.device),
target.to(self.pytorch_settings.device),
x, y = (
x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device),
)
target_pred = model(data)
loss = F.mean_squared_error(target_pred, target)
y_pred = model(x)
loss = complex_sse_loss(y_pred, y)
loss.backward()
optimizer.step()
# clamp weights to keep energy bounded
for p in model.parameters():
p.data.clamp_(-1.0, 1.0)
def eval_model(self, model, valid_loader):
last_loss = loss.item()
if enable_progress:
progress.update(task, advance=1, description=f"{last_loss:.3e}")
running_loss += loss.item()
if writer is not None:
if batch_idx % 10 == 0:
writer.add_scalar("training loss", running_loss/10, epoch*min(len(train_loader), self.optuna_settings.n_train_examples/train_loader.batch_size) + batch_idx)
running_loss = 0.0
if enable_progress:
progress.update(task, description=f"{last_loss:.3e}")
progress.stop()
def eval_model(self, model, valid_loader, epoch, writer=None, enable_progress=True):
if enable_progress:
progress = Progress(
TextColumn("[green]Evaluating..."),
TextColumn("Error: {task.description}"),
BarColumn(),
TaskProgressColumn(),
TextColumn("[green]Batch"),
MofNCompleteColumn(),
TimeRemainingColumn(),
# description="Training",
transient=False,
console=self.console,
refresh_per_second=10,
)
task = progress.add_task("-.---e--", total=len(valid_loader))
model.eval()
correct = 0
running_error = 0
running_error_2 = 0
with torch.no_grad():
for batch_idx, (data, target) in enumerate(valid_loader):
for batch_idx, (x, y) in enumerate(valid_loader):
if (
batch_idx * valid_loader.batchsize
batch_idx * valid_loader.batch_size
>= self.optuna_settings.n_valid_examples
):
break
data, target = (
data.to(self.pytorch_settings.device),
target.to(self.pytorch_settings.device),
x, y = (
x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device),
)
target_pred = model(data)
pred = target_pred.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
y_pred = model(x)
error = complex_sse_loss(y_pred, y)
running_error += error.item()
running_error_2 += error.item()
accuracy = correct / len(valid_loader.dataset)
# num_params = sum(p.numel() for p in model.parameters())
return accuracy
if enable_progress:
progress.update(task, advance=1, description=f"{error.item():.3e}")
if writer is not None:
if batch_idx % 10 == 0:
writer.add_scalar("sse", running_error_2/10, epoch*min(len(valid_loader), self.optuna_settings.n_valid_examples/valid_loader.batch_size) + batch_idx)
running_error_2 = 0.0
running_error /= batch_idx + 1
if enable_progress:
progress.update(task, description=f"{running_error:.3e}")
progress.stop()
return running_error
def run_model(self, model, loader):
model.eval()
y_preds = []
with torch.no_grad():
for x, y in loader:
x, y = (
x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device),
)
y_preds.append(model(x))
return torch.stack(y_preds)
def objective(self, trial: optuna.Trial):
model = self.define_model(trial).to(self.pytorch_settings.device)
optimizer_name = trial.suggest_categorical(
"optimizer", self.optimizer_settings.optimizer_range
)
lr = trial.suggest_float("lr", *self.optimizer_settings.lr_range, log=True)
optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)
writer = self.setup_tb_writer(self.optuna_settings.study_name, f"{trial.number:0>len(str(self.optuna_settings.n_trials))}")
train_loader, valid_loader = self.get_sliced_data(trial)
for epoch in range(self.pytorch_settings.epochs):
self.train_model(model, optimizer, train_loader)
accuracy = self.eval_model(model, valid_loader)
model = self.define_model(trial, writer).to(self.pytorch_settings.device)
if len(self.optuna_settings.directions) == 1:
trial.report(accuracy, epoch)
optimizer_name = optional_suggest_categorical(trial, "optimizer", self.optimizer_settings.optimizer_range)
lr = optional_suggest_float(trial, "lr", self.optimizer_settings.lr_range, log=True)
optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)
for epoch in range(self.pytorch_settings.epochs):
enable_progress = self.optuna_settings.n_threads == 1
if enable_progress:
print(f"Epoch {epoch+1}/{self.pytorch_settings.epochs}")
self.train_model(model, optimizer, train_loader, epoch, writer, enable_progress=enable_progress)
sse = self.eval_model(model, valid_loader, epoch, writer, enable_progress=enable_progress)
if not self.optuna_settings.multi_objective:
trial.report(sse, epoch)
if trial.should_prune():
raise optuna.exceptions.TrialPruned()
return accuracy
writer.close()
return sse
if __name__ == "__main__":
# plt.ion()
hyper_training = HyperTraining()
hyper_training.eye()
# hyper_training.setup_study()
# hyper_training.run_study()
for i in range(10):
#simulate some work
print(i)
time.sleep(0.2)
# hyper_training.resume_latest_study()
hyper_training.setup_study()
hyper_training.run_study()
best_model = hyper_training.define_model(hyper_training.study.best_trial).to(hyper_training.pytorch_settings.device)
data_settings_backup = copy.copy(hyper_training.data_settings)
hyper_training.data_settings.shuffle = False
hyper_training.data_settings.train_split = 0.01
plot_loader, _ = hyper_training.get_sliced_data(hyper_training.study.best_trial)
regen = hyper_training.run_model(best_model, plot_loader)
regen = regen.view(-1, 2)
# [batch_no, batch_size, 2] -> [no, 2]
original, _ = util.datasets.load_data(hyper_training.data_settings.config_path, skipfirst=hyper_training.data_settings.drop_first)
original = original[:len(regen)]
regen = regen.cpu().numpy()
_, axs = plt.subplots(2)
for i, ax in enumerate(axs):
ax.plot(np.abs(original[:, i])**2, label="original")
ax.plot(np.abs(regen[:, i])**2, label="regen")
ax.legend()
plt.show()
...
print(f"Best model: {best_model}")
# eye_fig = hyper_training.plot_eye()
...