clean up regen_no_hyper.py
This commit is contained in:
@@ -1,414 +1,130 @@
|
|||||||
import copy
|
from hypertraining.settings import (
|
||||||
from dataclasses import dataclass
|
GlobalSettings,
|
||||||
from datetime import datetime
|
DataSettings,
|
||||||
from pathlib import Path
|
PytorchSettings,
|
||||||
import matplotlib.pyplot as plt
|
ModelSettings,
|
||||||
|
OptimizerSettings,
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
# import torch.nn.functional as F # mse_loss doesn't support complex numbers
|
|
||||||
import torch.optim as optim
|
|
||||||
import torch.utils.data
|
|
||||||
|
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
|
||||||
|
|
||||||
from rich.progress import (
|
|
||||||
Progress,
|
|
||||||
TextColumn,
|
|
||||||
BarColumn,
|
|
||||||
TaskProgressColumn,
|
|
||||||
TimeRemainingColumn,
|
|
||||||
MofNCompleteColumn,
|
|
||||||
TimeElapsedColumn,
|
|
||||||
)
|
)
|
||||||
from rich.console import Console
|
|
||||||
from rich import print as rprint
|
|
||||||
|
|
||||||
# from util.optuna_helpers import optional_suggest_categorical, optional_suggest_float, optional_suggest_int
|
from hypertraining.training import Trainer
|
||||||
|
import torch
|
||||||
|
import json
|
||||||
import util
|
import util
|
||||||
|
|
||||||
|
global_settings = GlobalSettings(
|
||||||
|
seed=42,
|
||||||
|
)
|
||||||
|
|
||||||
# global settings
|
data_settings = DataSettings(
|
||||||
@dataclass
|
config_path="data/*-128-16384-100000-0-0-17-0-PAM4-0.ini",
|
||||||
class GlobalSettings:
|
dtype="complex64",
|
||||||
seed: int = 42
|
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
|
||||||
|
symbols=13, # study: single_core_regen_20241123_011232
|
||||||
|
# output_size = (11, 32), # 26 taps -> 2 taps per input symbol -> 1 tap every 1cm (model has 52 inputs (x/y))
|
||||||
|
output_size=26, # study: single_core_regen_20241123_011232 (model_input_dim/2)
|
||||||
|
shuffle=True,
|
||||||
|
in_out_delay=0,
|
||||||
|
xy_delay=0,
|
||||||
|
drop_first=128*64,
|
||||||
|
train_split=0.8,
|
||||||
|
)
|
||||||
|
|
||||||
|
pytorch_settings = PytorchSettings(
|
||||||
|
epochs=10000,
|
||||||
|
batchsize=2**12,
|
||||||
|
device="cuda",
|
||||||
|
dataloader_workers=12,
|
||||||
|
dataloader_prefetch=8,
|
||||||
|
summary_dir=".runs",
|
||||||
|
write_every=2**5,
|
||||||
|
save_models=True,
|
||||||
|
model_dir=".models",
|
||||||
|
)
|
||||||
|
|
||||||
# data settings
|
model_settings = ModelSettings(
|
||||||
@dataclass
|
output_dim=2,
|
||||||
class DataSettings:
|
n_hidden_layers=4,
|
||||||
config_path: str = "data/*-128-16384-100000-0-0-17-0-PAM4-0.ini"
|
overrides={
|
||||||
dtype: torch.dtype = torch.complex64
|
"n_hidden_nodes_0": 8,
|
||||||
symbols_range: float | int = 8
|
"n_hidden_nodes_1": 8,
|
||||||
data_size_range: float | int = 64
|
"n_hidden_nodes_2": 4,
|
||||||
shuffle: bool = True
|
"n_hidden_nodes_3": 6,
|
||||||
target_delay: float = 0
|
},
|
||||||
xy_delay_range: float | int = 0
|
model_activation_func="PowScale",
|
||||||
drop_first: int = 10
|
# dropout_prob=0.01,
|
||||||
train_split: float = 0.8
|
model_layer_function="ONN",
|
||||||
|
model_layer_parametrizations=[
|
||||||
|
{
|
||||||
|
"tensor_name": "weight",
|
||||||
|
"parametrization": torch.nn.utils.parametrizations.orthogonal,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tensor_name": "scales",
|
||||||
|
"parametrization": util.complexNN.clamp,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tensor_name": "scale",
|
||||||
|
"parametrization": util.complexNN.clamp,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tensor_name": "bias",
|
||||||
|
"parametrization": util.complexNN.clamp,
|
||||||
|
},
|
||||||
|
# {
|
||||||
|
# "tensor_name": "V",
|
||||||
|
# "parametrization": torch.nn.utils.parametrizations.orthogonal,
|
||||||
|
# },
|
||||||
|
# {
|
||||||
|
# "tensor_name": "S",
|
||||||
|
# "parametrization": util.complexNN.clamp,
|
||||||
|
# },
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizer_settings = OptimizerSettings(
|
||||||
|
optimizer="Adam",
|
||||||
|
learning_rate=0.05,
|
||||||
|
scheduler="ReduceLROnPlateau",
|
||||||
|
scheduler_kwargs={
|
||||||
|
"patience": 2**6,
|
||||||
|
"factor": 0.9,
|
||||||
|
# "threshold": 1e-3,
|
||||||
|
"min_lr": 1e-6,
|
||||||
|
"cooldown": 10,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
# pytorch settings
|
def save_dict_to_file(dictionary, filename):
|
||||||
@dataclass
|
"""
|
||||||
class PytorchSettings:
|
Save the best dictionary to a JSON file.
|
||||||
epochs: int = 1000
|
|
||||||
batchsize: int = 2**12
|
|
||||||
device: str = "cuda"
|
|
||||||
summary_dir: str = ".runs"
|
|
||||||
model_dir: str = ".models"
|
|
||||||
|
|
||||||
|
:param best: Dictionary containing the best training results.
|
||||||
|
:type best: dict
|
||||||
|
:param filename: Path to the JSON file where the dictionary will be saved.
|
||||||
|
:type filename: str
|
||||||
|
"""
|
||||||
|
with open(filename, 'w') as f:
|
||||||
|
json.dump(dictionary, f, indent=4)
|
||||||
|
|
||||||
# model settings
|
|
||||||
@dataclass
|
|
||||||
class ModelSettings:
|
|
||||||
output_size: int = 2
|
|
||||||
# n_layer_range: float|int = 2
|
|
||||||
# n_units_range: float|int = 32
|
|
||||||
n_layers: int = 3
|
|
||||||
n_units: int = 32
|
|
||||||
activation_func: tuple | str = "ModReLU"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class OptimizerSettings:
|
|
||||||
optimizer_range: str = "Adam"
|
|
||||||
lr_range: float = 2e-3
|
|
||||||
|
|
||||||
|
|
||||||
class Training:
|
|
||||||
def __init__(self):
|
|
||||||
self.global_settings = GlobalSettings()
|
|
||||||
self.data_settings = DataSettings()
|
|
||||||
self.pytorch_settings = PytorchSettings()
|
|
||||||
self.model_settings = ModelSettings()
|
|
||||||
self.optimizer_settings = OptimizerSettings()
|
|
||||||
self.study_name = (
|
|
||||||
f"single_core_regen_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not hasattr(self.pytorch_settings, "model_dir"):
|
|
||||||
self.pytorch_settings.model_dir = ".models"
|
|
||||||
|
|
||||||
self.writer = None
|
|
||||||
self.console = Console()
|
|
||||||
|
|
||||||
def setup_tb_writer(self, study_name=None, append=None):
|
|
||||||
log_dir = (
|
|
||||||
self.pytorch_settings.summary_dir + "/" + (study_name or self.study_name) + ("_" + str(append)) if append else ""
|
|
||||||
)
|
|
||||||
self.writer = SummaryWriter(log_dir)
|
|
||||||
return self.writer
|
|
||||||
|
|
||||||
def plot_eye(self, width=2, symbols=None, alpha=None, complex=False, show=True):
|
|
||||||
if not hasattr(self, "eye_data"):
|
|
||||||
data, config = util.datasets.load_data(
|
|
||||||
self.data_settings.config_path,
|
|
||||||
skipfirst=10,
|
|
||||||
symbols=symbols or 1000,
|
|
||||||
real=not self.data_settings.dtype.is_complex,
|
|
||||||
normalize=True,
|
|
||||||
)
|
|
||||||
self.eye_data = {"data": data, "sps": int(config["glova"]["sps"])}
|
|
||||||
return util.plot.eye(
|
|
||||||
**self.eye_data,
|
|
||||||
width=width,
|
|
||||||
show=show,
|
|
||||||
alpha=alpha,
|
|
||||||
complex=complex,
|
|
||||||
symbols=symbols or 1000,
|
|
||||||
skipfirst=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
def define_model(self):
|
|
||||||
n_layers = self.model_settings.n_layers
|
|
||||||
|
|
||||||
in_features = 2 * self.data_settings.data_size_range
|
|
||||||
|
|
||||||
layers = []
|
|
||||||
for i in range(n_layers):
|
|
||||||
out_features = self.model_settings.n_units
|
|
||||||
|
|
||||||
layers.append(util.complexNN.UnitaryLayer(in_features, out_features))
|
|
||||||
# layers.append(getattr(nn, self.model_settings.activation_func)())
|
|
||||||
layers.append(
|
|
||||||
getattr(util.complexNN, self.model_settings.activation_func)()
|
|
||||||
)
|
|
||||||
in_features = out_features
|
|
||||||
|
|
||||||
layers.append(
|
|
||||||
util.complexNN.UnitaryLayer(in_features, self.model_settings.output_size)
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.writer is not None:
|
|
||||||
self.writer.add_graph(
|
|
||||||
nn.Sequential(*layers),
|
|
||||||
torch.zeros(1, layers[0].in_features, dtype=self.data_settings.dtype),
|
|
||||||
)
|
|
||||||
|
|
||||||
return nn.Sequential(*layers)
|
|
||||||
|
|
||||||
def get_sliced_data(self):
|
|
||||||
symbols = self.data_settings.symbols_range
|
|
||||||
|
|
||||||
xy_delay = self.data_settings.xy_delay_range
|
|
||||||
|
|
||||||
data_size = self.data_settings.data_size_range
|
|
||||||
|
|
||||||
# get dataset
|
|
||||||
dataset = util.datasets.FiberRegenerationDataset(
|
|
||||||
file_path=self.data_settings.config_path,
|
|
||||||
symbols=symbols,
|
|
||||||
output_dim=data_size,
|
|
||||||
target_delay=self.data_settings.target_delay,
|
|
||||||
xy_delay=xy_delay,
|
|
||||||
drop_first=self.data_settings.drop_first,
|
|
||||||
dtype=self.data_settings.dtype,
|
|
||||||
real=not self.data_settings.dtype.is_complex,
|
|
||||||
# device=self.pytorch_settings.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset_size = len(dataset)
|
|
||||||
indices = list(range(dataset_size))
|
|
||||||
split = int(np.floor(self.data_settings.train_split * dataset_size))
|
|
||||||
if self.data_settings.shuffle:
|
|
||||||
np.random.seed(self.global_settings.seed)
|
|
||||||
np.random.shuffle(indices)
|
|
||||||
|
|
||||||
train_indices, valid_indices = indices[:split], indices[split:]
|
|
||||||
|
|
||||||
if self.data_settings.shuffle:
|
|
||||||
train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
|
|
||||||
valid_sampler = torch.utils.data.SubsetRandomSampler(valid_indices)
|
|
||||||
else:
|
|
||||||
train_sampler = train_indices
|
|
||||||
valid_sampler = valid_indices
|
|
||||||
|
|
||||||
|
|
||||||
train_loader = torch.utils.data.DataLoader(
|
|
||||||
dataset,
|
|
||||||
batch_size=self.pytorch_settings.batchsize,
|
|
||||||
sampler=train_sampler,
|
|
||||||
drop_last=True,
|
|
||||||
pin_memory=True,
|
|
||||||
num_workers=24,
|
|
||||||
prefetch_factor=4,
|
|
||||||
# persistent_workers=True
|
|
||||||
)
|
|
||||||
valid_loader = torch.utils.data.DataLoader(
|
|
||||||
dataset,
|
|
||||||
batch_size=self.pytorch_settings.batchsize,
|
|
||||||
sampler=valid_sampler,
|
|
||||||
drop_last=True,
|
|
||||||
pin_memory=True,
|
|
||||||
num_workers=24,
|
|
||||||
prefetch_factor=4,
|
|
||||||
# persistent_workers=True
|
|
||||||
)
|
|
||||||
|
|
||||||
return train_loader, valid_loader
|
|
||||||
|
|
||||||
def train_model(self, model, optimizer, train_loader, epoch):
|
|
||||||
with Progress(
|
|
||||||
TextColumn("[yellow] Training..."),
|
|
||||||
TextColumn("Error: {task.description}"),
|
|
||||||
BarColumn(),
|
|
||||||
TaskProgressColumn(),
|
|
||||||
TextColumn("[green]Batch"),
|
|
||||||
MofNCompleteColumn(),
|
|
||||||
TimeRemainingColumn(),
|
|
||||||
TimeElapsedColumn(),
|
|
||||||
# description="Training",
|
|
||||||
transient=False,
|
|
||||||
console=self.console,
|
|
||||||
refresh_per_second=10,
|
|
||||||
) as progress:
|
|
||||||
task = progress.add_task("-.---e--", total=len(train_loader))
|
|
||||||
|
|
||||||
running_loss = 0.0
|
|
||||||
model.train()
|
|
||||||
for batch_idx, (x, y) in enumerate(train_loader):
|
|
||||||
model.zero_grad(set_to_none=True)
|
|
||||||
x, y = (
|
|
||||||
x.to(self.pytorch_settings.device),
|
|
||||||
y.to(self.pytorch_settings.device),
|
|
||||||
)
|
|
||||||
y_pred = model(x)
|
|
||||||
loss = util.complexNN.complex_mse_loss(y_pred, y)
|
|
||||||
loss.backward()
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
progress.update(task, advance=1, description=f"{loss.item():.3e}")
|
|
||||||
|
|
||||||
running_loss += loss.item()
|
|
||||||
if self.writer is not None:
|
|
||||||
if (batch_idx + 1) % 10 == 0:
|
|
||||||
self.writer.add_scalar(
|
|
||||||
"training loss",
|
|
||||||
running_loss / 10,
|
|
||||||
epoch * len(train_loader) + batch_idx,
|
|
||||||
)
|
|
||||||
running_loss = 0.0
|
|
||||||
|
|
||||||
return running_loss
|
|
||||||
|
|
||||||
def eval_model(self, model, valid_loader, epoch):
|
|
||||||
with Progress(
|
|
||||||
TextColumn("[green]Evaluating..."),
|
|
||||||
TextColumn("Error: {task.description}"),
|
|
||||||
BarColumn(),
|
|
||||||
TaskProgressColumn(),
|
|
||||||
TextColumn("[green]Batch"),
|
|
||||||
MofNCompleteColumn(),
|
|
||||||
TimeRemainingColumn(),
|
|
||||||
TimeElapsedColumn(),
|
|
||||||
# description="Training",
|
|
||||||
transient=False,
|
|
||||||
console=self.console,
|
|
||||||
refresh_per_second=10,
|
|
||||||
) as progress:
|
|
||||||
task = progress.add_task("-.---e--", total=len(valid_loader))
|
|
||||||
|
|
||||||
model.eval()
|
|
||||||
running_loss = 0
|
|
||||||
running_loss2 = 0
|
|
||||||
with torch.no_grad():
|
|
||||||
for batch_idx, (x, y) in enumerate(valid_loader):
|
|
||||||
x, y = (
|
|
||||||
x.to(self.pytorch_settings.device),
|
|
||||||
y.to(self.pytorch_settings.device),
|
|
||||||
)
|
|
||||||
y_pred = model(x)
|
|
||||||
loss = util.complexNN.complex_mse_loss(y_pred, y)
|
|
||||||
running_loss += loss.item()
|
|
||||||
running_loss2 += loss.item()
|
|
||||||
|
|
||||||
progress.update(task, advance=1, description=f"{loss.item():.3e}")
|
|
||||||
if self.writer is not None:
|
|
||||||
if (batch_idx + 1) % 10 == 0:
|
|
||||||
self.writer.add_scalar(
|
|
||||||
"loss",
|
|
||||||
running_loss / 10,
|
|
||||||
epoch * len(valid_loader) + batch_idx,
|
|
||||||
)
|
|
||||||
running_loss = 0.0
|
|
||||||
|
|
||||||
if self.writer is not None:
|
|
||||||
self.writer.add_figure("fiber response", self.plot_model_response(model, plot=False), epoch+1)
|
|
||||||
|
|
||||||
return running_loss2 / len(valid_loader)
|
|
||||||
|
|
||||||
def run_model(self, model, loader):
|
|
||||||
model.eval()
|
|
||||||
xs = []
|
|
||||||
ys = []
|
|
||||||
y_preds = []
|
|
||||||
with torch.no_grad():
|
|
||||||
model = model.to(self.pytorch_settings.device)
|
|
||||||
for x, y in loader:
|
|
||||||
x, y = (
|
|
||||||
x.to(self.pytorch_settings.device),
|
|
||||||
y.to(self.pytorch_settings.device),
|
|
||||||
)
|
|
||||||
y_pred = model(x).cpu()
|
|
||||||
# x = x.cpu()
|
|
||||||
# y = y.cpu()
|
|
||||||
y_pred = y_pred.view(y_pred.shape[0], -1, 2)
|
|
||||||
y = y.view(y.shape[0], -1, 2)
|
|
||||||
x = x.view(x.shape[0], -1, 2)
|
|
||||||
xs.append(x[:, 0, :].squeeze())
|
|
||||||
ys.append(y.squeeze())
|
|
||||||
y_preds.append(y_pred.squeeze())
|
|
||||||
|
|
||||||
xs = torch.vstack(xs).cpu()
|
|
||||||
ys = torch.vstack(ys).cpu()
|
|
||||||
y_preds = torch.vstack(y_preds).cpu()
|
|
||||||
return ys, xs, y_preds
|
|
||||||
|
|
||||||
def objective(self, save=False, plot_before=False):
|
|
||||||
try:
|
|
||||||
rprint(*list(self.study_name.split("_")))
|
|
||||||
|
|
||||||
self.model = self.define_model().to(self.pytorch_settings.device)
|
|
||||||
|
|
||||||
if self.writer is not None:
|
|
||||||
self.writer.add_figure("fiber response", self.plot_model_response(plot=plot_before), 0)
|
|
||||||
|
|
||||||
train_loader, valid_loader = self.get_sliced_data()
|
|
||||||
|
|
||||||
optimizer_name = self.optimizer_settings.optimizer_range
|
|
||||||
|
|
||||||
lr = self.optimizer_settings.lr_range
|
|
||||||
|
|
||||||
optimizer = getattr(optim, optimizer_name)(self.model.parameters(), lr=lr)
|
|
||||||
|
|
||||||
for epoch in range(self.pytorch_settings.epochs):
|
|
||||||
self.console.rule(f"Epoch {epoch + 1}/{self.pytorch_settings.epochs}")
|
|
||||||
self.train_model(self.model, optimizer, train_loader, epoch)
|
|
||||||
eval_loss = self.eval_model(self.model, valid_loader, epoch)
|
|
||||||
|
|
||||||
return eval_loss
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
...
|
|
||||||
finally:
|
|
||||||
if hasattr(self, "model"):
|
|
||||||
save_path = (
|
|
||||||
Path(self.pytorch_settings.model_dir) / f"{self.study_name}.pth"
|
|
||||||
)
|
|
||||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
torch.save(self.model, save_path)
|
|
||||||
|
|
||||||
def _plot_model_response_plotter(self, fiber_in, fiber_out, regen, plot=True):
|
|
||||||
fig, axs = plt.subplots(2)
|
|
||||||
for i, ax in enumerate(axs):
|
|
||||||
ax.plot(np.abs(fiber_in[:, i]) ** 2, label="fiber in")
|
|
||||||
ax.plot(np.abs(fiber_out[:, i]) ** 2, label="fiber out")
|
|
||||||
ax.plot(np.abs(regen[:, i]) ** 2, label="regenerated")
|
|
||||||
ax.legend()
|
|
||||||
if plot:
|
|
||||||
plt.show()
|
|
||||||
return fig
|
|
||||||
|
|
||||||
def plot_model_response(self, model=None, plot=True):
|
|
||||||
data_settings_backup = copy.copy(self.data_settings)
|
|
||||||
self.data_settings.shuffle = False
|
|
||||||
self.data_settings.train_split = 0.01
|
|
||||||
self.data_settings.drop_first = 100
|
|
||||||
plot_loader, _ = self.get_sliced_data()
|
|
||||||
self.data_settings = data_settings_backup
|
|
||||||
|
|
||||||
fiber_in, fiber_out, regen = self.run_model(model or self.model, plot_loader)
|
|
||||||
fiber_in = fiber_in.view(-1, 2)
|
|
||||||
fiber_out = fiber_out.view(-1, 2)
|
|
||||||
regen = regen.view(-1, 2)
|
|
||||||
|
|
||||||
fiber_in = fiber_in.numpy()
|
|
||||||
fiber_out = fiber_out.numpy()
|
|
||||||
regen = regen.numpy()
|
|
||||||
|
|
||||||
# https://github.com/matplotlib/matplotlib/issues/27713#issue-2104110987
|
|
||||||
# https://github.com/matplotlib/matplotlib/issues/27713#issuecomment-1915497463
|
|
||||||
import gc
|
|
||||||
fig = self._plot_model_response_plotter(fiber_in, fiber_out, regen, plot=plot)
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
return fig
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
trainer = Training()
|
trainer = Trainer(
|
||||||
|
global_settings=global_settings,
|
||||||
|
data_settings=data_settings,
|
||||||
|
pytorch_settings=pytorch_settings,
|
||||||
|
model_settings=model_settings,
|
||||||
|
optimizer_settings=optimizer_settings,
|
||||||
|
checkpoint_path='.models/20241128_084935_8885.tar',
|
||||||
|
settings_override={
|
||||||
|
"model_settings": {
|
||||||
|
# "model_activation_func": "PowScale",
|
||||||
|
"dropout_prob": 0,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
reset_epoch=True,
|
||||||
|
)
|
||||||
|
|
||||||
# trainer.plot_eye()
|
best = trainer.train()
|
||||||
trainer.setup_tb_writer()
|
save_dict_to_file(best, ".models/best_results.json")
|
||||||
trainer.objective(save=True)
|
|
||||||
|
|
||||||
best_model = trainer.model
|
|
||||||
|
|
||||||
# best_model = trainer.define_model(trainer.study.best_trial).to(trainer.pytorch_settings.device)
|
|
||||||
trainer.plot_model_response(best_model)
|
|
||||||
|
|
||||||
# print(f"Best model: {best_model}")
|
|
||||||
|
|
||||||
...
|
...
|
||||||
|
|||||||
Reference in New Issue
Block a user