clean up regen_no_hyper.py

This commit is contained in:
Joseph Hopfmüller
2024-11-29 15:50:34 +01:00
parent e02662ed4f
commit bdf6f5bfb8

View File

@@ -1,414 +1,130 @@
import copy from hypertraining.settings import (
from dataclasses import dataclass GlobalSettings,
from datetime import datetime DataSettings,
from pathlib import Path PytorchSettings,
import matplotlib.pyplot as plt ModelSettings,
OptimizerSettings,
import numpy as np
import torch
import torch.nn as nn
# import torch.nn.functional as F # mse_loss doesn't support complex numbers
import torch.optim as optim
import torch.utils.data
from torch.utils.tensorboard import SummaryWriter
from rich.progress import (
Progress,
TextColumn,
BarColumn,
TaskProgressColumn,
TimeRemainingColumn,
MofNCompleteColumn,
TimeElapsedColumn,
) )
from rich.console import Console
from rich import print as rprint
# from util.optuna_helpers import optional_suggest_categorical, optional_suggest_float, optional_suggest_int from hypertraining.training import Trainer
import torch
import json
import util import util
global_settings = GlobalSettings(
seed=42,
)
# global settings data_settings = DataSettings(
@dataclass config_path="data/*-128-16384-100000-0-0-17-0-PAM4-0.ini",
class GlobalSettings: dtype="complex64",
seed: int = 42 # symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
symbols=13, # study: single_core_regen_20241123_011232
# output_size = (11, 32), # 26 taps -> 2 taps per input symbol -> 1 tap every 1cm (model has 52 inputs (x/y))
output_size=26, # study: single_core_regen_20241123_011232 (model_input_dim/2)
shuffle=True,
in_out_delay=0,
xy_delay=0,
drop_first=128*64,
train_split=0.8,
)
pytorch_settings = PytorchSettings(
epochs=10000,
batchsize=2**12,
device="cuda",
dataloader_workers=12,
dataloader_prefetch=8,
summary_dir=".runs",
write_every=2**5,
save_models=True,
model_dir=".models",
)
# data settings model_settings = ModelSettings(
@dataclass output_dim=2,
class DataSettings: n_hidden_layers=4,
config_path: str = "data/*-128-16384-100000-0-0-17-0-PAM4-0.ini" overrides={
dtype: torch.dtype = torch.complex64 "n_hidden_nodes_0": 8,
symbols_range: float | int = 8 "n_hidden_nodes_1": 8,
data_size_range: float | int = 64 "n_hidden_nodes_2": 4,
shuffle: bool = True "n_hidden_nodes_3": 6,
target_delay: float = 0 },
xy_delay_range: float | int = 0 model_activation_func="PowScale",
drop_first: int = 10 # dropout_prob=0.01,
train_split: float = 0.8 model_layer_function="ONN",
model_layer_parametrizations=[
{
"tensor_name": "weight",
"parametrization": torch.nn.utils.parametrizations.orthogonal,
},
{
"tensor_name": "scales",
"parametrization": util.complexNN.clamp,
},
{
"tensor_name": "scale",
"parametrization": util.complexNN.clamp,
},
{
"tensor_name": "bias",
"parametrization": util.complexNN.clamp,
},
# {
# "tensor_name": "V",
# "parametrization": torch.nn.utils.parametrizations.orthogonal,
# },
# {
# "tensor_name": "S",
# "parametrization": util.complexNN.clamp,
# },
],
)
optimizer_settings = OptimizerSettings(
optimizer="Adam",
learning_rate=0.05,
scheduler="ReduceLROnPlateau",
scheduler_kwargs={
"patience": 2**6,
"factor": 0.9,
# "threshold": 1e-3,
"min_lr": 1e-6,
"cooldown": 10,
},
)
# pytorch settings def save_dict_to_file(dictionary, filename):
@dataclass """
class PytorchSettings: Save the best dictionary to a JSON file.
epochs: int = 1000
batchsize: int = 2**12
device: str = "cuda"
summary_dir: str = ".runs"
model_dir: str = ".models"
:param best: Dictionary containing the best training results.
:type best: dict
:param filename: Path to the JSON file where the dictionary will be saved.
:type filename: str
"""
with open(filename, 'w') as f:
json.dump(dictionary, f, indent=4)
# model settings
@dataclass
class ModelSettings:
output_size: int = 2
# n_layer_range: float|int = 2
# n_units_range: float|int = 32
n_layers: int = 3
n_units: int = 32
activation_func: tuple | str = "ModReLU"
@dataclass
class OptimizerSettings:
optimizer_range: str = "Adam"
lr_range: float = 2e-3
class Training:
def __init__(self):
self.global_settings = GlobalSettings()
self.data_settings = DataSettings()
self.pytorch_settings = PytorchSettings()
self.model_settings = ModelSettings()
self.optimizer_settings = OptimizerSettings()
self.study_name = (
f"single_core_regen_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}"
)
if not hasattr(self.pytorch_settings, "model_dir"):
self.pytorch_settings.model_dir = ".models"
self.writer = None
self.console = Console()
def setup_tb_writer(self, study_name=None, append=None):
log_dir = (
self.pytorch_settings.summary_dir + "/" + (study_name or self.study_name) + ("_" + str(append)) if append else ""
)
self.writer = SummaryWriter(log_dir)
return self.writer
def plot_eye(self, width=2, symbols=None, alpha=None, complex=False, show=True):
if not hasattr(self, "eye_data"):
data, config = util.datasets.load_data(
self.data_settings.config_path,
skipfirst=10,
symbols=symbols or 1000,
real=not self.data_settings.dtype.is_complex,
normalize=True,
)
self.eye_data = {"data": data, "sps": int(config["glova"]["sps"])}
return util.plot.eye(
**self.eye_data,
width=width,
show=show,
alpha=alpha,
complex=complex,
symbols=symbols or 1000,
skipfirst=0,
)
def define_model(self):
n_layers = self.model_settings.n_layers
in_features = 2 * self.data_settings.data_size_range
layers = []
for i in range(n_layers):
out_features = self.model_settings.n_units
layers.append(util.complexNN.UnitaryLayer(in_features, out_features))
# layers.append(getattr(nn, self.model_settings.activation_func)())
layers.append(
getattr(util.complexNN, self.model_settings.activation_func)()
)
in_features = out_features
layers.append(
util.complexNN.UnitaryLayer(in_features, self.model_settings.output_size)
)
if self.writer is not None:
self.writer.add_graph(
nn.Sequential(*layers),
torch.zeros(1, layers[0].in_features, dtype=self.data_settings.dtype),
)
return nn.Sequential(*layers)
def get_sliced_data(self):
symbols = self.data_settings.symbols_range
xy_delay = self.data_settings.xy_delay_range
data_size = self.data_settings.data_size_range
# get dataset
dataset = util.datasets.FiberRegenerationDataset(
file_path=self.data_settings.config_path,
symbols=symbols,
output_dim=data_size,
target_delay=self.data_settings.target_delay,
xy_delay=xy_delay,
drop_first=self.data_settings.drop_first,
dtype=self.data_settings.dtype,
real=not self.data_settings.dtype.is_complex,
# device=self.pytorch_settings.device,
)
dataset_size = len(dataset)
indices = list(range(dataset_size))
split = int(np.floor(self.data_settings.train_split * dataset_size))
if self.data_settings.shuffle:
np.random.seed(self.global_settings.seed)
np.random.shuffle(indices)
train_indices, valid_indices = indices[:split], indices[split:]
if self.data_settings.shuffle:
train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
valid_sampler = torch.utils.data.SubsetRandomSampler(valid_indices)
else:
train_sampler = train_indices
valid_sampler = valid_indices
train_loader = torch.utils.data.DataLoader(
dataset,
batch_size=self.pytorch_settings.batchsize,
sampler=train_sampler,
drop_last=True,
pin_memory=True,
num_workers=24,
prefetch_factor=4,
# persistent_workers=True
)
valid_loader = torch.utils.data.DataLoader(
dataset,
batch_size=self.pytorch_settings.batchsize,
sampler=valid_sampler,
drop_last=True,
pin_memory=True,
num_workers=24,
prefetch_factor=4,
# persistent_workers=True
)
return train_loader, valid_loader
def train_model(self, model, optimizer, train_loader, epoch):
with Progress(
TextColumn("[yellow] Training..."),
TextColumn("Error: {task.description}"),
BarColumn(),
TaskProgressColumn(),
TextColumn("[green]Batch"),
MofNCompleteColumn(),
TimeRemainingColumn(),
TimeElapsedColumn(),
# description="Training",
transient=False,
console=self.console,
refresh_per_second=10,
) as progress:
task = progress.add_task("-.---e--", total=len(train_loader))
running_loss = 0.0
model.train()
for batch_idx, (x, y) in enumerate(train_loader):
model.zero_grad(set_to_none=True)
x, y = (
x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device),
)
y_pred = model(x)
loss = util.complexNN.complex_mse_loss(y_pred, y)
loss.backward()
optimizer.step()
progress.update(task, advance=1, description=f"{loss.item():.3e}")
running_loss += loss.item()
if self.writer is not None:
if (batch_idx + 1) % 10 == 0:
self.writer.add_scalar(
"training loss",
running_loss / 10,
epoch * len(train_loader) + batch_idx,
)
running_loss = 0.0
return running_loss
def eval_model(self, model, valid_loader, epoch):
with Progress(
TextColumn("[green]Evaluating..."),
TextColumn("Error: {task.description}"),
BarColumn(),
TaskProgressColumn(),
TextColumn("[green]Batch"),
MofNCompleteColumn(),
TimeRemainingColumn(),
TimeElapsedColumn(),
# description="Training",
transient=False,
console=self.console,
refresh_per_second=10,
) as progress:
task = progress.add_task("-.---e--", total=len(valid_loader))
model.eval()
running_loss = 0
running_loss2 = 0
with torch.no_grad():
for batch_idx, (x, y) in enumerate(valid_loader):
x, y = (
x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device),
)
y_pred = model(x)
loss = util.complexNN.complex_mse_loss(y_pred, y)
running_loss += loss.item()
running_loss2 += loss.item()
progress.update(task, advance=1, description=f"{loss.item():.3e}")
if self.writer is not None:
if (batch_idx + 1) % 10 == 0:
self.writer.add_scalar(
"loss",
running_loss / 10,
epoch * len(valid_loader) + batch_idx,
)
running_loss = 0.0
if self.writer is not None:
self.writer.add_figure("fiber response", self.plot_model_response(model, plot=False), epoch+1)
return running_loss2 / len(valid_loader)
def run_model(self, model, loader):
model.eval()
xs = []
ys = []
y_preds = []
with torch.no_grad():
model = model.to(self.pytorch_settings.device)
for x, y in loader:
x, y = (
x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device),
)
y_pred = model(x).cpu()
# x = x.cpu()
# y = y.cpu()
y_pred = y_pred.view(y_pred.shape[0], -1, 2)
y = y.view(y.shape[0], -1, 2)
x = x.view(x.shape[0], -1, 2)
xs.append(x[:, 0, :].squeeze())
ys.append(y.squeeze())
y_preds.append(y_pred.squeeze())
xs = torch.vstack(xs).cpu()
ys = torch.vstack(ys).cpu()
y_preds = torch.vstack(y_preds).cpu()
return ys, xs, y_preds
def objective(self, save=False, plot_before=False):
try:
rprint(*list(self.study_name.split("_")))
self.model = self.define_model().to(self.pytorch_settings.device)
if self.writer is not None:
self.writer.add_figure("fiber response", self.plot_model_response(plot=plot_before), 0)
train_loader, valid_loader = self.get_sliced_data()
optimizer_name = self.optimizer_settings.optimizer_range
lr = self.optimizer_settings.lr_range
optimizer = getattr(optim, optimizer_name)(self.model.parameters(), lr=lr)
for epoch in range(self.pytorch_settings.epochs):
self.console.rule(f"Epoch {epoch + 1}/{self.pytorch_settings.epochs}")
self.train_model(self.model, optimizer, train_loader, epoch)
eval_loss = self.eval_model(self.model, valid_loader, epoch)
return eval_loss
except KeyboardInterrupt:
...
finally:
if hasattr(self, "model"):
save_path = (
Path(self.pytorch_settings.model_dir) / f"{self.study_name}.pth"
)
save_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(self.model, save_path)
def _plot_model_response_plotter(self, fiber_in, fiber_out, regen, plot=True):
fig, axs = plt.subplots(2)
for i, ax in enumerate(axs):
ax.plot(np.abs(fiber_in[:, i]) ** 2, label="fiber in")
ax.plot(np.abs(fiber_out[:, i]) ** 2, label="fiber out")
ax.plot(np.abs(regen[:, i]) ** 2, label="regenerated")
ax.legend()
if plot:
plt.show()
return fig
def plot_model_response(self, model=None, plot=True):
data_settings_backup = copy.copy(self.data_settings)
self.data_settings.shuffle = False
self.data_settings.train_split = 0.01
self.data_settings.drop_first = 100
plot_loader, _ = self.get_sliced_data()
self.data_settings = data_settings_backup
fiber_in, fiber_out, regen = self.run_model(model or self.model, plot_loader)
fiber_in = fiber_in.view(-1, 2)
fiber_out = fiber_out.view(-1, 2)
regen = regen.view(-1, 2)
fiber_in = fiber_in.numpy()
fiber_out = fiber_out.numpy()
regen = regen.numpy()
# https://github.com/matplotlib/matplotlib/issues/27713#issue-2104110987
# https://github.com/matplotlib/matplotlib/issues/27713#issuecomment-1915497463
import gc
fig = self._plot_model_response_plotter(fiber_in, fiber_out, regen, plot=plot)
gc.collect()
return fig
if __name__ == "__main__": if __name__ == "__main__":
trainer = Training() trainer = Trainer(
global_settings=global_settings,
data_settings=data_settings,
pytorch_settings=pytorch_settings,
model_settings=model_settings,
optimizer_settings=optimizer_settings,
checkpoint_path='.models/20241128_084935_8885.tar',
settings_override={
"model_settings": {
# "model_activation_func": "PowScale",
"dropout_prob": 0,
}
},
reset_epoch=True,
)
# trainer.plot_eye() best = trainer.train()
trainer.setup_tb_writer() save_dict_to_file(best, ".models/best_results.json")
trainer.objective(save=True)
best_model = trainer.model
# best_model = trainer.define_model(trainer.study.best_trial).to(trainer.pytorch_settings.device)
trainer.plot_model_response(best_model)
# print(f"Best model: {best_model}")
... ...