move hypertraining class into separate file;
move settings dataclasses into separate file; add SemiUnitaryLayer; clean up model response plotting code; cnt hyperparameter search
This commit is contained in:
@@ -1,464 +1,107 @@
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
import numpy as np
|
||||
import optuna
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
# import torch.nn.functional as F # mse_loss doesn't support complex numbers
|
||||
import torch.optim as optim
|
||||
import torch.utils.data
|
||||
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from rich.progress import Progress, TextColumn, BarColumn, TaskProgressColumn, TimeRemainingColumn, MofNCompleteColumn
|
||||
from rich.console import Console
|
||||
|
||||
import multiprocessing
|
||||
|
||||
from util.datasets import FiberRegenerationDataset
|
||||
from util.complexNN import complex_sse_loss
|
||||
from util.optuna_helpers import optional_suggest_categorical, optional_suggest_float, optional_suggest_int
|
||||
import util
|
||||
# global settings
|
||||
@dataclass
|
||||
class GlobalSettings:
|
||||
seed: int = 42
|
||||
|
||||
|
||||
# data settings
|
||||
@dataclass
|
||||
class DataSettings:
|
||||
config_path: str = "data/*-128-16384-1000-0-0-17-0-PAM4-0.ini"
|
||||
dtype: torch.dtype = torch.complex64
|
||||
symbols_range: tuple|float|int = 16
|
||||
data_size_range: tuple|float|int = 32
|
||||
shuffle: bool = True
|
||||
target_delay: float = 0
|
||||
xy_delay_range: tuple|float|int = 0
|
||||
drop_first: int = 10
|
||||
train_split: float = 0.8
|
||||
|
||||
|
||||
# pytorch settings
|
||||
@dataclass
|
||||
class PytorchSettings:
|
||||
device: str = "cuda"
|
||||
batchsize: int = 1024
|
||||
epochs: int = 10
|
||||
summary_dir: str = ".runs"
|
||||
|
||||
|
||||
# model settings
|
||||
@dataclass
|
||||
class ModelSettings:
|
||||
output_size: int = 2
|
||||
n_layer_range: tuple|float|int = (2,8)
|
||||
n_units_range: tuple|float|int = (2,32)
|
||||
# activation_func_range: tuple = ("ReLU",)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OptimizerSettings:
|
||||
# optimizer_range: tuple|str = ("Adam", "RMSprop", "SGD")
|
||||
optimizer_range: tuple|str = "RMSprop"
|
||||
# lr_range: tuple|float = (1e-5, 1e-1)
|
||||
lr_range: tuple|float = 2e-5
|
||||
|
||||
|
||||
# optuna settings
|
||||
@dataclass
|
||||
class OptunaSettings:
|
||||
n_trials: int = 128
|
||||
n_threads: int = 8
|
||||
timeout: int = 600
|
||||
directions: tuple = ("minimize",)
|
||||
metrics_names: tuple = ("sse",)
|
||||
|
||||
limit_examples: bool = True
|
||||
n_train_examples: int = PytorchSettings.batchsize * 50
|
||||
# n_valid_examples: int = PytorchSettings.batchsize * 100
|
||||
n_valid_examples: int = float("inf")
|
||||
storage: str = "sqlite:///optuna_single_core_regen.db"
|
||||
study_name: str = (
|
||||
f"single_core_regen_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}"
|
||||
)
|
||||
|
||||
|
||||
class HyperTraining:
|
||||
def __init__(self):
|
||||
self.global_settings = GlobalSettings()
|
||||
self.data_settings = DataSettings()
|
||||
self.pytorch_settings = PytorchSettings()
|
||||
self.model_settings = ModelSettings()
|
||||
self.optimizer_settings = OptimizerSettings()
|
||||
self.optuna_settings = OptunaSettings()
|
||||
|
||||
self.console = Console()
|
||||
|
||||
# set some extra settings to make the code more readable
|
||||
self._extra_optuna_settings()
|
||||
|
||||
def setup_tb_writer(self, study_name=None, append=None):
|
||||
log_dir = self.pytorch_settings.summary_dir + "/" + (study_name or self.optuna_settings.study_name)
|
||||
if append is not None:
|
||||
log_dir += "_" + str(append)
|
||||
|
||||
return SummaryWriter(log_dir)
|
||||
|
||||
def resume_latest_study(self, verbose=True):
|
||||
study_name = hyper_training.get_latest_study()
|
||||
|
||||
if study_name:
|
||||
print(f"Resuming study: {study_name}")
|
||||
self.optuna_settings.study_name = study_name
|
||||
|
||||
def get_latest_study(self, verbose=True):
|
||||
studies = self.get_studies()
|
||||
for study in studies:
|
||||
study.datetime_start = study.datetime_start or datetime.min
|
||||
if studies:
|
||||
study = sorted(studies, key = lambda x: x.datetime_start, reverse=True)[0]
|
||||
if verbose:
|
||||
print(f"Last study: {study.study_name}")
|
||||
study_name = study.study_name
|
||||
else:
|
||||
if verbose:
|
||||
print("No previous studies found")
|
||||
study_name = None
|
||||
return study_name
|
||||
|
||||
def get_studies(self):
|
||||
return optuna.get_all_study_summaries(storage=self.optuna_settings.storage)
|
||||
|
||||
def setup_study(self):
|
||||
self.study = optuna.create_study(
|
||||
study_name=self.optuna_settings.study_name,
|
||||
storage=self.optuna_settings.storage,
|
||||
load_if_exists=True,
|
||||
direction=self.optuna_settings.direction,
|
||||
directions=self.optuna_settings.directions,
|
||||
)
|
||||
|
||||
with warnings.catch_warnings(action="ignore"):
|
||||
self.study.set_metric_names(self.optuna_settings.metrics_names)
|
||||
|
||||
self.n_threads = min(
|
||||
self.optuna_settings.n_trials, self.optuna_settings.n_threads
|
||||
)
|
||||
self.processes = []
|
||||
if self.n_threads > 1:
|
||||
for _ in range(self.n_threads):
|
||||
p = multiprocessing.Process(
|
||||
# target=lambda n_trials: self._run_optimize(self, n_trials),
|
||||
target = self._run_optimize,
|
||||
args = (self.optuna_settings.n_trials // self.n_threads,),
|
||||
)
|
||||
self.processes.append(p)
|
||||
|
||||
def run_study(self):
|
||||
if self.processes:
|
||||
for p in self.processes:
|
||||
p.start()
|
||||
for p in self.processes:
|
||||
p.join()
|
||||
|
||||
remaining_trials = (
|
||||
self.optuna_settings.n_trials
|
||||
- self.optuna_settings.n_trials % self.optuna_settings.n_threads
|
||||
)
|
||||
else:
|
||||
remaining_trials = self.optuna_settings.n_trials
|
||||
|
||||
if remaining_trials:
|
||||
self._run_optimize(remaining_trials)
|
||||
|
||||
def _run_optimize(self, n_trials):
|
||||
self.study.optimize(
|
||||
self.objective, n_trials=n_trials, timeout=self.optuna_settings.timeout
|
||||
)
|
||||
|
||||
def plot_eye(self, show=True):
|
||||
if not hasattr(self, "eye_data"):
|
||||
data, config = util.datasets.load_data(
|
||||
self.data_settings.config_path, skipfirst=10, symbols=1000
|
||||
)
|
||||
self.eye_data = {"data": data, "sps": int(config["glova"]["sps"])}
|
||||
return util.plot.eye(**self.eye_data, show=show)
|
||||
|
||||
def _extra_optuna_settings(self):
|
||||
self.optuna_settings.multi_objective = len(self.optuna_settings.directions) > 1
|
||||
if self.optuna_settings.multi_objective:
|
||||
self.optuna_settings.direction = None
|
||||
else:
|
||||
self.optuna_settings.direction = self.optuna_settings.directions[0]
|
||||
self.optuna_settings.directions = None
|
||||
|
||||
self.optuna_settings.n_train_examples = (
|
||||
self.optuna_settings.n_train_examples
|
||||
if self.optuna_settings.limit_examples
|
||||
else float("inf")
|
||||
)
|
||||
self.optuna_settings.n_valid_examples = (
|
||||
self.optuna_settings.n_valid_examples
|
||||
if self.optuna_settings.limit_examples
|
||||
else float("inf")
|
||||
)
|
||||
|
||||
def define_model(self, trial: optuna.Trial, writer=None):
|
||||
n_layers = optional_suggest_int(trial, "model_n_layers", self.model_settings.n_layer_range)
|
||||
|
||||
in_features = 2 * trial.params.get(
|
||||
"dataset_data_size",
|
||||
optional_suggest_int(trial, "dataset_data_size", self.data_settings.data_size_range),
|
||||
)
|
||||
trial.set_user_attr("input_dim", in_features)
|
||||
|
||||
layers = []
|
||||
for i in range(n_layers):
|
||||
out_features = optional_suggest_int(trial, f"model_n_units_l{i}", self.model_settings.n_units_range, log=True)
|
||||
|
||||
layers.append(nn.Linear(in_features, out_features, dtype=self.data_settings.dtype))
|
||||
# layers.append(getattr(nn, activation_func)())
|
||||
in_features = out_features
|
||||
|
||||
layers.append(nn.Linear(in_features, self.model_settings.output_size, dtype=self.data_settings.dtype))
|
||||
|
||||
if writer is not None:
|
||||
writer.add_graph(nn.Sequential(*layers), torch.zeros(1, trial.user_attrs["input_dim"], dtype=self.data_settings.dtype))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def get_sliced_data(self, trial: optuna.Trial):
|
||||
symbols = optional_suggest_float(trial, "dataset_symbols", self.data_settings.symbols_range)
|
||||
|
||||
xy_delay = optional_suggest_float(trial, "dataset_xy_delay", self.data_settings.xy_delay_range)
|
||||
|
||||
data_size = trial.params.get(
|
||||
"dataset_data_size",
|
||||
optional_suggest_int(trial, "dataset_data_size", self.data_settings.data_size_range)
|
||||
)
|
||||
|
||||
# get dataset
|
||||
dataset = FiberRegenerationDataset(
|
||||
file_path=self.data_settings.config_path,
|
||||
symbols=symbols,
|
||||
data_size=data_size,
|
||||
target_delay=self.data_settings.target_delay,
|
||||
xy_delay=xy_delay,
|
||||
drop_first=self.data_settings.drop_first,
|
||||
dtype=self.data_settings.dtype,
|
||||
)
|
||||
|
||||
dataset_size = len(dataset)
|
||||
indices = list(range(dataset_size))
|
||||
split = int(np.floor(self.data_settings.train_split * dataset_size))
|
||||
if self.data_settings.shuffle:
|
||||
np.random.seed(self.global_settings.seed)
|
||||
np.random.shuffle(indices)
|
||||
|
||||
train_indices, valid_indices = indices[:split], indices[split:]
|
||||
|
||||
train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
|
||||
valid_sampler = torch.utils.data.SubsetRandomSampler(valid_indices)
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
dataset, batch_size=self.pytorch_settings.batchsize, sampler=train_sampler, drop_last=True
|
||||
)
|
||||
valid_loader = torch.utils.data.DataLoader(
|
||||
dataset, batch_size=self.pytorch_settings.batchsize, sampler=valid_sampler, drop_last=True
|
||||
)
|
||||
|
||||
return train_loader, valid_loader
|
||||
|
||||
def train_model(self, model, optimizer, train_loader, epoch, writer=None, enable_progress=True):
|
||||
if enable_progress:
|
||||
progress = Progress(
|
||||
TextColumn("[yellow] Training..."),
|
||||
TextColumn(" Loss: {task.description}"),
|
||||
BarColumn(),
|
||||
TaskProgressColumn(),
|
||||
TextColumn("[green]Batch"),
|
||||
MofNCompleteColumn(),
|
||||
TimeRemainingColumn(),
|
||||
# description="Training",
|
||||
transient=False,
|
||||
console=self.console,
|
||||
refresh_per_second=10,
|
||||
)
|
||||
task = progress.add_task("-.---e--", total=len(train_loader))
|
||||
|
||||
running_loss = 0.0
|
||||
last_loss = 0.0
|
||||
model.train()
|
||||
for batch_idx, (x, y) in enumerate(train_loader):
|
||||
if (
|
||||
batch_idx * train_loader.batch_size
|
||||
>= self.optuna_settings.n_train_examples
|
||||
):
|
||||
break
|
||||
optimizer.zero_grad()
|
||||
x, y = (
|
||||
x.to(self.pytorch_settings.device),
|
||||
y.to(self.pytorch_settings.device),
|
||||
)
|
||||
y_pred = model(x)
|
||||
loss = complex_sse_loss(y_pred, y)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# clamp weights to keep energy bounded
|
||||
for p in model.parameters():
|
||||
p.data.clamp_(-1.0, 1.0)
|
||||
|
||||
last_loss = loss.item()
|
||||
|
||||
if enable_progress:
|
||||
progress.update(task, advance=1, description=f"{last_loss:.3e}")
|
||||
|
||||
running_loss += loss.item()
|
||||
if writer is not None:
|
||||
if batch_idx % 10 == 0:
|
||||
writer.add_scalar("training loss", running_loss/10, epoch*min(len(train_loader), self.optuna_settings.n_train_examples/train_loader.batch_size) + batch_idx)
|
||||
running_loss = 0.0
|
||||
|
||||
if enable_progress:
|
||||
progress.update(task, description=f"{last_loss:.3e}")
|
||||
progress.stop()
|
||||
|
||||
|
||||
def eval_model(self, model, valid_loader, epoch, writer=None, enable_progress=True):
|
||||
if enable_progress:
|
||||
progress = Progress(
|
||||
TextColumn("[green]Evaluating..."),
|
||||
TextColumn("Error: {task.description}"),
|
||||
BarColumn(),
|
||||
TaskProgressColumn(),
|
||||
TextColumn("[green]Batch"),
|
||||
MofNCompleteColumn(),
|
||||
TimeRemainingColumn(),
|
||||
# description="Training",
|
||||
transient=False,
|
||||
console=self.console,
|
||||
refresh_per_second=10,
|
||||
)
|
||||
task = progress.add_task("-.---e--", total=len(valid_loader))
|
||||
|
||||
model.eval()
|
||||
running_error = 0
|
||||
running_error_2 = 0
|
||||
with torch.no_grad():
|
||||
for batch_idx, (x, y) in enumerate(valid_loader):
|
||||
if (
|
||||
batch_idx * valid_loader.batch_size
|
||||
>= self.optuna_settings.n_valid_examples
|
||||
):
|
||||
break
|
||||
x, y = (
|
||||
x.to(self.pytorch_settings.device),
|
||||
y.to(self.pytorch_settings.device),
|
||||
)
|
||||
y_pred = model(x)
|
||||
error = complex_sse_loss(y_pred, y)
|
||||
running_error += error.item()
|
||||
running_error_2 += error.item()
|
||||
|
||||
if enable_progress:
|
||||
progress.update(task, advance=1, description=f"{error.item():.3e}")
|
||||
|
||||
if writer is not None:
|
||||
if batch_idx % 10 == 0:
|
||||
writer.add_scalar("sse", running_error_2/10, epoch*min(len(valid_loader), self.optuna_settings.n_valid_examples/valid_loader.batch_size) + batch_idx)
|
||||
running_error_2 = 0.0
|
||||
|
||||
running_error /= batch_idx + 1
|
||||
|
||||
if enable_progress:
|
||||
progress.update(task, description=f"{running_error:.3e}")
|
||||
progress.stop()
|
||||
|
||||
return running_error
|
||||
|
||||
def run_model(self, model, loader):
|
||||
model.eval()
|
||||
y_preds = []
|
||||
with torch.no_grad():
|
||||
for x, y in loader:
|
||||
x, y = (
|
||||
x.to(self.pytorch_settings.device),
|
||||
y.to(self.pytorch_settings.device),
|
||||
)
|
||||
y_preds.append(model(x))
|
||||
return torch.stack(y_preds)
|
||||
|
||||
|
||||
def objective(self, trial: optuna.Trial):
|
||||
writer = self.setup_tb_writer(self.optuna_settings.study_name, f"{trial.number:0>len(str(self.optuna_settings.n_trials))}")
|
||||
train_loader, valid_loader = self.get_sliced_data(trial)
|
||||
|
||||
model = self.define_model(trial, writer).to(self.pytorch_settings.device)
|
||||
|
||||
optimizer_name = optional_suggest_categorical(trial, "optimizer", self.optimizer_settings.optimizer_range)
|
||||
|
||||
lr = optional_suggest_float(trial, "lr", self.optimizer_settings.lr_range, log=True)
|
||||
|
||||
optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)
|
||||
|
||||
for epoch in range(self.pytorch_settings.epochs):
|
||||
enable_progress = self.optuna_settings.n_threads == 1
|
||||
if enable_progress:
|
||||
print(f"Epoch {epoch+1}/{self.pytorch_settings.epochs}")
|
||||
self.train_model(model, optimizer, train_loader, epoch, writer, enable_progress=enable_progress)
|
||||
sse = self.eval_model(model, valid_loader, epoch, writer, enable_progress=enable_progress)
|
||||
|
||||
if not self.optuna_settings.multi_objective:
|
||||
trial.report(sse, epoch)
|
||||
if trial.should_prune():
|
||||
raise optuna.exceptions.TrialPruned()
|
||||
|
||||
writer.close()
|
||||
|
||||
return sse
|
||||
|
||||
|
||||
from hypertraining.hypertraining import HyperTraining
|
||||
from hypertraining.settings import (
|
||||
GlobalSettings,
|
||||
DataSettings,
|
||||
PytorchSettings,
|
||||
ModelSettings,
|
||||
OptimizerSettings,
|
||||
OptunaSettings,
|
||||
)
|
||||
|
||||
global_settings = GlobalSettings(
|
||||
seed = 42,
|
||||
)
|
||||
|
||||
data_settings = DataSettings(
|
||||
config_path = "data/*-128-16384-100000-0-0-17-0-PAM4-0.ini",
|
||||
dtype = ("complex64", "float64", "complex32", "float32"),
|
||||
symbols = (1, 16),
|
||||
model_input_dim = (1, 32),
|
||||
shuffle = True,
|
||||
in_out_delay = 0,
|
||||
xy_delay = 0,
|
||||
drop_first = 1000,
|
||||
train_split = 0.8,
|
||||
)
|
||||
|
||||
pytorch_settings = PytorchSettings(
|
||||
epochs = 25,
|
||||
batchsize = 2**10,
|
||||
device = "cuda",
|
||||
dataloader_workers = 2,
|
||||
dataloader_prefetch = 2,
|
||||
summary_dir = ".runs",
|
||||
write_every = 2**5,
|
||||
model_dir = ".models",
|
||||
)
|
||||
|
||||
model_settings = ModelSettings(
|
||||
output_dim = 2,
|
||||
model_n_layers = (2, 8),
|
||||
unit_count = (2, 16),
|
||||
model_activation_func = ("ModReLU")#, "ZReLU", "Mag")#, "CReLU", "Identity"),
|
||||
)
|
||||
|
||||
optimizer_settings = OptimizerSettings(
|
||||
optimizer = ("Adam", "RMSprop"),#, "SGD"),
|
||||
# learning_rate = (1e-5, 1e-1),
|
||||
learning_rate=1e-3,
|
||||
# scheduler = "ReduceLROnPlateau",
|
||||
# scheduler_kwargs = {"mode": "min", "factor": 0.5, "patience": 10}
|
||||
)
|
||||
|
||||
optuna_settings = OptunaSettings(
|
||||
n_trials = 4096,
|
||||
n_threads = 16,
|
||||
timeout = 600,
|
||||
directions = ("minimize","minimize"),
|
||||
metrics_names = ("n_params","mse"),
|
||||
|
||||
limit_examples = True,
|
||||
n_train_batches = 100,
|
||||
n_valid_batches = 100,
|
||||
storage = "sqlite:///data/single_core_regen.db",
|
||||
study_name = f"single_core_regen_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
hyper_training = HyperTraining()
|
||||
hyper_training = HyperTraining(
|
||||
global_settings=global_settings,
|
||||
data_settings=data_settings,
|
||||
pytorch_settings=pytorch_settings,
|
||||
model_settings=model_settings,
|
||||
optimizer_settings=optimizer_settings,
|
||||
optuna_settings=optuna_settings,
|
||||
)
|
||||
|
||||
hyper_training.setup_study()
|
||||
|
||||
# hyper_training.resume_latest_study()
|
||||
|
||||
hyper_training.setup_study()
|
||||
|
||||
hyper_training.run_study()
|
||||
# best_trial = hyper_training.study.best_trial
|
||||
|
||||
best_model = hyper_training.define_model(hyper_training.study.best_trial).to(hyper_training.pytorch_settings.device)
|
||||
data_settings_backup = copy.copy(hyper_training.data_settings)
|
||||
hyper_training.data_settings.shuffle = False
|
||||
hyper_training.data_settings.train_split = 0.01
|
||||
plot_loader, _ = hyper_training.get_sliced_data(hyper_training.study.best_trial)
|
||||
|
||||
regen = hyper_training.run_model(best_model, plot_loader)
|
||||
regen = regen.view(-1, 2)
|
||||
# [batch_no, batch_size, 2] -> [no, 2]
|
||||
# best_model = hyper_training.define_model(best_trial).to(
|
||||
# hyper_training.pytorch_settings.device
|
||||
# )
|
||||
|
||||
original, _ = util.datasets.load_data(hyper_training.data_settings.config_path, skipfirst=hyper_training.data_settings.drop_first)
|
||||
original = original[:len(regen)]
|
||||
|
||||
regen = regen.cpu().numpy()
|
||||
_, axs = plt.subplots(2)
|
||||
for i, ax in enumerate(axs):
|
||||
ax.plot(np.abs(original[:, i])**2, label="original")
|
||||
ax.plot(np.abs(regen[:, i])**2, label="regen")
|
||||
ax.legend()
|
||||
plt.show()
|
||||
|
||||
|
||||
print(f"Best model: {best_model}")
|
||||
# title_append, subtitle = hyper_training.build_title(best_trial)
|
||||
# hyper_training.plot_model_response(
|
||||
# best_trial,
|
||||
# model=best_model,
|
||||
# title_append=title_append,
|
||||
# subtitle=subtitle,
|
||||
# mode="eye",
|
||||
# show=True,
|
||||
# )
|
||||
|
||||
# print(f"Best model found for trial {best_trial.number}")
|
||||
# print(f"Best model error: {best_trial.value}")
|
||||
# print(f"Best model params: {best_trial.params}")
|
||||
# print()
|
||||
# print(best_model)
|
||||
|
||||
# eye_fig = hyper_training.plot_eye()
|
||||
...
|
||||
|
||||
Reference in New Issue
Block a user