move hypertraining class into separate file;
move settings dataclasses into separate file; add SemiUnitaryLayer; clean up model response plotting code; cnt hyperparameter search
This commit is contained in:
@@ -1,3 +1,3 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:72460af57347d35df91cd76982231bcf538a82fd7f1b8522795202fa298a2dcb
|
oid sha256:e12f0c21fca93620a165fbb6ed58d0b313093e972ef4416694c29c9cea6dc867
|
||||||
size 696320
|
size 831488
|
||||||
|
|||||||
735
src/single-core-regen/hypertraining/hypertraining.py
Normal file
735
src/single-core-regen/hypertraining/hypertraining.py
Normal file
@@ -0,0 +1,735 @@
|
|||||||
|
import copy
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Literal
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import optuna
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
# import torch.nn.functional as F # mse_loss doesn't support complex numbers
|
||||||
|
import torch.optim as optim
|
||||||
|
import torch.utils.data
|
||||||
|
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
from rich.progress import (
|
||||||
|
Progress,
|
||||||
|
TextColumn,
|
||||||
|
BarColumn,
|
||||||
|
TaskProgressColumn,
|
||||||
|
TimeRemainingColumn,
|
||||||
|
MofNCompleteColumn,
|
||||||
|
TimeElapsedColumn,
|
||||||
|
)
|
||||||
|
from rich.console import Console
|
||||||
|
# from rich import print as rprint
|
||||||
|
|
||||||
|
import multiprocessing
|
||||||
|
|
||||||
|
from util.datasets import FiberRegenerationDataset
|
||||||
|
from util.optuna_helpers import (
|
||||||
|
force_suggest_categorical,
|
||||||
|
force_suggest_float,
|
||||||
|
force_suggest_int,
|
||||||
|
)
|
||||||
|
import util
|
||||||
|
|
||||||
|
from .settings import (
|
||||||
|
GlobalSettings,
|
||||||
|
DataSettings,
|
||||||
|
ModelSettings,
|
||||||
|
OptunaSettings,
|
||||||
|
OptimizerSettings,
|
||||||
|
PytorchSettings,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class HyperTraining:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
global_settings,
|
||||||
|
data_settings,
|
||||||
|
pytorch_settings,
|
||||||
|
model_settings,
|
||||||
|
optimizer_settings,
|
||||||
|
optuna_settings,
|
||||||
|
console=None,
|
||||||
|
):
|
||||||
|
self.global_settings: GlobalSettings = global_settings
|
||||||
|
self.data_settings: DataSettings = data_settings
|
||||||
|
self.pytorch_settings: PytorchSettings = pytorch_settings
|
||||||
|
self.model_settings: ModelSettings = model_settings
|
||||||
|
self.optimizer_settings: OptimizerSettings = optimizer_settings
|
||||||
|
self.optuna_settings: OptunaSettings = optuna_settings
|
||||||
|
|
||||||
|
self.console = console or Console()
|
||||||
|
|
||||||
|
# set some extra settings to make the code more readable
|
||||||
|
self._extra_optuna_settings()
|
||||||
|
|
||||||
|
def setup_tb_writer(self, study_name=None, append=None):
|
||||||
|
log_dir = (
|
||||||
|
self.pytorch_settings.summary_dir
|
||||||
|
+ "/"
|
||||||
|
+ (study_name or self.optuna_settings.study_name)
|
||||||
|
)
|
||||||
|
if append is not None:
|
||||||
|
log_dir += "_" + str(append)
|
||||||
|
|
||||||
|
return SummaryWriter(log_dir)
|
||||||
|
|
||||||
|
def resume_latest_study(self, verbose=True):
|
||||||
|
study_name = self.get_latest_study()
|
||||||
|
|
||||||
|
if study_name:
|
||||||
|
print(f"Resuming study: {study_name}")
|
||||||
|
self.optuna_settings.study_name = study_name
|
||||||
|
|
||||||
|
def get_latest_study(self, verbose=True):
|
||||||
|
studies = self.get_studies()
|
||||||
|
for study in studies:
|
||||||
|
study.datetime_start = study.datetime_start or datetime.min
|
||||||
|
if studies:
|
||||||
|
study = sorted(studies, key=lambda x: x.datetime_start, reverse=True)[0]
|
||||||
|
if verbose:
|
||||||
|
print(f"Last study: {study.study_name}")
|
||||||
|
study_name = study.study_name
|
||||||
|
else:
|
||||||
|
if verbose:
|
||||||
|
print("No previous studies found")
|
||||||
|
study_name = None
|
||||||
|
return study_name
|
||||||
|
|
||||||
|
def get_studies(self):
|
||||||
|
return optuna.get_all_study_summaries(storage=self.optuna_settings.storage)
|
||||||
|
|
||||||
|
def setup_study(self):
|
||||||
|
self.study = optuna.create_study(
|
||||||
|
study_name=self.optuna_settings.study_name,
|
||||||
|
storage=self.optuna_settings.storage,
|
||||||
|
load_if_exists=True,
|
||||||
|
direction=self.optuna_settings.direction,
|
||||||
|
directions=self.optuna_settings.directions,
|
||||||
|
)
|
||||||
|
|
||||||
|
with warnings.catch_warnings(action="ignore"):
|
||||||
|
self.study.set_metric_names(self.optuna_settings.metrics_names)
|
||||||
|
|
||||||
|
self.n_threads = min(
|
||||||
|
self.optuna_settings.n_trials, self.optuna_settings.n_threads
|
||||||
|
)
|
||||||
|
self.processes = []
|
||||||
|
if self.n_threads > 1:
|
||||||
|
for _ in range(self.n_threads):
|
||||||
|
p = multiprocessing.Process(
|
||||||
|
# target=lambda n_trials: self._run_optimize(self, n_trials),
|
||||||
|
target=self._run_optimize,
|
||||||
|
args=(self.optuna_settings.n_trials // self.n_threads,),
|
||||||
|
)
|
||||||
|
self.processes.append(p)
|
||||||
|
|
||||||
|
# def plot_eye(self, width=2, symbols=None, alpha=None, complex=False, show=True):
|
||||||
|
# data, config = util.datasets.load_data(
|
||||||
|
# self.data_settings.config_path,
|
||||||
|
# skipfirst=10,
|
||||||
|
# symbols=symbols or 1000,
|
||||||
|
# real=not complex,
|
||||||
|
# normalize=True,
|
||||||
|
# )
|
||||||
|
# eye_data = {"data": data.numpy(), "sps": int(config["glova"]["sps"])}
|
||||||
|
# return util.plot.eye(
|
||||||
|
# **eye_data,
|
||||||
|
# width=width,
|
||||||
|
# show=show,
|
||||||
|
# alpha=alpha,
|
||||||
|
# complex=complex,
|
||||||
|
# symbols=symbols or 1000,
|
||||||
|
# skipfirst=0,
|
||||||
|
# )
|
||||||
|
|
||||||
|
def run_study(self):
|
||||||
|
if self.processes:
|
||||||
|
for p in self.processes:
|
||||||
|
p.start()
|
||||||
|
for p in self.processes:
|
||||||
|
p.join()
|
||||||
|
|
||||||
|
remaining_trials = self.optuna_settings.n_trials % self.n_threads
|
||||||
|
else:
|
||||||
|
remaining_trials = self.optuna_settings.n_trials
|
||||||
|
|
||||||
|
if remaining_trials:
|
||||||
|
self._run_optimize(remaining_trials)
|
||||||
|
|
||||||
|
def _run_optimize(self, n_trials):
|
||||||
|
self.study.optimize(
|
||||||
|
self.objective, n_trials=n_trials, timeout=self.optuna_settings.timeout
|
||||||
|
)
|
||||||
|
|
||||||
|
def _extra_optuna_settings(self):
|
||||||
|
self.optuna_settings.multi_objective = len(self.optuna_settings.directions) > 1
|
||||||
|
if self.optuna_settings.multi_objective:
|
||||||
|
self.optuna_settings.direction = None
|
||||||
|
else:
|
||||||
|
self.optuna_settings.direction = self.optuna_settings.directions[0]
|
||||||
|
self.optuna_settings.directions = None
|
||||||
|
|
||||||
|
self.optuna_settings.n_train_batches = (
|
||||||
|
self.optuna_settings.n_train_batches
|
||||||
|
if self.optuna_settings.limit_examples
|
||||||
|
else float("inf")
|
||||||
|
)
|
||||||
|
self.optuna_settings.n_valid_batches = (
|
||||||
|
self.optuna_settings.n_valid_batches
|
||||||
|
if self.optuna_settings.limit_examples
|
||||||
|
else float("inf")
|
||||||
|
)
|
||||||
|
|
||||||
|
def define_model(self, trial: optuna.Trial, writer=None):
|
||||||
|
n_layers = force_suggest_int(
|
||||||
|
trial, "model_n_layers", self.model_settings.model_n_layers
|
||||||
|
)
|
||||||
|
|
||||||
|
input_dim = 2 * trial.params.get(
|
||||||
|
"model_input_dim",
|
||||||
|
force_suggest_int(trial, "model_input_dim", self.data_settings.model_input_dim),
|
||||||
|
)
|
||||||
|
|
||||||
|
dtype = trial.params.get(
|
||||||
|
"model_dtype",
|
||||||
|
force_suggest_categorical(trial, "model_dtype", self.data_settings.dtype),
|
||||||
|
)
|
||||||
|
dtype = getattr(torch, dtype)
|
||||||
|
|
||||||
|
afunc = force_suggest_categorical(
|
||||||
|
trial, "model_activation_func", self.model_settings.model_activation_func
|
||||||
|
)
|
||||||
|
|
||||||
|
layers = []
|
||||||
|
last_dim = input_dim
|
||||||
|
for i in range(n_layers):
|
||||||
|
hidden_dim = force_suggest_int(
|
||||||
|
trial, f"model_hidden_dim_{i}", self.model_settings.unit_count
|
||||||
|
)
|
||||||
|
layers.append(
|
||||||
|
util.complexNN.SemiUnitaryLayer(last_dim, hidden_dim, dtype=dtype)
|
||||||
|
)
|
||||||
|
last_dim = hidden_dim
|
||||||
|
layers.append(getattr(util.complexNN, afunc)())
|
||||||
|
|
||||||
|
layers.append(
|
||||||
|
util.complexNN.UnitaryLayer(
|
||||||
|
hidden_dim, self.model_settings.output_dim, dtype=dtype
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
model = nn.Sequential(*layers)
|
||||||
|
|
||||||
|
if writer is not None:
|
||||||
|
writer.add_graph(
|
||||||
|
model, torch.zeros(1, input_dim, dtype=dtype), use_strict_trace=False
|
||||||
|
)
|
||||||
|
|
||||||
|
return model.to(self.pytorch_settings.device)
|
||||||
|
|
||||||
|
def get_sliced_data(self, trial: optuna.Trial, override=None):
|
||||||
|
symbols = trial.params.get(
|
||||||
|
"dataset_symbols",
|
||||||
|
force_suggest_float(trial, "dataset_symbols", self.data_settings.symbols),
|
||||||
|
)
|
||||||
|
|
||||||
|
xy_delay = trial.params.get(
|
||||||
|
"dataset_xy_delay",
|
||||||
|
force_suggest_float(trial, "dataset_xy_delay", self.data_settings.xy_delay),
|
||||||
|
)
|
||||||
|
|
||||||
|
data_size = trial.params.get(
|
||||||
|
"model_input_dim",
|
||||||
|
force_suggest_int(trial, "model_input_dim", self.data_settings.model_input_dim),
|
||||||
|
)
|
||||||
|
|
||||||
|
dtype = trial.params.get(
|
||||||
|
"model_dtype",
|
||||||
|
force_suggest_categorical(trial, "model_dtype", self.data_settings.dtype),
|
||||||
|
)
|
||||||
|
dtype = getattr(torch, dtype)
|
||||||
|
|
||||||
|
num_symbols = None
|
||||||
|
if override is not None:
|
||||||
|
num_symbols = override.get("num_symbols", None)
|
||||||
|
# get dataset
|
||||||
|
dataset = FiberRegenerationDataset(
|
||||||
|
file_path=self.data_settings.config_path,
|
||||||
|
symbols=symbols,
|
||||||
|
output_dim=data_size,
|
||||||
|
target_delay=self.data_settings.in_out_delay,
|
||||||
|
xy_delay=xy_delay,
|
||||||
|
drop_first=self.data_settings.drop_first,
|
||||||
|
dtype=dtype,
|
||||||
|
real=not dtype.is_complex,
|
||||||
|
num_symbols=num_symbols,
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset_size = len(dataset)
|
||||||
|
indices = list(range(dataset_size))
|
||||||
|
split = int(np.floor(self.data_settings.train_split * dataset_size))
|
||||||
|
if self.data_settings.shuffle:
|
||||||
|
np.random.seed(self.global_settings.seed)
|
||||||
|
np.random.shuffle(indices)
|
||||||
|
|
||||||
|
train_indices, valid_indices = indices[:split], indices[split:]
|
||||||
|
|
||||||
|
if self.data_settings.shuffle:
|
||||||
|
train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
|
||||||
|
valid_sampler = torch.utils.data.SubsetRandomSampler(valid_indices)
|
||||||
|
else:
|
||||||
|
train_sampler = train_indices
|
||||||
|
valid_sampler = valid_indices
|
||||||
|
|
||||||
|
train_loader = torch.utils.data.DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=self.pytorch_settings.batchsize,
|
||||||
|
sampler=train_sampler,
|
||||||
|
drop_last=True,
|
||||||
|
pin_memory=True,
|
||||||
|
num_workers=self.pytorch_settings.dataloader_workers,
|
||||||
|
prefetch_factor=self.pytorch_settings.dataloader_prefetch,
|
||||||
|
)
|
||||||
|
|
||||||
|
valid_loader = torch.utils.data.DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=self.pytorch_settings.batchsize,
|
||||||
|
sampler=valid_sampler,
|
||||||
|
drop_last=True,
|
||||||
|
pin_memory=True,
|
||||||
|
num_workers=self.pytorch_settings.dataloader_workers,
|
||||||
|
prefetch_factor=self.pytorch_settings.dataloader_prefetch,
|
||||||
|
)
|
||||||
|
|
||||||
|
return train_loader, valid_loader
|
||||||
|
|
||||||
|
def train_model(
|
||||||
|
self,
|
||||||
|
trial,
|
||||||
|
model,
|
||||||
|
optimizer,
|
||||||
|
train_loader,
|
||||||
|
epoch,
|
||||||
|
writer=None,
|
||||||
|
enable_progress=False,
|
||||||
|
):
|
||||||
|
if enable_progress:
|
||||||
|
progress = Progress(
|
||||||
|
TextColumn("[yellow] Training..."),
|
||||||
|
TextColumn("Error: {task.description}"),
|
||||||
|
BarColumn(),
|
||||||
|
TaskProgressColumn(),
|
||||||
|
TextColumn("[green]Batch"),
|
||||||
|
MofNCompleteColumn(),
|
||||||
|
TimeRemainingColumn(),
|
||||||
|
TimeElapsedColumn(),
|
||||||
|
# description="Training",
|
||||||
|
transient=False,
|
||||||
|
console=self.console,
|
||||||
|
refresh_per_second=10,
|
||||||
|
)
|
||||||
|
task = progress.add_task("-.---e--", total=len(train_loader))
|
||||||
|
progress.start()
|
||||||
|
|
||||||
|
running_loss2 = 0.0
|
||||||
|
running_loss = 0.0
|
||||||
|
model.train()
|
||||||
|
for batch_idx, (x, y) in enumerate(train_loader):
|
||||||
|
if batch_idx >= self.optuna_settings.n_train_batches:
|
||||||
|
break
|
||||||
|
model.zero_grad(set_to_none=True)
|
||||||
|
x, y = (
|
||||||
|
x.to(self.pytorch_settings.device),
|
||||||
|
y.to(self.pytorch_settings.device),
|
||||||
|
)
|
||||||
|
y_pred = model(x)
|
||||||
|
loss = util.complexNN.complex_mse_loss(y_pred, y)
|
||||||
|
loss_value = loss.item()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
running_loss2 += loss_value
|
||||||
|
running_loss += loss_value
|
||||||
|
|
||||||
|
if enable_progress:
|
||||||
|
progress.update(task, advance=1, description=f"{loss_value:.3e}")
|
||||||
|
|
||||||
|
if writer is not None:
|
||||||
|
if batch_idx % self.pytorch_settings.write_every == 0:
|
||||||
|
writer.add_scalar(
|
||||||
|
"training loss",
|
||||||
|
running_loss2
|
||||||
|
/ (self.pytorch_settings.write_every if batch_idx > 0 else 1),
|
||||||
|
epoch
|
||||||
|
* min(len(train_loader), self.optuna_settings.n_train_batches)
|
||||||
|
+ batch_idx,
|
||||||
|
)
|
||||||
|
running_loss2 = 0.0
|
||||||
|
|
||||||
|
if enable_progress:
|
||||||
|
progress.stop()
|
||||||
|
|
||||||
|
return running_loss / min(
|
||||||
|
len(train_loader), self.optuna_settings.n_train_batches
|
||||||
|
)
|
||||||
|
|
||||||
|
def eval_model(
|
||||||
|
self, trial, model, valid_loader, epoch, writer=None, enable_progress=True
|
||||||
|
):
|
||||||
|
if enable_progress:
|
||||||
|
progress = Progress(
|
||||||
|
TextColumn("[green]Evaluating..."),
|
||||||
|
TextColumn("Error: {task.description}"),
|
||||||
|
BarColumn(),
|
||||||
|
TaskProgressColumn(),
|
||||||
|
TextColumn("[green]Batch"),
|
||||||
|
MofNCompleteColumn(),
|
||||||
|
TimeRemainingColumn(),
|
||||||
|
TimeElapsedColumn(),
|
||||||
|
# description="Training",
|
||||||
|
transient=False,
|
||||||
|
console=self.console,
|
||||||
|
refresh_per_second=10,
|
||||||
|
)
|
||||||
|
progress.start()
|
||||||
|
task = progress.add_task("-.---e--", total=len(valid_loader))
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
running_error = 0
|
||||||
|
running_error_2 = 0
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch_idx, (x, y) in enumerate(valid_loader):
|
||||||
|
if batch_idx >= self.optuna_settings.n_valid_batches:
|
||||||
|
break
|
||||||
|
x, y = (
|
||||||
|
x.to(self.pytorch_settings.device),
|
||||||
|
y.to(self.pytorch_settings.device),
|
||||||
|
)
|
||||||
|
y_pred = model(x)
|
||||||
|
error = util.complexNN.complex_mse_loss(y_pred, y)
|
||||||
|
error_value = error.item()
|
||||||
|
running_error += error_value
|
||||||
|
running_error_2 += error_value
|
||||||
|
|
||||||
|
if enable_progress:
|
||||||
|
progress.update(task, advance=1, description=f"{error_value:.3e}")
|
||||||
|
|
||||||
|
if writer is not None:
|
||||||
|
if batch_idx % self.pytorch_settings.write_every == 0:
|
||||||
|
writer.add_scalar(
|
||||||
|
"eval loss",
|
||||||
|
running_error_2
|
||||||
|
/ (
|
||||||
|
self.pytorch_settings.write_every
|
||||||
|
if batch_idx > 0
|
||||||
|
else 1
|
||||||
|
),
|
||||||
|
epoch
|
||||||
|
* min(
|
||||||
|
len(valid_loader), self.optuna_settings.n_valid_batches
|
||||||
|
)
|
||||||
|
+ batch_idx,
|
||||||
|
)
|
||||||
|
running_error_2 = 0.0
|
||||||
|
|
||||||
|
running_error /= min(len(valid_loader), self.optuna_settings.n_valid_batches)
|
||||||
|
|
||||||
|
if writer is not None:
|
||||||
|
title_append, subtitle = self.build_title(trial)
|
||||||
|
writer.add_figure(
|
||||||
|
"fiber response",
|
||||||
|
self.plot_model_response(
|
||||||
|
trial,
|
||||||
|
model=model,
|
||||||
|
title_append=title_append,
|
||||||
|
subtitle=subtitle,
|
||||||
|
show=False,
|
||||||
|
),
|
||||||
|
epoch + 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
if enable_progress:
|
||||||
|
progress.stop()
|
||||||
|
|
||||||
|
return running_error
|
||||||
|
|
||||||
|
def run_model(self, model, loader):
|
||||||
|
model.eval()
|
||||||
|
xs = []
|
||||||
|
ys = []
|
||||||
|
y_preds = []
|
||||||
|
with torch.no_grad():
|
||||||
|
model = model.to(self.pytorch_settings.device)
|
||||||
|
for x, y in loader:
|
||||||
|
x, y = (
|
||||||
|
x.to(self.pytorch_settings.device),
|
||||||
|
y.to(self.pytorch_settings.device),
|
||||||
|
)
|
||||||
|
y_pred = model(x).cpu()
|
||||||
|
# x = x.cpu()
|
||||||
|
# y = y.cpu()
|
||||||
|
y_pred = y_pred.view(y_pred.shape[0], -1, 2)
|
||||||
|
y = y.view(y.shape[0], -1, 2)
|
||||||
|
x = x.view(x.shape[0], -1, 2)
|
||||||
|
xs.append(x[:, 0, :].squeeze())
|
||||||
|
ys.append(y.squeeze())
|
||||||
|
y_preds.append(y_pred.squeeze())
|
||||||
|
|
||||||
|
xs = torch.vstack(xs).cpu()
|
||||||
|
ys = torch.vstack(ys).cpu()
|
||||||
|
y_preds = torch.vstack(y_preds).cpu()
|
||||||
|
return ys, xs, y_preds
|
||||||
|
|
||||||
|
def objective(self, trial: optuna.Trial, plot_before=False):
|
||||||
|
model = None
|
||||||
|
exc = None
|
||||||
|
try:
|
||||||
|
# rprint(*list(self.study_name.split("_")))
|
||||||
|
|
||||||
|
writer = self.setup_tb_writer(
|
||||||
|
self.optuna_settings.study_name,
|
||||||
|
f"{trial.number:0{len(str(self.optuna_settings.n_trials))}}",
|
||||||
|
)
|
||||||
|
|
||||||
|
model = self.define_model(trial, writer)
|
||||||
|
n_params = sum(p.numel() for p in model.parameters())
|
||||||
|
# n_nodes = trial.params.get("model_n_layers", self.model_settings.model_n_layers) * trial.params.get("model_hidden_dim", self.model_settings.unit_count)
|
||||||
|
|
||||||
|
title_append, subtitle = self.build_title(trial)
|
||||||
|
|
||||||
|
writer.add_figure(
|
||||||
|
"fiber response",
|
||||||
|
self.plot_model_response(
|
||||||
|
trial,
|
||||||
|
model=model,
|
||||||
|
title_append=title_append,
|
||||||
|
subtitle=subtitle,
|
||||||
|
show=plot_before,
|
||||||
|
),
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
|
||||||
|
train_loader, valid_loader = self.get_sliced_data(trial)
|
||||||
|
|
||||||
|
optimizer_name = force_suggest_categorical(
|
||||||
|
trial, "optimizer", self.optimizer_settings.optimizer
|
||||||
|
)
|
||||||
|
|
||||||
|
lr = force_suggest_float(
|
||||||
|
trial, "lr", self.optimizer_settings.learning_rate, log=True
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)
|
||||||
|
if self.optimizer_settings.scheduler is not None:
|
||||||
|
scheduler = getattr(optim.lr_scheduler, self.optimizer_settings.scheduler)(
|
||||||
|
optimizer, **self.optimizer_settings.scheduler_kwargs)
|
||||||
|
|
||||||
|
for epoch in range(self.pytorch_settings.epochs):
|
||||||
|
enable_progress = self.optuna_settings.n_threads == 1
|
||||||
|
if enable_progress:
|
||||||
|
self.console.rule(
|
||||||
|
f"Epoch {epoch + 1}/{self.pytorch_settings.epochs}"
|
||||||
|
)
|
||||||
|
self.train_model(
|
||||||
|
trial,
|
||||||
|
model,
|
||||||
|
optimizer,
|
||||||
|
train_loader,
|
||||||
|
epoch,
|
||||||
|
writer,
|
||||||
|
enable_progress=enable_progress,
|
||||||
|
)
|
||||||
|
error = self.eval_model(
|
||||||
|
trial,
|
||||||
|
model,
|
||||||
|
valid_loader,
|
||||||
|
epoch,
|
||||||
|
writer,
|
||||||
|
enable_progress=enable_progress,
|
||||||
|
)
|
||||||
|
if self.optimizer_settings.scheduler is not None:
|
||||||
|
scheduler.step(error)
|
||||||
|
|
||||||
|
writer.close()
|
||||||
|
|
||||||
|
if self.optuna_settings.multi_objective:
|
||||||
|
return n_params, error
|
||||||
|
trial.report(error, epoch)
|
||||||
|
if trial.should_prune():
|
||||||
|
raise optuna.exceptions.TrialPruned()
|
||||||
|
return error
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
...
|
||||||
|
# except Exception as e:
|
||||||
|
# exc = e
|
||||||
|
finally:
|
||||||
|
if model is not None:
|
||||||
|
save_path = (
|
||||||
|
Path(self.pytorch_settings.model_dir)
|
||||||
|
/ f"{self.optuna_settings.study_name}_{trial.number}.pth"
|
||||||
|
)
|
||||||
|
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
torch.save(model, save_path)
|
||||||
|
if exc is not None:
|
||||||
|
raise exc
|
||||||
|
|
||||||
|
|
||||||
|
def _plot_model_response_eye(
|
||||||
|
self, *signals, labels=None, sps=None, title_append="", subtitle="", show=True
|
||||||
|
):
|
||||||
|
if sps is None:
|
||||||
|
raise ValueError("sps must be provided")
|
||||||
|
if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))):
|
||||||
|
labels = [labels]
|
||||||
|
else:
|
||||||
|
labels = list(labels)
|
||||||
|
|
||||||
|
while len(labels) < len(signals):
|
||||||
|
labels.append(None)
|
||||||
|
|
||||||
|
# check if there are any labels
|
||||||
|
if not any(labels):
|
||||||
|
labels = [f"signal {i + 1}" for i in range(len(signals))]
|
||||||
|
|
||||||
|
fig, axs = plt.subplots(2, len(signals), sharex=True, sharey=True)
|
||||||
|
fig.suptitle(
|
||||||
|
f"Eye diagram{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}"
|
||||||
|
)
|
||||||
|
xaxis = np.linspace(0, 2, 2 * sps, endpoint=False)
|
||||||
|
for j, (label, signal) in enumerate(zip(labels, signals)):
|
||||||
|
# signal = signal.cpu().numpy()
|
||||||
|
for i in range(len(signal) // sps - 1):
|
||||||
|
x, y = signal[i * sps : (i + 2) * sps].T
|
||||||
|
axs[0, j].plot(xaxis, np.abs(x) ** 2, color="C0", alpha=0.02)
|
||||||
|
axs[1, j].plot(xaxis, np.abs(y) ** 2, color="C0", alpha=0.02)
|
||||||
|
axs[0, j].set_title(label + " x")
|
||||||
|
axs[1, j].set_title(label + " y")
|
||||||
|
axs[0, j].set_xlabel("Symbol")
|
||||||
|
axs[1, j].set_xlabel("Symbol")
|
||||||
|
axs[0, j].set_ylabel("normalized power")
|
||||||
|
axs[1, j].set_ylabel("normalized power")
|
||||||
|
|
||||||
|
if show:
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
def _plot_model_response_head(
|
||||||
|
self, *signals, labels=None, sps=None, title_append="", subtitle="", show=True
|
||||||
|
):
|
||||||
|
if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))):
|
||||||
|
labels = [labels]
|
||||||
|
else:
|
||||||
|
labels = list(labels)
|
||||||
|
|
||||||
|
while len(labels) < len(signals):
|
||||||
|
labels.append(None)
|
||||||
|
|
||||||
|
# check if there are any labels
|
||||||
|
if not any(labels):
|
||||||
|
labels = [f"signal {i + 1}" for i in range(len(signals))]
|
||||||
|
|
||||||
|
fig, axs = plt.subplots(1, 2, sharex=True, sharey=True)
|
||||||
|
fig.set_size_inches(18,6)
|
||||||
|
fig.suptitle(
|
||||||
|
f"Fiber response{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}"
|
||||||
|
)
|
||||||
|
for i, ax in enumerate(axs):
|
||||||
|
for signal, label in zip(signals, labels):
|
||||||
|
if sps is not None:
|
||||||
|
xaxis = np.linspace(
|
||||||
|
0, len(signal) / sps, len(signal), endpoint=False
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
xaxis = np.arange(len(signal))
|
||||||
|
ax.plot(xaxis, np.abs(signal[:, i]) ** 2, label=label)
|
||||||
|
ax.set_xlabel("Sample" if sps is None else "Symbol")
|
||||||
|
ax.set_ylabel("normalized power")
|
||||||
|
ax.legend(loc="upper right")
|
||||||
|
if show:
|
||||||
|
plt.show()
|
||||||
|
return fig
|
||||||
|
|
||||||
|
def plot_model_response(
|
||||||
|
self,
|
||||||
|
trial,
|
||||||
|
model=None,
|
||||||
|
title_append="",
|
||||||
|
subtitle="",
|
||||||
|
mode: Literal["eye", "head"] = "head",
|
||||||
|
show=True,
|
||||||
|
):
|
||||||
|
data_settings_backup = copy.deepcopy(self.data_settings)
|
||||||
|
pytorch_settings_backup = copy.deepcopy(self.pytorch_settings)
|
||||||
|
self.data_settings.drop_first = 100
|
||||||
|
self.data_settings.shuffle = False
|
||||||
|
self.data_settings.train_split = 1.0
|
||||||
|
self.pytorch_settings.batchsize = self.pytorch_settings.eye_symbols if mode == "eye" else self.pytorch_settings.head_symbols
|
||||||
|
plot_loader, _ = self.get_sliced_data(
|
||||||
|
trial, override={"num_symbols": self.pytorch_settings.batchsize}
|
||||||
|
)
|
||||||
|
self.data_settings = data_settings_backup
|
||||||
|
self.pytorch_settings = pytorch_settings_backup
|
||||||
|
|
||||||
|
fiber_in, fiber_out, regen = self.run_model(model, plot_loader)
|
||||||
|
fiber_in = fiber_in.view(-1, 2)
|
||||||
|
fiber_out = fiber_out.view(-1, 2)
|
||||||
|
regen = regen.view(-1, 2)
|
||||||
|
|
||||||
|
fiber_in = fiber_in.numpy()
|
||||||
|
fiber_out = fiber_out.numpy()
|
||||||
|
regen = regen.numpy()
|
||||||
|
|
||||||
|
# https://github.com/matplotlib/matplotlib/issues/27713#issue-2104110987
|
||||||
|
# https://github.com/matplotlib/matplotlib/issues/27713#issuecomment-1915497463
|
||||||
|
import gc
|
||||||
|
|
||||||
|
if mode == "head":
|
||||||
|
fig = self._plot_model_response_head(
|
||||||
|
fiber_in,
|
||||||
|
fiber_out,
|
||||||
|
regen,
|
||||||
|
labels=("fiber in", "fiber out", "regen"),
|
||||||
|
sps=plot_loader.dataset.samples_per_symbol,
|
||||||
|
title_append=title_append,
|
||||||
|
subtitle=subtitle,
|
||||||
|
show=show,
|
||||||
|
)
|
||||||
|
elif mode == "eye":
|
||||||
|
# raise NotImplementedError("Eye diagram not implemented")
|
||||||
|
fig = self._plot_model_response_eye(
|
||||||
|
fiber_in,
|
||||||
|
fiber_out,
|
||||||
|
regen,
|
||||||
|
labels=("fiber in", "fiber out", "regen"),
|
||||||
|
sps=plot_loader.dataset.samples_per_symbol,
|
||||||
|
title_append=title_append,
|
||||||
|
subtitle=subtitle,
|
||||||
|
show=show,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown mode: {mode}")
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def build_title(trial):
|
||||||
|
title_append = f"for trial {trial.number}"
|
||||||
|
subtitle = (
|
||||||
|
f"{trial.params['model_n_layers']} layers, "
|
||||||
|
f"{', '.join([str(trial.params[f'model_hidden_dim_{i}']) for i in range(trial.params['model_n_layers'])])} units, "
|
||||||
|
f"{trial.params['model_activation_func']}, "
|
||||||
|
f"{trial.params['model_dtype']}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return title_append, subtitle
|
||||||
77
src/single-core-regen/hypertraining/settings.py
Normal file
77
src/single-core-regen/hypertraining/settings.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
|
# global settings
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class GlobalSettings:
|
||||||
|
seed: int = 42
|
||||||
|
|
||||||
|
|
||||||
|
# data settings
|
||||||
|
@dataclass
|
||||||
|
class DataSettings:
|
||||||
|
config_path: str # = "data/*-128-16384-100000-0-0-17-0-PAM4-0.ini"
|
||||||
|
dtype: tuple = ("complex64", "float64")
|
||||||
|
symbols: tuple | float | int = 8
|
||||||
|
model_input_dim: tuple | float | int = 64
|
||||||
|
shuffle: bool = True
|
||||||
|
in_out_delay: float = 0
|
||||||
|
xy_delay: tuple | float | int = 0
|
||||||
|
drop_first: int = 1000
|
||||||
|
train_split: float = 0.8
|
||||||
|
|
||||||
|
|
||||||
|
# pytorch settings
|
||||||
|
@dataclass
|
||||||
|
class PytorchSettings:
|
||||||
|
epochs: int = 1
|
||||||
|
batchsize: int = 2**10
|
||||||
|
|
||||||
|
device: str = "cuda"
|
||||||
|
|
||||||
|
dataloader_workers: int = 2
|
||||||
|
dataloader_prefetch: int = 2
|
||||||
|
|
||||||
|
model_dir: str = ".models"
|
||||||
|
|
||||||
|
summary_dir: str = ".runs"
|
||||||
|
write_every: int = 10
|
||||||
|
head_symbols: int = 40
|
||||||
|
eye_symbols: int = 1000
|
||||||
|
|
||||||
|
|
||||||
|
# model settings
|
||||||
|
@dataclass
|
||||||
|
class ModelSettings:
|
||||||
|
output_dim: int = 2
|
||||||
|
model_n_layers: tuple | int = 3
|
||||||
|
unit_count: tuple | int = 8
|
||||||
|
# n_units_range: tuple | int = (2, 32)
|
||||||
|
# activation_func_range: tuple = ("ModReLU", "ZReLU", "CReLU", "Mag", "Identity")
|
||||||
|
model_activation_func: tuple = ("ModReLU",)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OptimizerSettings:
|
||||||
|
optimizer: tuple | str = ("Adam", "RMSprop", "SGD")
|
||||||
|
learning_rate: tuple | float = (1e-5, 1e-1)
|
||||||
|
scheduler: str | None = None
|
||||||
|
scheduler_kwargs: dict | None = None
|
||||||
|
|
||||||
|
|
||||||
|
# optuna settings
|
||||||
|
@dataclass
|
||||||
|
class OptunaSettings:
|
||||||
|
n_trials: int = 128
|
||||||
|
n_threads: int = 4
|
||||||
|
timeout: int = 600
|
||||||
|
directions: tuple = ("minimize",)
|
||||||
|
metrics_names: tuple = ("mse",)
|
||||||
|
limit_examples: bool = True
|
||||||
|
n_train_batches: int = 100
|
||||||
|
n_valid_batches: int = 100
|
||||||
|
storage: str = "sqlite:///example.db"
|
||||||
|
study_name: str = (
|
||||||
|
f"optuna_study_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}"
|
||||||
|
)
|
||||||
@@ -1,464 +1,107 @@
|
|||||||
import copy
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import matplotlib.pyplot as plt
|
from hypertraining.hypertraining import HyperTraining
|
||||||
|
from hypertraining.settings import (
|
||||||
import numpy as np
|
GlobalSettings,
|
||||||
import optuna
|
DataSettings,
|
||||||
import warnings
|
PytorchSettings,
|
||||||
|
ModelSettings,
|
||||||
import torch
|
OptimizerSettings,
|
||||||
import torch.nn as nn
|
OptunaSettings,
|
||||||
# import torch.nn.functional as F # mse_loss doesn't support complex numbers
|
)
|
||||||
import torch.optim as optim
|
|
||||||
import torch.utils.data
|
global_settings = GlobalSettings(
|
||||||
|
seed = 42,
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
)
|
||||||
|
|
||||||
from rich.progress import Progress, TextColumn, BarColumn, TaskProgressColumn, TimeRemainingColumn, MofNCompleteColumn
|
data_settings = DataSettings(
|
||||||
from rich.console import Console
|
config_path = "data/*-128-16384-100000-0-0-17-0-PAM4-0.ini",
|
||||||
|
dtype = ("complex64", "float64", "complex32", "float32"),
|
||||||
import multiprocessing
|
symbols = (1, 16),
|
||||||
|
model_input_dim = (1, 32),
|
||||||
from util.datasets import FiberRegenerationDataset
|
shuffle = True,
|
||||||
from util.complexNN import complex_sse_loss
|
in_out_delay = 0,
|
||||||
from util.optuna_helpers import optional_suggest_categorical, optional_suggest_float, optional_suggest_int
|
xy_delay = 0,
|
||||||
import util
|
drop_first = 1000,
|
||||||
# global settings
|
train_split = 0.8,
|
||||||
@dataclass
|
)
|
||||||
class GlobalSettings:
|
|
||||||
seed: int = 42
|
pytorch_settings = PytorchSettings(
|
||||||
|
epochs = 25,
|
||||||
|
batchsize = 2**10,
|
||||||
# data settings
|
device = "cuda",
|
||||||
@dataclass
|
dataloader_workers = 2,
|
||||||
class DataSettings:
|
dataloader_prefetch = 2,
|
||||||
config_path: str = "data/*-128-16384-1000-0-0-17-0-PAM4-0.ini"
|
summary_dir = ".runs",
|
||||||
dtype: torch.dtype = torch.complex64
|
write_every = 2**5,
|
||||||
symbols_range: tuple|float|int = 16
|
model_dir = ".models",
|
||||||
data_size_range: tuple|float|int = 32
|
)
|
||||||
shuffle: bool = True
|
|
||||||
target_delay: float = 0
|
model_settings = ModelSettings(
|
||||||
xy_delay_range: tuple|float|int = 0
|
output_dim = 2,
|
||||||
drop_first: int = 10
|
model_n_layers = (2, 8),
|
||||||
train_split: float = 0.8
|
unit_count = (2, 16),
|
||||||
|
model_activation_func = ("ModReLU")#, "ZReLU", "Mag")#, "CReLU", "Identity"),
|
||||||
|
)
|
||||||
# pytorch settings
|
|
||||||
@dataclass
|
optimizer_settings = OptimizerSettings(
|
||||||
class PytorchSettings:
|
optimizer = ("Adam", "RMSprop"),#, "SGD"),
|
||||||
device: str = "cuda"
|
# learning_rate = (1e-5, 1e-1),
|
||||||
batchsize: int = 1024
|
learning_rate=1e-3,
|
||||||
epochs: int = 10
|
# scheduler = "ReduceLROnPlateau",
|
||||||
summary_dir: str = ".runs"
|
# scheduler_kwargs = {"mode": "min", "factor": 0.5, "patience": 10}
|
||||||
|
)
|
||||||
|
|
||||||
# model settings
|
optuna_settings = OptunaSettings(
|
||||||
@dataclass
|
n_trials = 4096,
|
||||||
class ModelSettings:
|
n_threads = 16,
|
||||||
output_size: int = 2
|
timeout = 600,
|
||||||
n_layer_range: tuple|float|int = (2,8)
|
directions = ("minimize","minimize"),
|
||||||
n_units_range: tuple|float|int = (2,32)
|
metrics_names = ("n_params","mse"),
|
||||||
# activation_func_range: tuple = ("ReLU",)
|
|
||||||
|
limit_examples = True,
|
||||||
|
n_train_batches = 100,
|
||||||
@dataclass
|
n_valid_batches = 100,
|
||||||
class OptimizerSettings:
|
storage = "sqlite:///data/single_core_regen.db",
|
||||||
# optimizer_range: tuple|str = ("Adam", "RMSprop", "SGD")
|
study_name = f"single_core_regen_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
|
||||||
optimizer_range: tuple|str = "RMSprop"
|
)
|
||||||
# lr_range: tuple|float = (1e-5, 1e-1)
|
|
||||||
lr_range: tuple|float = 2e-5
|
|
||||||
|
|
||||||
|
|
||||||
# optuna settings
|
|
||||||
@dataclass
|
|
||||||
class OptunaSettings:
|
|
||||||
n_trials: int = 128
|
|
||||||
n_threads: int = 8
|
|
||||||
timeout: int = 600
|
|
||||||
directions: tuple = ("minimize",)
|
|
||||||
metrics_names: tuple = ("sse",)
|
|
||||||
|
|
||||||
limit_examples: bool = True
|
|
||||||
n_train_examples: int = PytorchSettings.batchsize * 50
|
|
||||||
# n_valid_examples: int = PytorchSettings.batchsize * 100
|
|
||||||
n_valid_examples: int = float("inf")
|
|
||||||
storage: str = "sqlite:///optuna_single_core_regen.db"
|
|
||||||
study_name: str = (
|
|
||||||
f"single_core_regen_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class HyperTraining:
|
|
||||||
def __init__(self):
|
|
||||||
self.global_settings = GlobalSettings()
|
|
||||||
self.data_settings = DataSettings()
|
|
||||||
self.pytorch_settings = PytorchSettings()
|
|
||||||
self.model_settings = ModelSettings()
|
|
||||||
self.optimizer_settings = OptimizerSettings()
|
|
||||||
self.optuna_settings = OptunaSettings()
|
|
||||||
|
|
||||||
self.console = Console()
|
|
||||||
|
|
||||||
# set some extra settings to make the code more readable
|
|
||||||
self._extra_optuna_settings()
|
|
||||||
|
|
||||||
def setup_tb_writer(self, study_name=None, append=None):
|
|
||||||
log_dir = self.pytorch_settings.summary_dir + "/" + (study_name or self.optuna_settings.study_name)
|
|
||||||
if append is not None:
|
|
||||||
log_dir += "_" + str(append)
|
|
||||||
|
|
||||||
return SummaryWriter(log_dir)
|
|
||||||
|
|
||||||
def resume_latest_study(self, verbose=True):
|
|
||||||
study_name = hyper_training.get_latest_study()
|
|
||||||
|
|
||||||
if study_name:
|
|
||||||
print(f"Resuming study: {study_name}")
|
|
||||||
self.optuna_settings.study_name = study_name
|
|
||||||
|
|
||||||
def get_latest_study(self, verbose=True):
|
|
||||||
studies = self.get_studies()
|
|
||||||
for study in studies:
|
|
||||||
study.datetime_start = study.datetime_start or datetime.min
|
|
||||||
if studies:
|
|
||||||
study = sorted(studies, key = lambda x: x.datetime_start, reverse=True)[0]
|
|
||||||
if verbose:
|
|
||||||
print(f"Last study: {study.study_name}")
|
|
||||||
study_name = study.study_name
|
|
||||||
else:
|
|
||||||
if verbose:
|
|
||||||
print("No previous studies found")
|
|
||||||
study_name = None
|
|
||||||
return study_name
|
|
||||||
|
|
||||||
def get_studies(self):
|
|
||||||
return optuna.get_all_study_summaries(storage=self.optuna_settings.storage)
|
|
||||||
|
|
||||||
def setup_study(self):
|
|
||||||
self.study = optuna.create_study(
|
|
||||||
study_name=self.optuna_settings.study_name,
|
|
||||||
storage=self.optuna_settings.storage,
|
|
||||||
load_if_exists=True,
|
|
||||||
direction=self.optuna_settings.direction,
|
|
||||||
directions=self.optuna_settings.directions,
|
|
||||||
)
|
|
||||||
|
|
||||||
with warnings.catch_warnings(action="ignore"):
|
|
||||||
self.study.set_metric_names(self.optuna_settings.metrics_names)
|
|
||||||
|
|
||||||
self.n_threads = min(
|
|
||||||
self.optuna_settings.n_trials, self.optuna_settings.n_threads
|
|
||||||
)
|
|
||||||
self.processes = []
|
|
||||||
if self.n_threads > 1:
|
|
||||||
for _ in range(self.n_threads):
|
|
||||||
p = multiprocessing.Process(
|
|
||||||
# target=lambda n_trials: self._run_optimize(self, n_trials),
|
|
||||||
target = self._run_optimize,
|
|
||||||
args = (self.optuna_settings.n_trials // self.n_threads,),
|
|
||||||
)
|
|
||||||
self.processes.append(p)
|
|
||||||
|
|
||||||
def run_study(self):
|
|
||||||
if self.processes:
|
|
||||||
for p in self.processes:
|
|
||||||
p.start()
|
|
||||||
for p in self.processes:
|
|
||||||
p.join()
|
|
||||||
|
|
||||||
remaining_trials = (
|
|
||||||
self.optuna_settings.n_trials
|
|
||||||
- self.optuna_settings.n_trials % self.optuna_settings.n_threads
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
remaining_trials = self.optuna_settings.n_trials
|
|
||||||
|
|
||||||
if remaining_trials:
|
|
||||||
self._run_optimize(remaining_trials)
|
|
||||||
|
|
||||||
def _run_optimize(self, n_trials):
|
|
||||||
self.study.optimize(
|
|
||||||
self.objective, n_trials=n_trials, timeout=self.optuna_settings.timeout
|
|
||||||
)
|
|
||||||
|
|
||||||
def plot_eye(self, show=True):
|
|
||||||
if not hasattr(self, "eye_data"):
|
|
||||||
data, config = util.datasets.load_data(
|
|
||||||
self.data_settings.config_path, skipfirst=10, symbols=1000
|
|
||||||
)
|
|
||||||
self.eye_data = {"data": data, "sps": int(config["glova"]["sps"])}
|
|
||||||
return util.plot.eye(**self.eye_data, show=show)
|
|
||||||
|
|
||||||
def _extra_optuna_settings(self):
|
|
||||||
self.optuna_settings.multi_objective = len(self.optuna_settings.directions) > 1
|
|
||||||
if self.optuna_settings.multi_objective:
|
|
||||||
self.optuna_settings.direction = None
|
|
||||||
else:
|
|
||||||
self.optuna_settings.direction = self.optuna_settings.directions[0]
|
|
||||||
self.optuna_settings.directions = None
|
|
||||||
|
|
||||||
self.optuna_settings.n_train_examples = (
|
|
||||||
self.optuna_settings.n_train_examples
|
|
||||||
if self.optuna_settings.limit_examples
|
|
||||||
else float("inf")
|
|
||||||
)
|
|
||||||
self.optuna_settings.n_valid_examples = (
|
|
||||||
self.optuna_settings.n_valid_examples
|
|
||||||
if self.optuna_settings.limit_examples
|
|
||||||
else float("inf")
|
|
||||||
)
|
|
||||||
|
|
||||||
def define_model(self, trial: optuna.Trial, writer=None):
|
|
||||||
n_layers = optional_suggest_int(trial, "model_n_layers", self.model_settings.n_layer_range)
|
|
||||||
|
|
||||||
in_features = 2 * trial.params.get(
|
|
||||||
"dataset_data_size",
|
|
||||||
optional_suggest_int(trial, "dataset_data_size", self.data_settings.data_size_range),
|
|
||||||
)
|
|
||||||
trial.set_user_attr("input_dim", in_features)
|
|
||||||
|
|
||||||
layers = []
|
|
||||||
for i in range(n_layers):
|
|
||||||
out_features = optional_suggest_int(trial, f"model_n_units_l{i}", self.model_settings.n_units_range, log=True)
|
|
||||||
|
|
||||||
layers.append(nn.Linear(in_features, out_features, dtype=self.data_settings.dtype))
|
|
||||||
# layers.append(getattr(nn, activation_func)())
|
|
||||||
in_features = out_features
|
|
||||||
|
|
||||||
layers.append(nn.Linear(in_features, self.model_settings.output_size, dtype=self.data_settings.dtype))
|
|
||||||
|
|
||||||
if writer is not None:
|
|
||||||
writer.add_graph(nn.Sequential(*layers), torch.zeros(1, trial.user_attrs["input_dim"], dtype=self.data_settings.dtype))
|
|
||||||
|
|
||||||
return nn.Sequential(*layers)
|
|
||||||
|
|
||||||
def get_sliced_data(self, trial: optuna.Trial):
|
|
||||||
symbols = optional_suggest_float(trial, "dataset_symbols", self.data_settings.symbols_range)
|
|
||||||
|
|
||||||
xy_delay = optional_suggest_float(trial, "dataset_xy_delay", self.data_settings.xy_delay_range)
|
|
||||||
|
|
||||||
data_size = trial.params.get(
|
|
||||||
"dataset_data_size",
|
|
||||||
optional_suggest_int(trial, "dataset_data_size", self.data_settings.data_size_range)
|
|
||||||
)
|
|
||||||
|
|
||||||
# get dataset
|
|
||||||
dataset = FiberRegenerationDataset(
|
|
||||||
file_path=self.data_settings.config_path,
|
|
||||||
symbols=symbols,
|
|
||||||
data_size=data_size,
|
|
||||||
target_delay=self.data_settings.target_delay,
|
|
||||||
xy_delay=xy_delay,
|
|
||||||
drop_first=self.data_settings.drop_first,
|
|
||||||
dtype=self.data_settings.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset_size = len(dataset)
|
|
||||||
indices = list(range(dataset_size))
|
|
||||||
split = int(np.floor(self.data_settings.train_split * dataset_size))
|
|
||||||
if self.data_settings.shuffle:
|
|
||||||
np.random.seed(self.global_settings.seed)
|
|
||||||
np.random.shuffle(indices)
|
|
||||||
|
|
||||||
train_indices, valid_indices = indices[:split], indices[split:]
|
|
||||||
|
|
||||||
train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
|
|
||||||
valid_sampler = torch.utils.data.SubsetRandomSampler(valid_indices)
|
|
||||||
|
|
||||||
train_loader = torch.utils.data.DataLoader(
|
|
||||||
dataset, batch_size=self.pytorch_settings.batchsize, sampler=train_sampler, drop_last=True
|
|
||||||
)
|
|
||||||
valid_loader = torch.utils.data.DataLoader(
|
|
||||||
dataset, batch_size=self.pytorch_settings.batchsize, sampler=valid_sampler, drop_last=True
|
|
||||||
)
|
|
||||||
|
|
||||||
return train_loader, valid_loader
|
|
||||||
|
|
||||||
def train_model(self, model, optimizer, train_loader, epoch, writer=None, enable_progress=True):
|
|
||||||
if enable_progress:
|
|
||||||
progress = Progress(
|
|
||||||
TextColumn("[yellow] Training..."),
|
|
||||||
TextColumn(" Loss: {task.description}"),
|
|
||||||
BarColumn(),
|
|
||||||
TaskProgressColumn(),
|
|
||||||
TextColumn("[green]Batch"),
|
|
||||||
MofNCompleteColumn(),
|
|
||||||
TimeRemainingColumn(),
|
|
||||||
# description="Training",
|
|
||||||
transient=False,
|
|
||||||
console=self.console,
|
|
||||||
refresh_per_second=10,
|
|
||||||
)
|
|
||||||
task = progress.add_task("-.---e--", total=len(train_loader))
|
|
||||||
|
|
||||||
running_loss = 0.0
|
|
||||||
last_loss = 0.0
|
|
||||||
model.train()
|
|
||||||
for batch_idx, (x, y) in enumerate(train_loader):
|
|
||||||
if (
|
|
||||||
batch_idx * train_loader.batch_size
|
|
||||||
>= self.optuna_settings.n_train_examples
|
|
||||||
):
|
|
||||||
break
|
|
||||||
optimizer.zero_grad()
|
|
||||||
x, y = (
|
|
||||||
x.to(self.pytorch_settings.device),
|
|
||||||
y.to(self.pytorch_settings.device),
|
|
||||||
)
|
|
||||||
y_pred = model(x)
|
|
||||||
loss = complex_sse_loss(y_pred, y)
|
|
||||||
loss.backward()
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
# clamp weights to keep energy bounded
|
|
||||||
for p in model.parameters():
|
|
||||||
p.data.clamp_(-1.0, 1.0)
|
|
||||||
|
|
||||||
last_loss = loss.item()
|
|
||||||
|
|
||||||
if enable_progress:
|
|
||||||
progress.update(task, advance=1, description=f"{last_loss:.3e}")
|
|
||||||
|
|
||||||
running_loss += loss.item()
|
|
||||||
if writer is not None:
|
|
||||||
if batch_idx % 10 == 0:
|
|
||||||
writer.add_scalar("training loss", running_loss/10, epoch*min(len(train_loader), self.optuna_settings.n_train_examples/train_loader.batch_size) + batch_idx)
|
|
||||||
running_loss = 0.0
|
|
||||||
|
|
||||||
if enable_progress:
|
|
||||||
progress.update(task, description=f"{last_loss:.3e}")
|
|
||||||
progress.stop()
|
|
||||||
|
|
||||||
|
|
||||||
def eval_model(self, model, valid_loader, epoch, writer=None, enable_progress=True):
|
|
||||||
if enable_progress:
|
|
||||||
progress = Progress(
|
|
||||||
TextColumn("[green]Evaluating..."),
|
|
||||||
TextColumn("Error: {task.description}"),
|
|
||||||
BarColumn(),
|
|
||||||
TaskProgressColumn(),
|
|
||||||
TextColumn("[green]Batch"),
|
|
||||||
MofNCompleteColumn(),
|
|
||||||
TimeRemainingColumn(),
|
|
||||||
# description="Training",
|
|
||||||
transient=False,
|
|
||||||
console=self.console,
|
|
||||||
refresh_per_second=10,
|
|
||||||
)
|
|
||||||
task = progress.add_task("-.---e--", total=len(valid_loader))
|
|
||||||
|
|
||||||
model.eval()
|
|
||||||
running_error = 0
|
|
||||||
running_error_2 = 0
|
|
||||||
with torch.no_grad():
|
|
||||||
for batch_idx, (x, y) in enumerate(valid_loader):
|
|
||||||
if (
|
|
||||||
batch_idx * valid_loader.batch_size
|
|
||||||
>= self.optuna_settings.n_valid_examples
|
|
||||||
):
|
|
||||||
break
|
|
||||||
x, y = (
|
|
||||||
x.to(self.pytorch_settings.device),
|
|
||||||
y.to(self.pytorch_settings.device),
|
|
||||||
)
|
|
||||||
y_pred = model(x)
|
|
||||||
error = complex_sse_loss(y_pred, y)
|
|
||||||
running_error += error.item()
|
|
||||||
running_error_2 += error.item()
|
|
||||||
|
|
||||||
if enable_progress:
|
|
||||||
progress.update(task, advance=1, description=f"{error.item():.3e}")
|
|
||||||
|
|
||||||
if writer is not None:
|
|
||||||
if batch_idx % 10 == 0:
|
|
||||||
writer.add_scalar("sse", running_error_2/10, epoch*min(len(valid_loader), self.optuna_settings.n_valid_examples/valid_loader.batch_size) + batch_idx)
|
|
||||||
running_error_2 = 0.0
|
|
||||||
|
|
||||||
running_error /= batch_idx + 1
|
|
||||||
|
|
||||||
if enable_progress:
|
|
||||||
progress.update(task, description=f"{running_error:.3e}")
|
|
||||||
progress.stop()
|
|
||||||
|
|
||||||
return running_error
|
|
||||||
|
|
||||||
def run_model(self, model, loader):
|
|
||||||
model.eval()
|
|
||||||
y_preds = []
|
|
||||||
with torch.no_grad():
|
|
||||||
for x, y in loader:
|
|
||||||
x, y = (
|
|
||||||
x.to(self.pytorch_settings.device),
|
|
||||||
y.to(self.pytorch_settings.device),
|
|
||||||
)
|
|
||||||
y_preds.append(model(x))
|
|
||||||
return torch.stack(y_preds)
|
|
||||||
|
|
||||||
|
|
||||||
def objective(self, trial: optuna.Trial):
|
|
||||||
writer = self.setup_tb_writer(self.optuna_settings.study_name, f"{trial.number:0>len(str(self.optuna_settings.n_trials))}")
|
|
||||||
train_loader, valid_loader = self.get_sliced_data(trial)
|
|
||||||
|
|
||||||
model = self.define_model(trial, writer).to(self.pytorch_settings.device)
|
|
||||||
|
|
||||||
optimizer_name = optional_suggest_categorical(trial, "optimizer", self.optimizer_settings.optimizer_range)
|
|
||||||
|
|
||||||
lr = optional_suggest_float(trial, "lr", self.optimizer_settings.lr_range, log=True)
|
|
||||||
|
|
||||||
optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)
|
|
||||||
|
|
||||||
for epoch in range(self.pytorch_settings.epochs):
|
|
||||||
enable_progress = self.optuna_settings.n_threads == 1
|
|
||||||
if enable_progress:
|
|
||||||
print(f"Epoch {epoch+1}/{self.pytorch_settings.epochs}")
|
|
||||||
self.train_model(model, optimizer, train_loader, epoch, writer, enable_progress=enable_progress)
|
|
||||||
sse = self.eval_model(model, valid_loader, epoch, writer, enable_progress=enable_progress)
|
|
||||||
|
|
||||||
if not self.optuna_settings.multi_objective:
|
|
||||||
trial.report(sse, epoch)
|
|
||||||
if trial.should_prune():
|
|
||||||
raise optuna.exceptions.TrialPruned()
|
|
||||||
|
|
||||||
writer.close()
|
|
||||||
|
|
||||||
return sse
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
hyper_training = HyperTraining()
|
hyper_training = HyperTraining(
|
||||||
|
global_settings=global_settings,
|
||||||
|
data_settings=data_settings,
|
||||||
|
pytorch_settings=pytorch_settings,
|
||||||
|
model_settings=model_settings,
|
||||||
|
optimizer_settings=optimizer_settings,
|
||||||
|
optuna_settings=optuna_settings,
|
||||||
|
)
|
||||||
|
|
||||||
|
hyper_training.setup_study()
|
||||||
|
|
||||||
# hyper_training.resume_latest_study()
|
# hyper_training.resume_latest_study()
|
||||||
|
|
||||||
hyper_training.setup_study()
|
|
||||||
hyper_training.run_study()
|
hyper_training.run_study()
|
||||||
|
# best_trial = hyper_training.study.best_trial
|
||||||
|
|
||||||
best_model = hyper_training.define_model(hyper_training.study.best_trial).to(hyper_training.pytorch_settings.device)
|
# best_model = hyper_training.define_model(best_trial).to(
|
||||||
data_settings_backup = copy.copy(hyper_training.data_settings)
|
# hyper_training.pytorch_settings.device
|
||||||
hyper_training.data_settings.shuffle = False
|
# )
|
||||||
hyper_training.data_settings.train_split = 0.01
|
|
||||||
plot_loader, _ = hyper_training.get_sliced_data(hyper_training.study.best_trial)
|
|
||||||
|
|
||||||
regen = hyper_training.run_model(best_model, plot_loader)
|
|
||||||
regen = regen.view(-1, 2)
|
|
||||||
# [batch_no, batch_size, 2] -> [no, 2]
|
|
||||||
|
|
||||||
original, _ = util.datasets.load_data(hyper_training.data_settings.config_path, skipfirst=hyper_training.data_settings.drop_first)
|
# title_append, subtitle = hyper_training.build_title(best_trial)
|
||||||
original = original[:len(regen)]
|
# hyper_training.plot_model_response(
|
||||||
|
# best_trial,
|
||||||
regen = regen.cpu().numpy()
|
# model=best_model,
|
||||||
_, axs = plt.subplots(2)
|
# title_append=title_append,
|
||||||
for i, ax in enumerate(axs):
|
# subtitle=subtitle,
|
||||||
ax.plot(np.abs(original[:, i])**2, label="original")
|
# mode="eye",
|
||||||
ax.plot(np.abs(regen[:, i])**2, label="regen")
|
# show=True,
|
||||||
ax.legend()
|
# )
|
||||||
plt.show()
|
|
||||||
|
|
||||||
|
|
||||||
print(f"Best model: {best_model}")
|
|
||||||
|
|
||||||
|
# print(f"Best model found for trial {best_trial.number}")
|
||||||
|
# print(f"Best model error: {best_trial.value}")
|
||||||
|
# print(f"Best model params: {best_trial.params}")
|
||||||
|
# print()
|
||||||
|
# print(best_model)
|
||||||
|
|
||||||
# eye_fig = hyper_training.plot_eye()
|
# eye_fig = hyper_training.plot_eye()
|
||||||
...
|
...
|
||||||
|
|||||||
@@ -95,11 +95,12 @@ class Training:
|
|||||||
self.writer = None
|
self.writer = None
|
||||||
self.console = Console()
|
self.console = Console()
|
||||||
|
|
||||||
def setup_tb_writer(self, study_name=None):
|
def setup_tb_writer(self, study_name=None, append=None):
|
||||||
log_dir = (
|
log_dir = (
|
||||||
self.pytorch_settings.summary_dir + "/" + (study_name or self.study_name)
|
self.pytorch_settings.summary_dir + "/" + (study_name or self.study_name) + ("_" + str(append)) if append else ""
|
||||||
)
|
)
|
||||||
self.writer = SummaryWriter(log_dir)
|
self.writer = SummaryWriter(log_dir)
|
||||||
|
return self.writer
|
||||||
|
|
||||||
def plot_eye(self, width=2, symbols=None, alpha=None, complex=False, show=True):
|
def plot_eye(self, width=2, symbols=None, alpha=None, complex=False, show=True):
|
||||||
if not hasattr(self, "eye_data"):
|
if not hasattr(self, "eye_data"):
|
||||||
@@ -160,7 +161,7 @@ class Training:
|
|||||||
dataset = util.datasets.FiberRegenerationDataset(
|
dataset = util.datasets.FiberRegenerationDataset(
|
||||||
file_path=self.data_settings.config_path,
|
file_path=self.data_settings.config_path,
|
||||||
symbols=symbols,
|
symbols=symbols,
|
||||||
data_size=data_size,
|
output_dim=data_size,
|
||||||
target_delay=self.data_settings.target_delay,
|
target_delay=self.data_settings.target_delay,
|
||||||
xy_delay=xy_delay,
|
xy_delay=xy_delay,
|
||||||
drop_first=self.data_settings.drop_first,
|
drop_first=self.data_settings.drop_first,
|
||||||
@@ -212,7 +213,7 @@ class Training:
|
|||||||
def train_model(self, model, optimizer, train_loader, epoch):
|
def train_model(self, model, optimizer, train_loader, epoch):
|
||||||
with Progress(
|
with Progress(
|
||||||
TextColumn("[yellow] Training..."),
|
TextColumn("[yellow] Training..."),
|
||||||
TextColumn("Loss: {task.description}"),
|
TextColumn("Error: {task.description}"),
|
||||||
BarColumn(),
|
BarColumn(),
|
||||||
TaskProgressColumn(),
|
TaskProgressColumn(),
|
||||||
TextColumn("[green]Batch"),
|
TextColumn("[green]Batch"),
|
||||||
@@ -256,7 +257,7 @@ class Training:
|
|||||||
def eval_model(self, model, valid_loader, epoch):
|
def eval_model(self, model, valid_loader, epoch):
|
||||||
with Progress(
|
with Progress(
|
||||||
TextColumn("[green]Evaluating..."),
|
TextColumn("[green]Evaluating..."),
|
||||||
TextColumn("Loss: {task.description}"),
|
TextColumn("Error: {task.description}"),
|
||||||
BarColumn(),
|
BarColumn(),
|
||||||
TaskProgressColumn(),
|
TaskProgressColumn(),
|
||||||
TextColumn("[green]Batch"),
|
TextColumn("[green]Batch"),
|
||||||
@@ -325,18 +326,6 @@ class Training:
|
|||||||
ys = torch.vstack(ys).cpu()
|
ys = torch.vstack(ys).cpu()
|
||||||
y_preds = torch.vstack(y_preds).cpu()
|
y_preds = torch.vstack(y_preds).cpu()
|
||||||
return ys, xs, y_preds
|
return ys, xs, y_preds
|
||||||
|
|
||||||
def dummy_model(self, loader):
|
|
||||||
xs = []
|
|
||||||
ys = []
|
|
||||||
for x, y in loader:
|
|
||||||
y = y.cpu().view(y.shape[0], -1, 2)
|
|
||||||
x = x.cpu().view(x.shape[0], -1, 2)
|
|
||||||
xs.append(x[:, 0, :].squeeze())
|
|
||||||
ys.append(y.squeeze())
|
|
||||||
xs = torch.vstack(xs)
|
|
||||||
ys = torch.vstack(ys)
|
|
||||||
return xs, ys
|
|
||||||
|
|
||||||
def objective(self, save=False, plot_before=False):
|
def objective(self, save=False, plot_before=False):
|
||||||
try:
|
try:
|
||||||
@@ -360,22 +349,18 @@ class Training:
|
|||||||
self.train_model(self.model, optimizer, train_loader, epoch)
|
self.train_model(self.model, optimizer, train_loader, epoch)
|
||||||
eval_loss = self.eval_model(self.model, valid_loader, epoch)
|
eval_loss = self.eval_model(self.model, valid_loader, epoch)
|
||||||
|
|
||||||
if save:
|
return eval_loss
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
...
|
||||||
|
finally:
|
||||||
|
if hasattr(self, "model"):
|
||||||
save_path = (
|
save_path = (
|
||||||
Path(self.pytorch_settings.model_dir) / f"{self.study_name}.pth"
|
Path(self.pytorch_settings.model_dir) / f"{self.study_name}.pth"
|
||||||
)
|
)
|
||||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
torch.save(self.model, save_path)
|
torch.save(self.model, save_path)
|
||||||
|
|
||||||
return eval_loss
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
pass
|
|
||||||
finally:
|
|
||||||
if hasattr(self, "model"):
|
|
||||||
except_save_path = Path(".models/exception") / f"{self.study_name}.pth"
|
|
||||||
except_save_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
torch.save(self.model, except_save_path)
|
|
||||||
|
|
||||||
def _plot_model_response_plotter(self, fiber_in, fiber_out, regen, plot=True):
|
def _plot_model_response_plotter(self, fiber_in, fiber_out, regen, plot=True):
|
||||||
fig, axs = plt.subplots(2)
|
fig, axs = plt.subplots(2)
|
||||||
for i, ax in enumerate(axs):
|
for i, ax in enumerate(axs):
|
||||||
|
|||||||
@@ -15,3 +15,5 @@ from . import complexNN # noqa: F401
|
|||||||
# from .complexNN import UnitaryLayer # noqa: F401
|
# from .complexNN import UnitaryLayer # noqa: F401
|
||||||
# from .complexNN import complex_mse_loss # noqa: F401
|
# from .complexNN import complex_mse_loss # noqa: F401
|
||||||
# from .complexNN import complex_sse_loss # noqa: F401
|
# from .complexNN import complex_sse_loss # noqa: F401
|
||||||
|
|
||||||
|
from . import misc # noqa: F401
|
||||||
@@ -1,116 +1,160 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
def complex_mse_loss(input, target):
|
def complex_mse_loss(input, target):
|
||||||
"""
|
"""
|
||||||
Compute the mean squared error between two complex tensors.
|
Compute the mean squared error between two complex tensors.
|
||||||
"""
|
"""
|
||||||
return torch.mean(torch.square(input.real - target.real) + torch.square(input.imag - target.imag))
|
if input.is_complex():
|
||||||
|
return torch.mean(
|
||||||
|
torch.square(input.real - target.real)
|
||||||
|
+ torch.square(input.imag - target.imag)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return F.mse_loss(input, target)
|
||||||
|
|
||||||
|
|
||||||
def complex_sse_loss(input, target):
|
def complex_sse_loss(input, target):
|
||||||
"""
|
"""
|
||||||
Compute the sum squared error between two complex tensors.
|
Compute the sum squared error between two complex tensors.
|
||||||
"""
|
"""
|
||||||
if input.is_complex():
|
if input.is_complex():
|
||||||
return torch.sum(torch.square(input.real - target.real) + torch.square(input.imag - target.imag))
|
return torch.sum(
|
||||||
|
torch.square(input.real - target.real)
|
||||||
|
+ torch.square(input.imag - target.imag)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return torch.sum(torch.square(input - target))
|
return torch.sum(torch.square(input - target))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class UnitaryLayer(nn.Module):
|
class UnitaryLayer(nn.Module):
|
||||||
def __init__(self, in_features, out_features):
|
def __init__(self, in_features, out_features, dtype=None):
|
||||||
super(UnitaryLayer, self).__init__()
|
|
||||||
assert in_features >= out_features
|
assert in_features >= out_features
|
||||||
|
super(UnitaryLayer, self).__init__()
|
||||||
self.in_features = in_features
|
self.in_features = in_features
|
||||||
self.out_features = out_features
|
self.out_features = out_features
|
||||||
self.weight = nn.Parameter(torch.randn(in_features, out_features, dtype=torch.cfloat))
|
self.weight = nn.Parameter(torch.randn(in_features, out_features, dtype=dtype))
|
||||||
self.reset_parameters()
|
self.reset_parameters()
|
||||||
|
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
q, _ = torch.linalg.qr(self.weight)
|
q, _ = torch.linalg.qr(self.weight)
|
||||||
self.weight.data = q
|
self.weight.data = q
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.matmul(x, self.weight)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"UnitaryLayer({self.in_features}, {self.out_features})"
|
||||||
|
|
||||||
|
class SemiUnitaryLayer(nn.Module):
|
||||||
|
def __init__(self, input_dim, output_dim, dtype=None):
|
||||||
|
super(SemiUnitaryLayer, self).__init__()
|
||||||
|
self.input_dim = input_dim
|
||||||
|
self.output_dim = output_dim
|
||||||
|
|
||||||
|
# Create a larger square matrix for QR decomposition
|
||||||
|
self.weight = nn.Parameter(torch.randn(max(input_dim, output_dim), max(input_dim, output_dim), dtype=dtype))
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
@staticmethod
|
def reset_parameters(self):
|
||||||
@torch.jit.script
|
# Ensure the weights are semi-unitary by QR decomposition
|
||||||
def _unitary_forward(x, weight):
|
q, _ = torch.linalg.qr(self.weight)
|
||||||
out = torch.matmul(x, weight)
|
if self.input_dim > self.output_dim:
|
||||||
return out
|
self.weight.data = q[:self.input_dim, :self.output_dim]
|
||||||
|
else:
|
||||||
|
self.weight.data = q[:self.output_dim, :self.input_dim].t()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self._unitary_forward(x, self.weight)
|
out = torch.matmul(x, self.weight)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"SemiUnitaryLayer({self.input_dim}, {self.output_dim})"
|
||||||
|
|
||||||
|
|
||||||
|
# class SpreadLayer(nn.Module):
|
||||||
|
# def __init__(self, in_features, out_features, dtype=None):
|
||||||
|
# super(SpreadLayer, self).__init__()
|
||||||
|
# self.in_features = in_features
|
||||||
|
# self.out_features = out_features
|
||||||
|
# self.mat = torch.ones(in_features, out_features, dtype=dtype)*torch.sqrt(torch.tensor(in_features/out_features))
|
||||||
|
|
||||||
|
# def forward(self, x):
|
||||||
|
# # N in_features -> M out_features, Enery is preserved (P = abs(x)^2)
|
||||||
|
# out = torch.matmul(x, self.mat)
|
||||||
|
# return out
|
||||||
|
|
||||||
|
|
||||||
#### as defined by zhang et al
|
#### as defined by zhang et al
|
||||||
|
|
||||||
|
|
||||||
class Identity(nn.Module):
|
class Identity(nn.Module):
|
||||||
"""
|
"""
|
||||||
implements the "activation" function
|
implements the "activation" function
|
||||||
M(z) = z
|
M(z) = z
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(Identity, self).__init__()
|
super(Identity, self).__init__()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class Mag(nn.Module):
|
class Mag(nn.Module):
|
||||||
"""
|
"""
|
||||||
implements the activation function
|
implements the activation function
|
||||||
M(z) = ||z||
|
M(z) = ||z||
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(Mag, self).__init__()
|
super(Mag, self).__init__()
|
||||||
|
|
||||||
@torch.jit.script
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return torch.abs(x.real**2 + x.imag**2)
|
return torch.abs(x).to(dtype=x.dtype)
|
||||||
|
|
||||||
# class Tanh(nn.Module):
|
|
||||||
# """
|
|
||||||
# implements the activation function
|
|
||||||
# M(z) = tanh(z) = sinh(z)/cosh(z) = (exp(z)-exp(-z))/(exp(z)+exp(-z)) = (exp(2*z)-1)/(exp(2*z)+1)
|
|
||||||
# """
|
|
||||||
# def __init__(self):
|
|
||||||
# super(Tanh, self).__init__()
|
|
||||||
|
|
||||||
# def forward(self, x):
|
|
||||||
# return torch.tanh(x)
|
|
||||||
|
|
||||||
class ModReLU(nn.Module):
|
class ModReLU(nn.Module):
|
||||||
"""
|
"""
|
||||||
implements the activation function
|
implements the activation function
|
||||||
M(z) = ReLU(||z|| + b)*exp(j*theta_z)
|
M(z) = ReLU(||z|| + b)*exp(j*theta_z)
|
||||||
= ReLU(||z|| + b)*z/||z||
|
= ReLU(||z|| + b)*z/||z||
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, b=0):
|
def __init__(self, b=0):
|
||||||
super(ModReLU, self).__init__()
|
super(ModReLU, self).__init__()
|
||||||
self.b = b
|
self.b = torch.tensor(b)
|
||||||
self.relu = nn.ReLU()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
# @torch.jit.script
|
|
||||||
def _mod_relu(x, b):
|
|
||||||
mod = torch.abs(x.real**2 + x.imag**2)
|
|
||||||
return torch.relu(mod + b) * x / mod
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self._mod_relu(x, self.b)
|
if x.is_complex():
|
||||||
|
mod = torch.abs(x.real**2 + x.imag**2)
|
||||||
|
return torch.relu(mod + self.b) * x / mod
|
||||||
|
|
||||||
|
else:
|
||||||
|
return torch.relu(x + self.b)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"ModReLU(b={self.b})"
|
||||||
|
|
||||||
|
|
||||||
class CReLU(nn.Module):
|
class CReLU(nn.Module):
|
||||||
"""
|
"""
|
||||||
implements the activation function
|
implements the activation function
|
||||||
M(z) = ReLU(Re(z)) + j*ReLU(Im(z))
|
M(z) = ReLU(Re(z)) + j*ReLU(Im(z))
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(CReLU, self).__init__()
|
super(CReLU, self).__init__()
|
||||||
self.relu = nn.ReLU()
|
|
||||||
|
|
||||||
@torch.jit.script
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return torch.relu(x.real) + 1j*torch.relu(x.imag)
|
if x.is_complex():
|
||||||
|
return torch.relu(x.real) + 1j * torch.relu(x.imag)
|
||||||
|
else:
|
||||||
|
return torch.relu(x)
|
||||||
|
|
||||||
|
|
||||||
class ZReLU(nn.Module):
|
class ZReLU(nn.Module):
|
||||||
"""
|
"""
|
||||||
implements the activation function
|
implements the activation function
|
||||||
@@ -122,20 +166,8 @@ class ZReLU(nn.Module):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(ZReLU, self).__init__()
|
super(ZReLU, self).__init__()
|
||||||
|
|
||||||
@torch.jit.script
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return x * (torch.angle(x) >= 0) * (torch.angle(x) <= torch.pi/2)
|
if x.is_complex():
|
||||||
|
return x * (torch.angle(x) >= 0) * (torch.angle(x) <= torch.pi / 2)
|
||||||
# class ComplexFeedForwardNN(nn.Module):
|
else:
|
||||||
# def __init__(self, in_features, hidden_features, out_features):
|
return torch.relu(x)
|
||||||
# super(ComplexFeedForwardNN, self).__init__()
|
|
||||||
# self.in_features = in_features
|
|
||||||
# self.hidden_features = hidden_features
|
|
||||||
# self.out_features = out_features
|
|
||||||
# self.fc1 = UnitaryLayer(in_features, hidden_features)
|
|
||||||
# self.fc2 = UnitaryLayer(hidden_features, out_features)
|
|
||||||
|
|
||||||
# def forward(self, x):
|
|
||||||
# x = self.fc1(x)
|
|
||||||
# x = self.fc2(x)
|
|
||||||
# return x
|
|
||||||
|
|||||||
@@ -41,9 +41,10 @@ def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=Fals
|
|||||||
data = np.load(datapath)[skipfirst * sps : symbols * sps + skipfirst * sps]
|
data = np.load(datapath)[skipfirst * sps : symbols * sps + skipfirst * sps]
|
||||||
|
|
||||||
if normalize:
|
if normalize:
|
||||||
a, b, c, d = data.T
|
# square gets normalized to 1, as the power is (proportional to) the square of the amplitude
|
||||||
|
a, b, c, d = np.square(data.T)
|
||||||
a, b, c, d = a/np.max(np.abs(a)), b/np.max(np.abs(b)), c/np.max(np.abs(c)), d/np.max(np.abs(d))
|
a, b, c, d = a/np.max(np.abs(a)), b/np.max(np.abs(b)), c/np.max(np.abs(c)), d/np.max(np.abs(d))
|
||||||
data = np.array([a, b, c, d]).T
|
data = np.sqrt(np.array([a, b, c, d]).T)
|
||||||
|
|
||||||
if real:
|
if real:
|
||||||
data = np.abs(data)
|
data = np.abs(data)
|
||||||
@@ -98,7 +99,7 @@ class FiberRegenerationDataset(Dataset):
|
|||||||
file_path: str | Path,
|
file_path: str | Path,
|
||||||
symbols: int | float,
|
symbols: int | float,
|
||||||
*,
|
*,
|
||||||
data_size: int = None,
|
output_dim: int = None,
|
||||||
target_delay: float | int = 0,
|
target_delay: float | int = 0,
|
||||||
xy_delay: float | int = 0,
|
xy_delay: float | int = 0,
|
||||||
drop_first: float | int = 0,
|
drop_first: float | int = 0,
|
||||||
@@ -129,7 +130,7 @@ class FiberRegenerationDataset(Dataset):
|
|||||||
assert isinstance(symbols, (float, int)), (
|
assert isinstance(symbols, (float, int)), (
|
||||||
"symbols must be a float or an integer"
|
"symbols must be a float or an integer"
|
||||||
)
|
)
|
||||||
assert data_size is None or isinstance(data_size, int), (
|
assert output_dim is None or isinstance(output_dim, int), (
|
||||||
"output_len must be an integer"
|
"output_len must be an integer"
|
||||||
)
|
)
|
||||||
assert isinstance(target_delay, (float, int)), (
|
assert isinstance(target_delay, (float, int)), (
|
||||||
@@ -142,7 +143,7 @@ class FiberRegenerationDataset(Dataset):
|
|||||||
|
|
||||||
# check values
|
# check values
|
||||||
assert symbols > 0, "symbols must be positive"
|
assert symbols > 0, "symbols must be positive"
|
||||||
assert data_size is None or data_size > 0, "output_len must be positive or None"
|
assert output_dim is None or output_dim > 0, "output_len must be positive or None"
|
||||||
assert drop_first >= 0, "drop_first must be non-negative"
|
assert drop_first >= 0, "drop_first must be non-negative"
|
||||||
|
|
||||||
faux = kwargs.pop("faux", False)
|
faux = kwargs.pop("faux", False)
|
||||||
@@ -158,7 +159,7 @@ class FiberRegenerationDataset(Dataset):
|
|||||||
"glova": {"sps": 128},
|
"glova": {"sps": 128},
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
data_raw, self.config = load_data(file_path, skipfirst=drop_first, real=real, normalize=True, device=device, dtype=dtype)
|
data_raw, self.config = load_data(file_path, skipfirst=drop_first, symbols=kwargs.pop("num_symbols", None), real=real, normalize=True, device=device, dtype=dtype)
|
||||||
|
|
||||||
self.device = data_raw.device
|
self.device = data_raw.device
|
||||||
|
|
||||||
@@ -166,7 +167,7 @@ class FiberRegenerationDataset(Dataset):
|
|||||||
self.samples_per_slice = int(symbols * self.samples_per_symbol)
|
self.samples_per_slice = int(symbols * self.samples_per_symbol)
|
||||||
self.symbols_per_slice = self.samples_per_slice / self.samples_per_symbol
|
self.symbols_per_slice = self.samples_per_slice / self.samples_per_symbol
|
||||||
|
|
||||||
self.data_size = data_size or self.samples_per_slice
|
self.output_dim = output_dim or self.samples_per_slice
|
||||||
self.target_delay = target_delay or 0
|
self.target_delay = target_delay or 0
|
||||||
self.xy_delay = xy_delay or 0
|
self.xy_delay = xy_delay or 0
|
||||||
|
|
||||||
@@ -261,13 +262,13 @@ class FiberRegenerationDataset(Dataset):
|
|||||||
data, target = self.data[idx, 1].squeeze(), self.data[idx, 0].squeeze()
|
data, target = self.data[idx, 1].squeeze(), self.data[idx, 0].squeeze()
|
||||||
|
|
||||||
# reduce by by taking self.output_dim equally spaced samples
|
# reduce by by taking self.output_dim equally spaced samples
|
||||||
data = data[:, : data.shape[1] // self.data_size * self.data_size]
|
data = data[:, : data.shape[1] // self.output_dim * self.output_dim]
|
||||||
data = data.view(data.shape[0], self.data_size, -1)
|
data = data.view(data.shape[0], self.output_dim, -1)
|
||||||
data = data[:, :, 0]
|
data = data[:, :, 0]
|
||||||
|
|
||||||
# target is corresponding to the middle of the data as the output sample is influenced by the data before and after it
|
# target is corresponding to the middle of the data as the output sample is influenced by the data before and after it
|
||||||
target = target[:, : target.shape[1] // self.data_size * self.data_size]
|
target = target[:, : target.shape[1] // self.output_dim * self.output_dim]
|
||||||
target = target.view(target.shape[0], self.data_size, -1)
|
target = target.view(target.shape[0], self.output_dim, -1)
|
||||||
target = target[:, 0, target.shape[2] // 2]
|
target = target[:, 0, target.shape[2] // 2]
|
||||||
|
|
||||||
data = data.transpose(0, 1).flatten().squeeze()
|
data = data.transpose(0, 1).flatten().squeeze()
|
||||||
|
|||||||
21
src/single-core-regen/util/misc.py
Normal file
21
src/single-core-regen/util/misc.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
def multi_getattr(objs, attr, fallback=None):
|
||||||
|
"""
|
||||||
|
tries to get the attribute from a list of objects, returning the first hit
|
||||||
|
if no object has the attribute, it returns the fallback value if provided, otherwise raises AttributeError
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return _multi_getattr(objs, attr)
|
||||||
|
except AttributeError as e:
|
||||||
|
if fallback is not None:
|
||||||
|
return fallback
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def _multi_getattr(objs, attr):
|
||||||
|
if not isinstance(objs, (list, tuple)):
|
||||||
|
objs = [objs]
|
||||||
|
for obj in objs:
|
||||||
|
try:
|
||||||
|
return getattr(obj, attr)
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
raise AttributeError(f"None of the objects has attribute {attr}")
|
||||||
@@ -27,4 +27,19 @@ def optional_suggest_int(trial, name, range_or_value, step=None, log=False):
|
|||||||
return _optional_suggest(trial, name, range_or_value, step=step, log=log, type='int')
|
return _optional_suggest(trial, name, range_or_value, step=step, log=log, type='int')
|
||||||
|
|
||||||
def optional_suggest_float(trial, name, range_or_value, step=None, log=False):
|
def optional_suggest_float(trial, name, range_or_value, step=None, log=False):
|
||||||
return _optional_suggest(trial, name, range_or_value, step=step, log=log, type='float')
|
return _optional_suggest(trial, name, range_or_value, step=step, log=log, type='float')
|
||||||
|
|
||||||
|
def force_suggest_int(trial, name, range_or_value, step=1, log=False):
|
||||||
|
if not hasattr(range_or_value, '__iter__') or isinstance(range_or_value, str):
|
||||||
|
return trial.suggest_int(name, range_or_value, range_or_value, step=step, log=log)
|
||||||
|
return trial.suggest_int(name, *range_or_value, step=step, log=log)
|
||||||
|
|
||||||
|
def force_suggest_float(trial, name, range_or_value, step=None, log=False):
|
||||||
|
if not hasattr(range_or_value, '__iter__') or isinstance(range_or_value, str):
|
||||||
|
return trial.suggest_float(name, range_or_value, range_or_value, step=step, log=log)
|
||||||
|
return trial.suggest_float(name, *range_or_value, step=step, log=log)
|
||||||
|
|
||||||
|
def force_suggest_categorical(trial, name, range_or_value):
|
||||||
|
if not hasattr(range_or_value, '__iter__') or isinstance(range_or_value, str):
|
||||||
|
return trial.suggest_categorical(name, [range_or_value])
|
||||||
|
return trial.suggest_categorical(name, range_or_value)
|
||||||
@@ -38,7 +38,7 @@ def eye(*, path=None, data=None, sps=None, title=None, symbols=1000, skipfirst=0
|
|||||||
axs[0, 1].plot(xaxis, np.abs(outx), color="C0", alpha=alpha or 0.1)
|
axs[0, 1].plot(xaxis, np.abs(outx), color="C0", alpha=alpha or 0.1)
|
||||||
axs[1, 0].plot(xaxis, np.abs(iny), color="C0", alpha=alpha or 0.1)
|
axs[1, 0].plot(xaxis, np.abs(iny), color="C0", alpha=alpha or 0.1)
|
||||||
axs[1, 1].plot(xaxis, np.abs(outy), color="C0", alpha=alpha or 0.1)
|
axs[1, 1].plot(xaxis, np.abs(outy), color="C0", alpha=alpha or 0.1)
|
||||||
axs[0,0].set_ylim(0, 1.1*np.max(np.abs(data)))
|
axs[0, 0].set_ylim(0, 1.1*np.max(np.abs(data)))
|
||||||
|
|
||||||
axs2[0, 0].plot(xaxis, np.angle(inx), color="C1", alpha=alpha or 0.1)
|
axs2[0, 0].plot(xaxis, np.angle(inx), color="C1", alpha=alpha or 0.1)
|
||||||
axs2[0, 1].plot(xaxis, np.angle(outx), color="C1", alpha=alpha or 0.1)
|
axs2[0, 1].plot(xaxis, np.angle(outx), color="C1", alpha=alpha or 0.1)
|
||||||
|
|||||||
Reference in New Issue
Block a user