clean up regen_no_hyper.py
This commit is contained in:
@@ -1,414 +1,130 @@
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# import torch.nn.functional as F # mse_loss doesn't support complex numbers
|
||||
import torch.optim as optim
|
||||
import torch.utils.data
|
||||
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from rich.progress import (
|
||||
Progress,
|
||||
TextColumn,
|
||||
BarColumn,
|
||||
TaskProgressColumn,
|
||||
TimeRemainingColumn,
|
||||
MofNCompleteColumn,
|
||||
TimeElapsedColumn,
|
||||
from hypertraining.settings import (
|
||||
GlobalSettings,
|
||||
DataSettings,
|
||||
PytorchSettings,
|
||||
ModelSettings,
|
||||
OptimizerSettings,
|
||||
)
|
||||
from rich.console import Console
|
||||
from rich import print as rprint
|
||||
|
||||
# from util.optuna_helpers import optional_suggest_categorical, optional_suggest_float, optional_suggest_int
|
||||
from hypertraining.training import Trainer
|
||||
import torch
|
||||
import json
|
||||
import util
|
||||
|
||||
global_settings = GlobalSettings(
|
||||
seed=42,
|
||||
)
|
||||
|
||||
# global settings
|
||||
@dataclass
|
||||
class GlobalSettings:
|
||||
seed: int = 42
|
||||
data_settings = DataSettings(
|
||||
config_path="data/*-128-16384-100000-0-0-17-0-PAM4-0.ini",
|
||||
dtype="complex64",
|
||||
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
|
||||
symbols=13, # study: single_core_regen_20241123_011232
|
||||
# output_size = (11, 32), # 26 taps -> 2 taps per input symbol -> 1 tap every 1cm (model has 52 inputs (x/y))
|
||||
output_size=26, # study: single_core_regen_20241123_011232 (model_input_dim/2)
|
||||
shuffle=True,
|
||||
in_out_delay=0,
|
||||
xy_delay=0,
|
||||
drop_first=128*64,
|
||||
train_split=0.8,
|
||||
)
|
||||
|
||||
pytorch_settings = PytorchSettings(
|
||||
epochs=10000,
|
||||
batchsize=2**12,
|
||||
device="cuda",
|
||||
dataloader_workers=12,
|
||||
dataloader_prefetch=8,
|
||||
summary_dir=".runs",
|
||||
write_every=2**5,
|
||||
save_models=True,
|
||||
model_dir=".models",
|
||||
)
|
||||
|
||||
# data settings
|
||||
@dataclass
|
||||
class DataSettings:
|
||||
config_path: str = "data/*-128-16384-100000-0-0-17-0-PAM4-0.ini"
|
||||
dtype: torch.dtype = torch.complex64
|
||||
symbols_range: float | int = 8
|
||||
data_size_range: float | int = 64
|
||||
shuffle: bool = True
|
||||
target_delay: float = 0
|
||||
xy_delay_range: float | int = 0
|
||||
drop_first: int = 10
|
||||
train_split: float = 0.8
|
||||
model_settings = ModelSettings(
|
||||
output_dim=2,
|
||||
n_hidden_layers=4,
|
||||
overrides={
|
||||
"n_hidden_nodes_0": 8,
|
||||
"n_hidden_nodes_1": 8,
|
||||
"n_hidden_nodes_2": 4,
|
||||
"n_hidden_nodes_3": 6,
|
||||
},
|
||||
model_activation_func="PowScale",
|
||||
# dropout_prob=0.01,
|
||||
model_layer_function="ONN",
|
||||
model_layer_parametrizations=[
|
||||
{
|
||||
"tensor_name": "weight",
|
||||
"parametrization": torch.nn.utils.parametrizations.orthogonal,
|
||||
},
|
||||
{
|
||||
"tensor_name": "scales",
|
||||
"parametrization": util.complexNN.clamp,
|
||||
},
|
||||
{
|
||||
"tensor_name": "scale",
|
||||
"parametrization": util.complexNN.clamp,
|
||||
},
|
||||
{
|
||||
"tensor_name": "bias",
|
||||
"parametrization": util.complexNN.clamp,
|
||||
},
|
||||
# {
|
||||
# "tensor_name": "V",
|
||||
# "parametrization": torch.nn.utils.parametrizations.orthogonal,
|
||||
# },
|
||||
# {
|
||||
# "tensor_name": "S",
|
||||
# "parametrization": util.complexNN.clamp,
|
||||
# },
|
||||
],
|
||||
)
|
||||
|
||||
optimizer_settings = OptimizerSettings(
|
||||
optimizer="Adam",
|
||||
learning_rate=0.05,
|
||||
scheduler="ReduceLROnPlateau",
|
||||
scheduler_kwargs={
|
||||
"patience": 2**6,
|
||||
"factor": 0.9,
|
||||
# "threshold": 1e-3,
|
||||
"min_lr": 1e-6,
|
||||
"cooldown": 10,
|
||||
},
|
||||
)
|
||||
|
||||
# pytorch settings
|
||||
@dataclass
|
||||
class PytorchSettings:
|
||||
epochs: int = 1000
|
||||
batchsize: int = 2**12
|
||||
device: str = "cuda"
|
||||
summary_dir: str = ".runs"
|
||||
model_dir: str = ".models"
|
||||
def save_dict_to_file(dictionary, filename):
|
||||
"""
|
||||
Save the best dictionary to a JSON file.
|
||||
|
||||
:param best: Dictionary containing the best training results.
|
||||
:type best: dict
|
||||
:param filename: Path to the JSON file where the dictionary will be saved.
|
||||
:type filename: str
|
||||
"""
|
||||
with open(filename, 'w') as f:
|
||||
json.dump(dictionary, f, indent=4)
|
||||
|
||||
# model settings
|
||||
@dataclass
|
||||
class ModelSettings:
|
||||
output_size: int = 2
|
||||
# n_layer_range: float|int = 2
|
||||
# n_units_range: float|int = 32
|
||||
n_layers: int = 3
|
||||
n_units: int = 32
|
||||
activation_func: tuple | str = "ModReLU"
|
||||
|
||||
|
||||
@dataclass
|
||||
class OptimizerSettings:
|
||||
optimizer_range: str = "Adam"
|
||||
lr_range: float = 2e-3
|
||||
|
||||
|
||||
class Training:
|
||||
def __init__(self):
|
||||
self.global_settings = GlobalSettings()
|
||||
self.data_settings = DataSettings()
|
||||
self.pytorch_settings = PytorchSettings()
|
||||
self.model_settings = ModelSettings()
|
||||
self.optimizer_settings = OptimizerSettings()
|
||||
self.study_name = (
|
||||
f"single_core_regen_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}"
|
||||
)
|
||||
|
||||
if not hasattr(self.pytorch_settings, "model_dir"):
|
||||
self.pytorch_settings.model_dir = ".models"
|
||||
|
||||
self.writer = None
|
||||
self.console = Console()
|
||||
|
||||
def setup_tb_writer(self, study_name=None, append=None):
|
||||
log_dir = (
|
||||
self.pytorch_settings.summary_dir + "/" + (study_name or self.study_name) + ("_" + str(append)) if append else ""
|
||||
)
|
||||
self.writer = SummaryWriter(log_dir)
|
||||
return self.writer
|
||||
|
||||
def plot_eye(self, width=2, symbols=None, alpha=None, complex=False, show=True):
|
||||
if not hasattr(self, "eye_data"):
|
||||
data, config = util.datasets.load_data(
|
||||
self.data_settings.config_path,
|
||||
skipfirst=10,
|
||||
symbols=symbols or 1000,
|
||||
real=not self.data_settings.dtype.is_complex,
|
||||
normalize=True,
|
||||
)
|
||||
self.eye_data = {"data": data, "sps": int(config["glova"]["sps"])}
|
||||
return util.plot.eye(
|
||||
**self.eye_data,
|
||||
width=width,
|
||||
show=show,
|
||||
alpha=alpha,
|
||||
complex=complex,
|
||||
symbols=symbols or 1000,
|
||||
skipfirst=0,
|
||||
)
|
||||
|
||||
def define_model(self):
|
||||
n_layers = self.model_settings.n_layers
|
||||
|
||||
in_features = 2 * self.data_settings.data_size_range
|
||||
|
||||
layers = []
|
||||
for i in range(n_layers):
|
||||
out_features = self.model_settings.n_units
|
||||
|
||||
layers.append(util.complexNN.UnitaryLayer(in_features, out_features))
|
||||
# layers.append(getattr(nn, self.model_settings.activation_func)())
|
||||
layers.append(
|
||||
getattr(util.complexNN, self.model_settings.activation_func)()
|
||||
)
|
||||
in_features = out_features
|
||||
|
||||
layers.append(
|
||||
util.complexNN.UnitaryLayer(in_features, self.model_settings.output_size)
|
||||
)
|
||||
|
||||
if self.writer is not None:
|
||||
self.writer.add_graph(
|
||||
nn.Sequential(*layers),
|
||||
torch.zeros(1, layers[0].in_features, dtype=self.data_settings.dtype),
|
||||
)
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def get_sliced_data(self):
|
||||
symbols = self.data_settings.symbols_range
|
||||
|
||||
xy_delay = self.data_settings.xy_delay_range
|
||||
|
||||
data_size = self.data_settings.data_size_range
|
||||
|
||||
# get dataset
|
||||
dataset = util.datasets.FiberRegenerationDataset(
|
||||
file_path=self.data_settings.config_path,
|
||||
symbols=symbols,
|
||||
output_dim=data_size,
|
||||
target_delay=self.data_settings.target_delay,
|
||||
xy_delay=xy_delay,
|
||||
drop_first=self.data_settings.drop_first,
|
||||
dtype=self.data_settings.dtype,
|
||||
real=not self.data_settings.dtype.is_complex,
|
||||
# device=self.pytorch_settings.device,
|
||||
)
|
||||
|
||||
dataset_size = len(dataset)
|
||||
indices = list(range(dataset_size))
|
||||
split = int(np.floor(self.data_settings.train_split * dataset_size))
|
||||
if self.data_settings.shuffle:
|
||||
np.random.seed(self.global_settings.seed)
|
||||
np.random.shuffle(indices)
|
||||
|
||||
train_indices, valid_indices = indices[:split], indices[split:]
|
||||
|
||||
if self.data_settings.shuffle:
|
||||
train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
|
||||
valid_sampler = torch.utils.data.SubsetRandomSampler(valid_indices)
|
||||
else:
|
||||
train_sampler = train_indices
|
||||
valid_sampler = valid_indices
|
||||
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=self.pytorch_settings.batchsize,
|
||||
sampler=train_sampler,
|
||||
drop_last=True,
|
||||
pin_memory=True,
|
||||
num_workers=24,
|
||||
prefetch_factor=4,
|
||||
# persistent_workers=True
|
||||
)
|
||||
valid_loader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=self.pytorch_settings.batchsize,
|
||||
sampler=valid_sampler,
|
||||
drop_last=True,
|
||||
pin_memory=True,
|
||||
num_workers=24,
|
||||
prefetch_factor=4,
|
||||
# persistent_workers=True
|
||||
)
|
||||
|
||||
return train_loader, valid_loader
|
||||
|
||||
def train_model(self, model, optimizer, train_loader, epoch):
|
||||
with Progress(
|
||||
TextColumn("[yellow] Training..."),
|
||||
TextColumn("Error: {task.description}"),
|
||||
BarColumn(),
|
||||
TaskProgressColumn(),
|
||||
TextColumn("[green]Batch"),
|
||||
MofNCompleteColumn(),
|
||||
TimeRemainingColumn(),
|
||||
TimeElapsedColumn(),
|
||||
# description="Training",
|
||||
transient=False,
|
||||
console=self.console,
|
||||
refresh_per_second=10,
|
||||
) as progress:
|
||||
task = progress.add_task("-.---e--", total=len(train_loader))
|
||||
|
||||
running_loss = 0.0
|
||||
model.train()
|
||||
for batch_idx, (x, y) in enumerate(train_loader):
|
||||
model.zero_grad(set_to_none=True)
|
||||
x, y = (
|
||||
x.to(self.pytorch_settings.device),
|
||||
y.to(self.pytorch_settings.device),
|
||||
)
|
||||
y_pred = model(x)
|
||||
loss = util.complexNN.complex_mse_loss(y_pred, y)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
progress.update(task, advance=1, description=f"{loss.item():.3e}")
|
||||
|
||||
running_loss += loss.item()
|
||||
if self.writer is not None:
|
||||
if (batch_idx + 1) % 10 == 0:
|
||||
self.writer.add_scalar(
|
||||
"training loss",
|
||||
running_loss / 10,
|
||||
epoch * len(train_loader) + batch_idx,
|
||||
)
|
||||
running_loss = 0.0
|
||||
|
||||
return running_loss
|
||||
|
||||
def eval_model(self, model, valid_loader, epoch):
|
||||
with Progress(
|
||||
TextColumn("[green]Evaluating..."),
|
||||
TextColumn("Error: {task.description}"),
|
||||
BarColumn(),
|
||||
TaskProgressColumn(),
|
||||
TextColumn("[green]Batch"),
|
||||
MofNCompleteColumn(),
|
||||
TimeRemainingColumn(),
|
||||
TimeElapsedColumn(),
|
||||
# description="Training",
|
||||
transient=False,
|
||||
console=self.console,
|
||||
refresh_per_second=10,
|
||||
) as progress:
|
||||
task = progress.add_task("-.---e--", total=len(valid_loader))
|
||||
|
||||
model.eval()
|
||||
running_loss = 0
|
||||
running_loss2 = 0
|
||||
with torch.no_grad():
|
||||
for batch_idx, (x, y) in enumerate(valid_loader):
|
||||
x, y = (
|
||||
x.to(self.pytorch_settings.device),
|
||||
y.to(self.pytorch_settings.device),
|
||||
)
|
||||
y_pred = model(x)
|
||||
loss = util.complexNN.complex_mse_loss(y_pred, y)
|
||||
running_loss += loss.item()
|
||||
running_loss2 += loss.item()
|
||||
|
||||
progress.update(task, advance=1, description=f"{loss.item():.3e}")
|
||||
if self.writer is not None:
|
||||
if (batch_idx + 1) % 10 == 0:
|
||||
self.writer.add_scalar(
|
||||
"loss",
|
||||
running_loss / 10,
|
||||
epoch * len(valid_loader) + batch_idx,
|
||||
)
|
||||
running_loss = 0.0
|
||||
|
||||
if self.writer is not None:
|
||||
self.writer.add_figure("fiber response", self.plot_model_response(model, plot=False), epoch+1)
|
||||
|
||||
return running_loss2 / len(valid_loader)
|
||||
|
||||
def run_model(self, model, loader):
|
||||
model.eval()
|
||||
xs = []
|
||||
ys = []
|
||||
y_preds = []
|
||||
with torch.no_grad():
|
||||
model = model.to(self.pytorch_settings.device)
|
||||
for x, y in loader:
|
||||
x, y = (
|
||||
x.to(self.pytorch_settings.device),
|
||||
y.to(self.pytorch_settings.device),
|
||||
)
|
||||
y_pred = model(x).cpu()
|
||||
# x = x.cpu()
|
||||
# y = y.cpu()
|
||||
y_pred = y_pred.view(y_pred.shape[0], -1, 2)
|
||||
y = y.view(y.shape[0], -1, 2)
|
||||
x = x.view(x.shape[0], -1, 2)
|
||||
xs.append(x[:, 0, :].squeeze())
|
||||
ys.append(y.squeeze())
|
||||
y_preds.append(y_pred.squeeze())
|
||||
|
||||
xs = torch.vstack(xs).cpu()
|
||||
ys = torch.vstack(ys).cpu()
|
||||
y_preds = torch.vstack(y_preds).cpu()
|
||||
return ys, xs, y_preds
|
||||
|
||||
def objective(self, save=False, plot_before=False):
|
||||
try:
|
||||
rprint(*list(self.study_name.split("_")))
|
||||
|
||||
self.model = self.define_model().to(self.pytorch_settings.device)
|
||||
|
||||
if self.writer is not None:
|
||||
self.writer.add_figure("fiber response", self.plot_model_response(plot=plot_before), 0)
|
||||
|
||||
train_loader, valid_loader = self.get_sliced_data()
|
||||
|
||||
optimizer_name = self.optimizer_settings.optimizer_range
|
||||
|
||||
lr = self.optimizer_settings.lr_range
|
||||
|
||||
optimizer = getattr(optim, optimizer_name)(self.model.parameters(), lr=lr)
|
||||
|
||||
for epoch in range(self.pytorch_settings.epochs):
|
||||
self.console.rule(f"Epoch {epoch + 1}/{self.pytorch_settings.epochs}")
|
||||
self.train_model(self.model, optimizer, train_loader, epoch)
|
||||
eval_loss = self.eval_model(self.model, valid_loader, epoch)
|
||||
|
||||
return eval_loss
|
||||
|
||||
except KeyboardInterrupt:
|
||||
...
|
||||
finally:
|
||||
if hasattr(self, "model"):
|
||||
save_path = (
|
||||
Path(self.pytorch_settings.model_dir) / f"{self.study_name}.pth"
|
||||
)
|
||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
torch.save(self.model, save_path)
|
||||
|
||||
def _plot_model_response_plotter(self, fiber_in, fiber_out, regen, plot=True):
|
||||
fig, axs = plt.subplots(2)
|
||||
for i, ax in enumerate(axs):
|
||||
ax.plot(np.abs(fiber_in[:, i]) ** 2, label="fiber in")
|
||||
ax.plot(np.abs(fiber_out[:, i]) ** 2, label="fiber out")
|
||||
ax.plot(np.abs(regen[:, i]) ** 2, label="regenerated")
|
||||
ax.legend()
|
||||
if plot:
|
||||
plt.show()
|
||||
return fig
|
||||
|
||||
def plot_model_response(self, model=None, plot=True):
|
||||
data_settings_backup = copy.copy(self.data_settings)
|
||||
self.data_settings.shuffle = False
|
||||
self.data_settings.train_split = 0.01
|
||||
self.data_settings.drop_first = 100
|
||||
plot_loader, _ = self.get_sliced_data()
|
||||
self.data_settings = data_settings_backup
|
||||
|
||||
fiber_in, fiber_out, regen = self.run_model(model or self.model, plot_loader)
|
||||
fiber_in = fiber_in.view(-1, 2)
|
||||
fiber_out = fiber_out.view(-1, 2)
|
||||
regen = regen.view(-1, 2)
|
||||
|
||||
fiber_in = fiber_in.numpy()
|
||||
fiber_out = fiber_out.numpy()
|
||||
regen = regen.numpy()
|
||||
|
||||
# https://github.com/matplotlib/matplotlib/issues/27713#issue-2104110987
|
||||
# https://github.com/matplotlib/matplotlib/issues/27713#issuecomment-1915497463
|
||||
import gc
|
||||
fig = self._plot_model_response_plotter(fiber_in, fiber_out, regen, plot=plot)
|
||||
gc.collect()
|
||||
|
||||
return fig
|
||||
|
||||
if __name__ == "__main__":
|
||||
trainer = Training()
|
||||
trainer = Trainer(
|
||||
global_settings=global_settings,
|
||||
data_settings=data_settings,
|
||||
pytorch_settings=pytorch_settings,
|
||||
model_settings=model_settings,
|
||||
optimizer_settings=optimizer_settings,
|
||||
checkpoint_path='.models/20241128_084935_8885.tar',
|
||||
settings_override={
|
||||
"model_settings": {
|
||||
# "model_activation_func": "PowScale",
|
||||
"dropout_prob": 0,
|
||||
}
|
||||
},
|
||||
reset_epoch=True,
|
||||
)
|
||||
|
||||
# trainer.plot_eye()
|
||||
trainer.setup_tb_writer()
|
||||
trainer.objective(save=True)
|
||||
|
||||
best_model = trainer.model
|
||||
|
||||
# best_model = trainer.define_model(trainer.study.best_trial).to(trainer.pytorch_settings.device)
|
||||
trainer.plot_model_response(best_model)
|
||||
|
||||
# print(f"Best model: {best_model}")
|
||||
best = trainer.train()
|
||||
save_dict_to_file(best, ".models/best_results.json")
|
||||
|
||||
...
|
||||
|
||||
Reference in New Issue
Block a user