Compare commits

...

2 Commits

Author SHA1 Message Date
Joseph Hopfmüller
6358c95c42 new hyperparameter db 2024-11-20 22:49:40 +01:00
Joseph Hopfmüller
674033ac2e move hypertraining class into separate file;
move settings dataclasses into separate file;
add SemiUnitaryLayer;
clean up model response plotting code;
cnt hyperparameter search
2024-11-20 22:49:31 +01:00
12 changed files with 1067 additions and 553 deletions

View File

@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:72460af57347d35df91cd76982231bcf538a82fd7f1b8522795202fa298a2dcb
size 696320
oid sha256:e12f0c21fca93620a165fbb6ed58d0b313093e972ef4416694c29c9cea6dc867
size 831488

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:7231dea2c9107f443de9122fdc971d9ce6df93db2ee27a9d68a5e22c986373eb
size 937984

View File

@@ -0,0 +1,735 @@
import copy
from datetime import datetime
from pathlib import Path
from typing import Literal
import matplotlib.pyplot as plt
import numpy as np
import optuna
import warnings
import torch
import torch.nn as nn
# import torch.nn.functional as F # mse_loss doesn't support complex numbers
import torch.optim as optim
import torch.utils.data
from torch.utils.tensorboard import SummaryWriter
from rich.progress import (
Progress,
TextColumn,
BarColumn,
TaskProgressColumn,
TimeRemainingColumn,
MofNCompleteColumn,
TimeElapsedColumn,
)
from rich.console import Console
# from rich import print as rprint
import multiprocessing
from util.datasets import FiberRegenerationDataset
from util.optuna_helpers import (
force_suggest_categorical,
force_suggest_float,
force_suggest_int,
)
import util
from .settings import (
GlobalSettings,
DataSettings,
ModelSettings,
OptunaSettings,
OptimizerSettings,
PytorchSettings,
)
class HyperTraining:
def __init__(
self,
*,
global_settings,
data_settings,
pytorch_settings,
model_settings,
optimizer_settings,
optuna_settings,
console=None,
):
self.global_settings: GlobalSettings = global_settings
self.data_settings: DataSettings = data_settings
self.pytorch_settings: PytorchSettings = pytorch_settings
self.model_settings: ModelSettings = model_settings
self.optimizer_settings: OptimizerSettings = optimizer_settings
self.optuna_settings: OptunaSettings = optuna_settings
self.console = console or Console()
# set some extra settings to make the code more readable
self._extra_optuna_settings()
def setup_tb_writer(self, study_name=None, append=None):
log_dir = (
self.pytorch_settings.summary_dir
+ "/"
+ (study_name or self.optuna_settings.study_name)
)
if append is not None:
log_dir += "_" + str(append)
return SummaryWriter(log_dir)
def resume_latest_study(self, verbose=True):
study_name = self.get_latest_study()
if study_name:
print(f"Resuming study: {study_name}")
self.optuna_settings.study_name = study_name
def get_latest_study(self, verbose=True):
studies = self.get_studies()
for study in studies:
study.datetime_start = study.datetime_start or datetime.min
if studies:
study = sorted(studies, key=lambda x: x.datetime_start, reverse=True)[0]
if verbose:
print(f"Last study: {study.study_name}")
study_name = study.study_name
else:
if verbose:
print("No previous studies found")
study_name = None
return study_name
def get_studies(self):
return optuna.get_all_study_summaries(storage=self.optuna_settings.storage)
def setup_study(self):
self.study = optuna.create_study(
study_name=self.optuna_settings.study_name,
storage=self.optuna_settings.storage,
load_if_exists=True,
direction=self.optuna_settings.direction,
directions=self.optuna_settings.directions,
)
with warnings.catch_warnings(action="ignore"):
self.study.set_metric_names(self.optuna_settings.metrics_names)
self.n_threads = min(
self.optuna_settings.n_trials, self.optuna_settings.n_threads
)
self.processes = []
if self.n_threads > 1:
for _ in range(self.n_threads):
p = multiprocessing.Process(
# target=lambda n_trials: self._run_optimize(self, n_trials),
target=self._run_optimize,
args=(self.optuna_settings.n_trials // self.n_threads,),
)
self.processes.append(p)
# def plot_eye(self, width=2, symbols=None, alpha=None, complex=False, show=True):
# data, config = util.datasets.load_data(
# self.data_settings.config_path,
# skipfirst=10,
# symbols=symbols or 1000,
# real=not complex,
# normalize=True,
# )
# eye_data = {"data": data.numpy(), "sps": int(config["glova"]["sps"])}
# return util.plot.eye(
# **eye_data,
# width=width,
# show=show,
# alpha=alpha,
# complex=complex,
# symbols=symbols or 1000,
# skipfirst=0,
# )
def run_study(self):
if self.processes:
for p in self.processes:
p.start()
for p in self.processes:
p.join()
remaining_trials = self.optuna_settings.n_trials % self.n_threads
else:
remaining_trials = self.optuna_settings.n_trials
if remaining_trials:
self._run_optimize(remaining_trials)
def _run_optimize(self, n_trials):
self.study.optimize(
self.objective, n_trials=n_trials, timeout=self.optuna_settings.timeout
)
def _extra_optuna_settings(self):
self.optuna_settings.multi_objective = len(self.optuna_settings.directions) > 1
if self.optuna_settings.multi_objective:
self.optuna_settings.direction = None
else:
self.optuna_settings.direction = self.optuna_settings.directions[0]
self.optuna_settings.directions = None
self.optuna_settings.n_train_batches = (
self.optuna_settings.n_train_batches
if self.optuna_settings.limit_examples
else float("inf")
)
self.optuna_settings.n_valid_batches = (
self.optuna_settings.n_valid_batches
if self.optuna_settings.limit_examples
else float("inf")
)
def define_model(self, trial: optuna.Trial, writer=None):
n_layers = force_suggest_int(
trial, "model_n_layers", self.model_settings.model_n_layers
)
input_dim = 2 * trial.params.get(
"model_input_dim",
force_suggest_int(trial, "model_input_dim", self.data_settings.model_input_dim),
)
dtype = trial.params.get(
"model_dtype",
force_suggest_categorical(trial, "model_dtype", self.data_settings.dtype),
)
dtype = getattr(torch, dtype)
afunc = force_suggest_categorical(
trial, "model_activation_func", self.model_settings.model_activation_func
)
layers = []
last_dim = input_dim
for i in range(n_layers):
hidden_dim = force_suggest_int(
trial, f"model_hidden_dim_{i}", self.model_settings.unit_count
)
layers.append(
util.complexNN.SemiUnitaryLayer(last_dim, hidden_dim, dtype=dtype)
)
last_dim = hidden_dim
layers.append(getattr(util.complexNN, afunc)())
layers.append(
util.complexNN.UnitaryLayer(
hidden_dim, self.model_settings.output_dim, dtype=dtype
)
)
model = nn.Sequential(*layers)
if writer is not None:
writer.add_graph(
model, torch.zeros(1, input_dim, dtype=dtype), use_strict_trace=False
)
return model.to(self.pytorch_settings.device)
def get_sliced_data(self, trial: optuna.Trial, override=None):
symbols = trial.params.get(
"dataset_symbols",
force_suggest_float(trial, "dataset_symbols", self.data_settings.symbols),
)
xy_delay = trial.params.get(
"dataset_xy_delay",
force_suggest_float(trial, "dataset_xy_delay", self.data_settings.xy_delay),
)
data_size = trial.params.get(
"model_input_dim",
force_suggest_int(trial, "model_input_dim", self.data_settings.model_input_dim),
)
dtype = trial.params.get(
"model_dtype",
force_suggest_categorical(trial, "model_dtype", self.data_settings.dtype),
)
dtype = getattr(torch, dtype)
num_symbols = None
if override is not None:
num_symbols = override.get("num_symbols", None)
# get dataset
dataset = FiberRegenerationDataset(
file_path=self.data_settings.config_path,
symbols=symbols,
output_dim=data_size,
target_delay=self.data_settings.in_out_delay,
xy_delay=xy_delay,
drop_first=self.data_settings.drop_first,
dtype=dtype,
real=not dtype.is_complex,
num_symbols=num_symbols,
)
dataset_size = len(dataset)
indices = list(range(dataset_size))
split = int(np.floor(self.data_settings.train_split * dataset_size))
if self.data_settings.shuffle:
np.random.seed(self.global_settings.seed)
np.random.shuffle(indices)
train_indices, valid_indices = indices[:split], indices[split:]
if self.data_settings.shuffle:
train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
valid_sampler = torch.utils.data.SubsetRandomSampler(valid_indices)
else:
train_sampler = train_indices
valid_sampler = valid_indices
train_loader = torch.utils.data.DataLoader(
dataset,
batch_size=self.pytorch_settings.batchsize,
sampler=train_sampler,
drop_last=True,
pin_memory=True,
num_workers=self.pytorch_settings.dataloader_workers,
prefetch_factor=self.pytorch_settings.dataloader_prefetch,
)
valid_loader = torch.utils.data.DataLoader(
dataset,
batch_size=self.pytorch_settings.batchsize,
sampler=valid_sampler,
drop_last=True,
pin_memory=True,
num_workers=self.pytorch_settings.dataloader_workers,
prefetch_factor=self.pytorch_settings.dataloader_prefetch,
)
return train_loader, valid_loader
def train_model(
self,
trial,
model,
optimizer,
train_loader,
epoch,
writer=None,
enable_progress=False,
):
if enable_progress:
progress = Progress(
TextColumn("[yellow] Training..."),
TextColumn("Error: {task.description}"),
BarColumn(),
TaskProgressColumn(),
TextColumn("[green]Batch"),
MofNCompleteColumn(),
TimeRemainingColumn(),
TimeElapsedColumn(),
# description="Training",
transient=False,
console=self.console,
refresh_per_second=10,
)
task = progress.add_task("-.---e--", total=len(train_loader))
progress.start()
running_loss2 = 0.0
running_loss = 0.0
model.train()
for batch_idx, (x, y) in enumerate(train_loader):
if batch_idx >= self.optuna_settings.n_train_batches:
break
model.zero_grad(set_to_none=True)
x, y = (
x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device),
)
y_pred = model(x)
loss = util.complexNN.complex_mse_loss(y_pred, y)
loss_value = loss.item()
loss.backward()
optimizer.step()
running_loss2 += loss_value
running_loss += loss_value
if enable_progress:
progress.update(task, advance=1, description=f"{loss_value:.3e}")
if writer is not None:
if batch_idx % self.pytorch_settings.write_every == 0:
writer.add_scalar(
"training loss",
running_loss2
/ (self.pytorch_settings.write_every if batch_idx > 0 else 1),
epoch
* min(len(train_loader), self.optuna_settings.n_train_batches)
+ batch_idx,
)
running_loss2 = 0.0
if enable_progress:
progress.stop()
return running_loss / min(
len(train_loader), self.optuna_settings.n_train_batches
)
def eval_model(
self, trial, model, valid_loader, epoch, writer=None, enable_progress=True
):
if enable_progress:
progress = Progress(
TextColumn("[green]Evaluating..."),
TextColumn("Error: {task.description}"),
BarColumn(),
TaskProgressColumn(),
TextColumn("[green]Batch"),
MofNCompleteColumn(),
TimeRemainingColumn(),
TimeElapsedColumn(),
# description="Training",
transient=False,
console=self.console,
refresh_per_second=10,
)
progress.start()
task = progress.add_task("-.---e--", total=len(valid_loader))
model.eval()
running_error = 0
running_error_2 = 0
with torch.no_grad():
for batch_idx, (x, y) in enumerate(valid_loader):
if batch_idx >= self.optuna_settings.n_valid_batches:
break
x, y = (
x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device),
)
y_pred = model(x)
error = util.complexNN.complex_mse_loss(y_pred, y)
error_value = error.item()
running_error += error_value
running_error_2 += error_value
if enable_progress:
progress.update(task, advance=1, description=f"{error_value:.3e}")
if writer is not None:
if batch_idx % self.pytorch_settings.write_every == 0:
writer.add_scalar(
"eval loss",
running_error_2
/ (
self.pytorch_settings.write_every
if batch_idx > 0
else 1
),
epoch
* min(
len(valid_loader), self.optuna_settings.n_valid_batches
)
+ batch_idx,
)
running_error_2 = 0.0
running_error /= min(len(valid_loader), self.optuna_settings.n_valid_batches)
if writer is not None:
title_append, subtitle = self.build_title(trial)
writer.add_figure(
"fiber response",
self.plot_model_response(
trial,
model=model,
title_append=title_append,
subtitle=subtitle,
show=False,
),
epoch + 1,
)
if enable_progress:
progress.stop()
return running_error
def run_model(self, model, loader):
model.eval()
xs = []
ys = []
y_preds = []
with torch.no_grad():
model = model.to(self.pytorch_settings.device)
for x, y in loader:
x, y = (
x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device),
)
y_pred = model(x).cpu()
# x = x.cpu()
# y = y.cpu()
y_pred = y_pred.view(y_pred.shape[0], -1, 2)
y = y.view(y.shape[0], -1, 2)
x = x.view(x.shape[0], -1, 2)
xs.append(x[:, 0, :].squeeze())
ys.append(y.squeeze())
y_preds.append(y_pred.squeeze())
xs = torch.vstack(xs).cpu()
ys = torch.vstack(ys).cpu()
y_preds = torch.vstack(y_preds).cpu()
return ys, xs, y_preds
def objective(self, trial: optuna.Trial, plot_before=False):
model = None
exc = None
try:
# rprint(*list(self.study_name.split("_")))
writer = self.setup_tb_writer(
self.optuna_settings.study_name,
f"{trial.number:0{len(str(self.optuna_settings.n_trials))}}",
)
model = self.define_model(trial, writer)
n_params = sum(p.numel() for p in model.parameters())
# n_nodes = trial.params.get("model_n_layers", self.model_settings.model_n_layers) * trial.params.get("model_hidden_dim", self.model_settings.unit_count)
title_append, subtitle = self.build_title(trial)
writer.add_figure(
"fiber response",
self.plot_model_response(
trial,
model=model,
title_append=title_append,
subtitle=subtitle,
show=plot_before,
),
0,
)
train_loader, valid_loader = self.get_sliced_data(trial)
optimizer_name = force_suggest_categorical(
trial, "optimizer", self.optimizer_settings.optimizer
)
lr = force_suggest_float(
trial, "lr", self.optimizer_settings.learning_rate, log=True
)
optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)
if self.optimizer_settings.scheduler is not None:
scheduler = getattr(optim.lr_scheduler, self.optimizer_settings.scheduler)(
optimizer, **self.optimizer_settings.scheduler_kwargs)
for epoch in range(self.pytorch_settings.epochs):
enable_progress = self.optuna_settings.n_threads == 1
if enable_progress:
self.console.rule(
f"Epoch {epoch + 1}/{self.pytorch_settings.epochs}"
)
self.train_model(
trial,
model,
optimizer,
train_loader,
epoch,
writer,
enable_progress=enable_progress,
)
error = self.eval_model(
trial,
model,
valid_loader,
epoch,
writer,
enable_progress=enable_progress,
)
if self.optimizer_settings.scheduler is not None:
scheduler.step(error)
writer.close()
if self.optuna_settings.multi_objective:
return n_params, error
trial.report(error, epoch)
if trial.should_prune():
raise optuna.exceptions.TrialPruned()
return error
except KeyboardInterrupt:
...
# except Exception as e:
# exc = e
finally:
if model is not None:
save_path = (
Path(self.pytorch_settings.model_dir)
/ f"{self.optuna_settings.study_name}_{trial.number}.pth"
)
save_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(model, save_path)
if exc is not None:
raise exc
def _plot_model_response_eye(
self, *signals, labels=None, sps=None, title_append="", subtitle="", show=True
):
if sps is None:
raise ValueError("sps must be provided")
if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))):
labels = [labels]
else:
labels = list(labels)
while len(labels) < len(signals):
labels.append(None)
# check if there are any labels
if not any(labels):
labels = [f"signal {i + 1}" for i in range(len(signals))]
fig, axs = plt.subplots(2, len(signals), sharex=True, sharey=True)
fig.suptitle(
f"Eye diagram{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}"
)
xaxis = np.linspace(0, 2, 2 * sps, endpoint=False)
for j, (label, signal) in enumerate(zip(labels, signals)):
# signal = signal.cpu().numpy()
for i in range(len(signal) // sps - 1):
x, y = signal[i * sps : (i + 2) * sps].T
axs[0, j].plot(xaxis, np.abs(x) ** 2, color="C0", alpha=0.02)
axs[1, j].plot(xaxis, np.abs(y) ** 2, color="C0", alpha=0.02)
axs[0, j].set_title(label + " x")
axs[1, j].set_title(label + " y")
axs[0, j].set_xlabel("Symbol")
axs[1, j].set_xlabel("Symbol")
axs[0, j].set_ylabel("normalized power")
axs[1, j].set_ylabel("normalized power")
if show:
plt.show()
def _plot_model_response_head(
self, *signals, labels=None, sps=None, title_append="", subtitle="", show=True
):
if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))):
labels = [labels]
else:
labels = list(labels)
while len(labels) < len(signals):
labels.append(None)
# check if there are any labels
if not any(labels):
labels = [f"signal {i + 1}" for i in range(len(signals))]
fig, axs = plt.subplots(1, 2, sharex=True, sharey=True)
fig.set_size_inches(18,6)
fig.suptitle(
f"Fiber response{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}"
)
for i, ax in enumerate(axs):
for signal, label in zip(signals, labels):
if sps is not None:
xaxis = np.linspace(
0, len(signal) / sps, len(signal), endpoint=False
)
else:
xaxis = np.arange(len(signal))
ax.plot(xaxis, np.abs(signal[:, i]) ** 2, label=label)
ax.set_xlabel("Sample" if sps is None else "Symbol")
ax.set_ylabel("normalized power")
ax.legend(loc="upper right")
if show:
plt.show()
return fig
def plot_model_response(
self,
trial,
model=None,
title_append="",
subtitle="",
mode: Literal["eye", "head"] = "head",
show=True,
):
data_settings_backup = copy.deepcopy(self.data_settings)
pytorch_settings_backup = copy.deepcopy(self.pytorch_settings)
self.data_settings.drop_first = 100
self.data_settings.shuffle = False
self.data_settings.train_split = 1.0
self.pytorch_settings.batchsize = self.pytorch_settings.eye_symbols if mode == "eye" else self.pytorch_settings.head_symbols
plot_loader, _ = self.get_sliced_data(
trial, override={"num_symbols": self.pytorch_settings.batchsize}
)
self.data_settings = data_settings_backup
self.pytorch_settings = pytorch_settings_backup
fiber_in, fiber_out, regen = self.run_model(model, plot_loader)
fiber_in = fiber_in.view(-1, 2)
fiber_out = fiber_out.view(-1, 2)
regen = regen.view(-1, 2)
fiber_in = fiber_in.numpy()
fiber_out = fiber_out.numpy()
regen = regen.numpy()
# https://github.com/matplotlib/matplotlib/issues/27713#issue-2104110987
# https://github.com/matplotlib/matplotlib/issues/27713#issuecomment-1915497463
import gc
if mode == "head":
fig = self._plot_model_response_head(
fiber_in,
fiber_out,
regen,
labels=("fiber in", "fiber out", "regen"),
sps=plot_loader.dataset.samples_per_symbol,
title_append=title_append,
subtitle=subtitle,
show=show,
)
elif mode == "eye":
# raise NotImplementedError("Eye diagram not implemented")
fig = self._plot_model_response_eye(
fiber_in,
fiber_out,
regen,
labels=("fiber in", "fiber out", "regen"),
sps=plot_loader.dataset.samples_per_symbol,
title_append=title_append,
subtitle=subtitle,
show=show,
)
else:
raise ValueError(f"Unknown mode: {mode}")
gc.collect()
return fig
@staticmethod
def build_title(trial):
title_append = f"for trial {trial.number}"
subtitle = (
f"{trial.params['model_n_layers']} layers, "
f"{', '.join([str(trial.params[f'model_hidden_dim_{i}']) for i in range(trial.params['model_n_layers'])])} units, "
f"{trial.params['model_activation_func']}, "
f"{trial.params['model_dtype']}"
)
return title_append, subtitle

View File

@@ -0,0 +1,77 @@
from dataclasses import dataclass
from datetime import datetime
# global settings
@dataclass(frozen=True)
class GlobalSettings:
seed: int = 42
# data settings
@dataclass
class DataSettings:
config_path: str # = "data/*-128-16384-100000-0-0-17-0-PAM4-0.ini"
dtype: tuple = ("complex64", "float64")
symbols: tuple | float | int = 8
model_input_dim: tuple | float | int = 64
shuffle: bool = True
in_out_delay: float = 0
xy_delay: tuple | float | int = 0
drop_first: int = 1000
train_split: float = 0.8
# pytorch settings
@dataclass
class PytorchSettings:
epochs: int = 1
batchsize: int = 2**10
device: str = "cuda"
dataloader_workers: int = 2
dataloader_prefetch: int = 2
model_dir: str = ".models"
summary_dir: str = ".runs"
write_every: int = 10
head_symbols: int = 40
eye_symbols: int = 1000
# model settings
@dataclass
class ModelSettings:
output_dim: int = 2
model_n_layers: tuple | int = 3
unit_count: tuple | int = 8
# n_units_range: tuple | int = (2, 32)
# activation_func_range: tuple = ("ModReLU", "ZReLU", "CReLU", "Mag", "Identity")
model_activation_func: tuple = ("ModReLU",)
@dataclass
class OptimizerSettings:
optimizer: tuple | str = ("Adam", "RMSprop", "SGD")
learning_rate: tuple | float = (1e-5, 1e-1)
scheduler: str | None = None
scheduler_kwargs: dict | None = None
# optuna settings
@dataclass
class OptunaSettings:
n_trials: int = 128
n_threads: int = 4
timeout: int = 600
directions: tuple = ("minimize",)
metrics_names: tuple = ("mse",)
limit_examples: bool = True
n_train_batches: int = 100
n_valid_batches: int = 100
storage: str = "sqlite:///example.db"
study_name: str = (
f"optuna_study_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}"
)

View File

@@ -1,464 +1,107 @@
import copy
from dataclasses import dataclass
from datetime import datetime
import matplotlib.pyplot as plt
import numpy as np
import optuna
import warnings
import torch
import torch.nn as nn
# import torch.nn.functional as F # mse_loss doesn't support complex numbers
import torch.optim as optim
import torch.utils.data
from torch.utils.tensorboard import SummaryWriter
from rich.progress import Progress, TextColumn, BarColumn, TaskProgressColumn, TimeRemainingColumn, MofNCompleteColumn
from rich.console import Console
import multiprocessing
from util.datasets import FiberRegenerationDataset
from util.complexNN import complex_sse_loss
from util.optuna_helpers import optional_suggest_categorical, optional_suggest_float, optional_suggest_int
import util
# global settings
@dataclass
class GlobalSettings:
seed: int = 42
# data settings
@dataclass
class DataSettings:
config_path: str = "data/*-128-16384-1000-0-0-17-0-PAM4-0.ini"
dtype: torch.dtype = torch.complex64
symbols_range: tuple|float|int = 16
data_size_range: tuple|float|int = 32
shuffle: bool = True
target_delay: float = 0
xy_delay_range: tuple|float|int = 0
drop_first: int = 10
train_split: float = 0.8
# pytorch settings
@dataclass
class PytorchSettings:
device: str = "cuda"
batchsize: int = 1024
epochs: int = 10
summary_dir: str = ".runs"
# model settings
@dataclass
class ModelSettings:
output_size: int = 2
n_layer_range: tuple|float|int = (2,8)
n_units_range: tuple|float|int = (2,32)
# activation_func_range: tuple = ("ReLU",)
@dataclass
class OptimizerSettings:
# optimizer_range: tuple|str = ("Adam", "RMSprop", "SGD")
optimizer_range: tuple|str = "RMSprop"
# lr_range: tuple|float = (1e-5, 1e-1)
lr_range: tuple|float = 2e-5
# optuna settings
@dataclass
class OptunaSettings:
n_trials: int = 128
n_threads: int = 8
timeout: int = 600
directions: tuple = ("minimize",)
metrics_names: tuple = ("sse",)
limit_examples: bool = True
n_train_examples: int = PytorchSettings.batchsize * 50
# n_valid_examples: int = PytorchSettings.batchsize * 100
n_valid_examples: int = float("inf")
storage: str = "sqlite:///optuna_single_core_regen.db"
study_name: str = (
f"single_core_regen_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}"
)
class HyperTraining:
def __init__(self):
self.global_settings = GlobalSettings()
self.data_settings = DataSettings()
self.pytorch_settings = PytorchSettings()
self.model_settings = ModelSettings()
self.optimizer_settings = OptimizerSettings()
self.optuna_settings = OptunaSettings()
self.console = Console()
# set some extra settings to make the code more readable
self._extra_optuna_settings()
def setup_tb_writer(self, study_name=None, append=None):
log_dir = self.pytorch_settings.summary_dir + "/" + (study_name or self.optuna_settings.study_name)
if append is not None:
log_dir += "_" + str(append)
return SummaryWriter(log_dir)
def resume_latest_study(self, verbose=True):
study_name = hyper_training.get_latest_study()
if study_name:
print(f"Resuming study: {study_name}")
self.optuna_settings.study_name = study_name
def get_latest_study(self, verbose=True):
studies = self.get_studies()
for study in studies:
study.datetime_start = study.datetime_start or datetime.min
if studies:
study = sorted(studies, key = lambda x: x.datetime_start, reverse=True)[0]
if verbose:
print(f"Last study: {study.study_name}")
study_name = study.study_name
else:
if verbose:
print("No previous studies found")
study_name = None
return study_name
def get_studies(self):
return optuna.get_all_study_summaries(storage=self.optuna_settings.storage)
def setup_study(self):
self.study = optuna.create_study(
study_name=self.optuna_settings.study_name,
storage=self.optuna_settings.storage,
load_if_exists=True,
direction=self.optuna_settings.direction,
directions=self.optuna_settings.directions,
)
with warnings.catch_warnings(action="ignore"):
self.study.set_metric_names(self.optuna_settings.metrics_names)
self.n_threads = min(
self.optuna_settings.n_trials, self.optuna_settings.n_threads
)
self.processes = []
if self.n_threads > 1:
for _ in range(self.n_threads):
p = multiprocessing.Process(
# target=lambda n_trials: self._run_optimize(self, n_trials),
target = self._run_optimize,
args = (self.optuna_settings.n_trials // self.n_threads,),
)
self.processes.append(p)
def run_study(self):
if self.processes:
for p in self.processes:
p.start()
for p in self.processes:
p.join()
remaining_trials = (
self.optuna_settings.n_trials
- self.optuna_settings.n_trials % self.optuna_settings.n_threads
)
else:
remaining_trials = self.optuna_settings.n_trials
if remaining_trials:
self._run_optimize(remaining_trials)
def _run_optimize(self, n_trials):
self.study.optimize(
self.objective, n_trials=n_trials, timeout=self.optuna_settings.timeout
)
def plot_eye(self, show=True):
if not hasattr(self, "eye_data"):
data, config = util.datasets.load_data(
self.data_settings.config_path, skipfirst=10, symbols=1000
)
self.eye_data = {"data": data, "sps": int(config["glova"]["sps"])}
return util.plot.eye(**self.eye_data, show=show)
def _extra_optuna_settings(self):
self.optuna_settings.multi_objective = len(self.optuna_settings.directions) > 1
if self.optuna_settings.multi_objective:
self.optuna_settings.direction = None
else:
self.optuna_settings.direction = self.optuna_settings.directions[0]
self.optuna_settings.directions = None
self.optuna_settings.n_train_examples = (
self.optuna_settings.n_train_examples
if self.optuna_settings.limit_examples
else float("inf")
)
self.optuna_settings.n_valid_examples = (
self.optuna_settings.n_valid_examples
if self.optuna_settings.limit_examples
else float("inf")
)
def define_model(self, trial: optuna.Trial, writer=None):
n_layers = optional_suggest_int(trial, "model_n_layers", self.model_settings.n_layer_range)
in_features = 2 * trial.params.get(
"dataset_data_size",
optional_suggest_int(trial, "dataset_data_size", self.data_settings.data_size_range),
)
trial.set_user_attr("input_dim", in_features)
layers = []
for i in range(n_layers):
out_features = optional_suggest_int(trial, f"model_n_units_l{i}", self.model_settings.n_units_range, log=True)
layers.append(nn.Linear(in_features, out_features, dtype=self.data_settings.dtype))
# layers.append(getattr(nn, activation_func)())
in_features = out_features
layers.append(nn.Linear(in_features, self.model_settings.output_size, dtype=self.data_settings.dtype))
if writer is not None:
writer.add_graph(nn.Sequential(*layers), torch.zeros(1, trial.user_attrs["input_dim"], dtype=self.data_settings.dtype))
return nn.Sequential(*layers)
def get_sliced_data(self, trial: optuna.Trial):
symbols = optional_suggest_float(trial, "dataset_symbols", self.data_settings.symbols_range)
xy_delay = optional_suggest_float(trial, "dataset_xy_delay", self.data_settings.xy_delay_range)
data_size = trial.params.get(
"dataset_data_size",
optional_suggest_int(trial, "dataset_data_size", self.data_settings.data_size_range)
)
# get dataset
dataset = FiberRegenerationDataset(
file_path=self.data_settings.config_path,
symbols=symbols,
data_size=data_size,
target_delay=self.data_settings.target_delay,
xy_delay=xy_delay,
drop_first=self.data_settings.drop_first,
dtype=self.data_settings.dtype,
)
dataset_size = len(dataset)
indices = list(range(dataset_size))
split = int(np.floor(self.data_settings.train_split * dataset_size))
if self.data_settings.shuffle:
np.random.seed(self.global_settings.seed)
np.random.shuffle(indices)
train_indices, valid_indices = indices[:split], indices[split:]
train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
valid_sampler = torch.utils.data.SubsetRandomSampler(valid_indices)
train_loader = torch.utils.data.DataLoader(
dataset, batch_size=self.pytorch_settings.batchsize, sampler=train_sampler, drop_last=True
)
valid_loader = torch.utils.data.DataLoader(
dataset, batch_size=self.pytorch_settings.batchsize, sampler=valid_sampler, drop_last=True
)
return train_loader, valid_loader
def train_model(self, model, optimizer, train_loader, epoch, writer=None, enable_progress=True):
if enable_progress:
progress = Progress(
TextColumn("[yellow] Training..."),
TextColumn(" Loss: {task.description}"),
BarColumn(),
TaskProgressColumn(),
TextColumn("[green]Batch"),
MofNCompleteColumn(),
TimeRemainingColumn(),
# description="Training",
transient=False,
console=self.console,
refresh_per_second=10,
)
task = progress.add_task("-.---e--", total=len(train_loader))
running_loss = 0.0
last_loss = 0.0
model.train()
for batch_idx, (x, y) in enumerate(train_loader):
if (
batch_idx * train_loader.batch_size
>= self.optuna_settings.n_train_examples
):
break
optimizer.zero_grad()
x, y = (
x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device),
)
y_pred = model(x)
loss = complex_sse_loss(y_pred, y)
loss.backward()
optimizer.step()
# clamp weights to keep energy bounded
for p in model.parameters():
p.data.clamp_(-1.0, 1.0)
last_loss = loss.item()
if enable_progress:
progress.update(task, advance=1, description=f"{last_loss:.3e}")
running_loss += loss.item()
if writer is not None:
if batch_idx % 10 == 0:
writer.add_scalar("training loss", running_loss/10, epoch*min(len(train_loader), self.optuna_settings.n_train_examples/train_loader.batch_size) + batch_idx)
running_loss = 0.0
if enable_progress:
progress.update(task, description=f"{last_loss:.3e}")
progress.stop()
def eval_model(self, model, valid_loader, epoch, writer=None, enable_progress=True):
if enable_progress:
progress = Progress(
TextColumn("[green]Evaluating..."),
TextColumn("Error: {task.description}"),
BarColumn(),
TaskProgressColumn(),
TextColumn("[green]Batch"),
MofNCompleteColumn(),
TimeRemainingColumn(),
# description="Training",
transient=False,
console=self.console,
refresh_per_second=10,
)
task = progress.add_task("-.---e--", total=len(valid_loader))
model.eval()
running_error = 0
running_error_2 = 0
with torch.no_grad():
for batch_idx, (x, y) in enumerate(valid_loader):
if (
batch_idx * valid_loader.batch_size
>= self.optuna_settings.n_valid_examples
):
break
x, y = (
x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device),
)
y_pred = model(x)
error = complex_sse_loss(y_pred, y)
running_error += error.item()
running_error_2 += error.item()
if enable_progress:
progress.update(task, advance=1, description=f"{error.item():.3e}")
if writer is not None:
if batch_idx % 10 == 0:
writer.add_scalar("sse", running_error_2/10, epoch*min(len(valid_loader), self.optuna_settings.n_valid_examples/valid_loader.batch_size) + batch_idx)
running_error_2 = 0.0
running_error /= batch_idx + 1
if enable_progress:
progress.update(task, description=f"{running_error:.3e}")
progress.stop()
return running_error
def run_model(self, model, loader):
model.eval()
y_preds = []
with torch.no_grad():
for x, y in loader:
x, y = (
x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device),
)
y_preds.append(model(x))
return torch.stack(y_preds)
def objective(self, trial: optuna.Trial):
writer = self.setup_tb_writer(self.optuna_settings.study_name, f"{trial.number:0>len(str(self.optuna_settings.n_trials))}")
train_loader, valid_loader = self.get_sliced_data(trial)
model = self.define_model(trial, writer).to(self.pytorch_settings.device)
optimizer_name = optional_suggest_categorical(trial, "optimizer", self.optimizer_settings.optimizer_range)
lr = optional_suggest_float(trial, "lr", self.optimizer_settings.lr_range, log=True)
optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)
for epoch in range(self.pytorch_settings.epochs):
enable_progress = self.optuna_settings.n_threads == 1
if enable_progress:
print(f"Epoch {epoch+1}/{self.pytorch_settings.epochs}")
self.train_model(model, optimizer, train_loader, epoch, writer, enable_progress=enable_progress)
sse = self.eval_model(model, valid_loader, epoch, writer, enable_progress=enable_progress)
if not self.optuna_settings.multi_objective:
trial.report(sse, epoch)
if trial.should_prune():
raise optuna.exceptions.TrialPruned()
writer.close()
return sse
from hypertraining.hypertraining import HyperTraining
from hypertraining.settings import (
GlobalSettings,
DataSettings,
PytorchSettings,
ModelSettings,
OptimizerSettings,
OptunaSettings,
)
global_settings = GlobalSettings(
seed = 42,
)
data_settings = DataSettings(
config_path = "data/*-128-16384-100000-0-0-17-0-PAM4-0.ini",
dtype = ("complex64", "float64", "complex32", "float32"),
symbols = (1, 16),
model_input_dim = (1, 32),
shuffle = True,
in_out_delay = 0,
xy_delay = 0,
drop_first = 1000,
train_split = 0.8,
)
pytorch_settings = PytorchSettings(
epochs = 25,
batchsize = 2**10,
device = "cuda",
dataloader_workers = 2,
dataloader_prefetch = 2,
summary_dir = ".runs",
write_every = 2**5,
model_dir = ".models",
)
model_settings = ModelSettings(
output_dim = 2,
model_n_layers = (2, 8),
unit_count = (2, 16),
model_activation_func = ("ModReLU")#, "ZReLU", "Mag")#, "CReLU", "Identity"),
)
optimizer_settings = OptimizerSettings(
optimizer = ("Adam", "RMSprop"),#, "SGD"),
# learning_rate = (1e-5, 1e-1),
learning_rate=1e-3,
# scheduler = "ReduceLROnPlateau",
# scheduler_kwargs = {"mode": "min", "factor": 0.5, "patience": 10}
)
optuna_settings = OptunaSettings(
n_trials = 4096,
n_threads = 16,
timeout = 600,
directions = ("minimize","minimize"),
metrics_names = ("n_params","mse"),
limit_examples = True,
n_train_batches = 100,
n_valid_batches = 100,
storage = "sqlite:///data/single_core_regen.db",
study_name = f"single_core_regen_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
)
if __name__ == "__main__":
hyper_training = HyperTraining()
hyper_training = HyperTraining(
global_settings=global_settings,
data_settings=data_settings,
pytorch_settings=pytorch_settings,
model_settings=model_settings,
optimizer_settings=optimizer_settings,
optuna_settings=optuna_settings,
)
hyper_training.setup_study()
# hyper_training.resume_latest_study()
hyper_training.setup_study()
hyper_training.run_study()
# best_trial = hyper_training.study.best_trial
best_model = hyper_training.define_model(hyper_training.study.best_trial).to(hyper_training.pytorch_settings.device)
data_settings_backup = copy.copy(hyper_training.data_settings)
hyper_training.data_settings.shuffle = False
hyper_training.data_settings.train_split = 0.01
plot_loader, _ = hyper_training.get_sliced_data(hyper_training.study.best_trial)
# best_model = hyper_training.define_model(best_trial).to(
# hyper_training.pytorch_settings.device
# )
regen = hyper_training.run_model(best_model, plot_loader)
regen = regen.view(-1, 2)
# [batch_no, batch_size, 2] -> [no, 2]
original, _ = util.datasets.load_data(hyper_training.data_settings.config_path, skipfirst=hyper_training.data_settings.drop_first)
original = original[:len(regen)]
regen = regen.cpu().numpy()
_, axs = plt.subplots(2)
for i, ax in enumerate(axs):
ax.plot(np.abs(original[:, i])**2, label="original")
ax.plot(np.abs(regen[:, i])**2, label="regen")
ax.legend()
plt.show()
print(f"Best model: {best_model}")
# title_append, subtitle = hyper_training.build_title(best_trial)
# hyper_training.plot_model_response(
# best_trial,
# model=best_model,
# title_append=title_append,
# subtitle=subtitle,
# mode="eye",
# show=True,
# )
# print(f"Best model found for trial {best_trial.number}")
# print(f"Best model error: {best_trial.value}")
# print(f"Best model params: {best_trial.params}")
# print()
# print(best_model)
# eye_fig = hyper_training.plot_eye()
...

View File

@@ -95,11 +95,12 @@ class Training:
self.writer = None
self.console = Console()
def setup_tb_writer(self, study_name=None):
def setup_tb_writer(self, study_name=None, append=None):
log_dir = (
self.pytorch_settings.summary_dir + "/" + (study_name or self.study_name)
self.pytorch_settings.summary_dir + "/" + (study_name or self.study_name) + ("_" + str(append)) if append else ""
)
self.writer = SummaryWriter(log_dir)
return self.writer
def plot_eye(self, width=2, symbols=None, alpha=None, complex=False, show=True):
if not hasattr(self, "eye_data"):
@@ -160,7 +161,7 @@ class Training:
dataset = util.datasets.FiberRegenerationDataset(
file_path=self.data_settings.config_path,
symbols=symbols,
data_size=data_size,
output_dim=data_size,
target_delay=self.data_settings.target_delay,
xy_delay=xy_delay,
drop_first=self.data_settings.drop_first,
@@ -212,7 +213,7 @@ class Training:
def train_model(self, model, optimizer, train_loader, epoch):
with Progress(
TextColumn("[yellow] Training..."),
TextColumn("Loss: {task.description}"),
TextColumn("Error: {task.description}"),
BarColumn(),
TaskProgressColumn(),
TextColumn("[green]Batch"),
@@ -256,7 +257,7 @@ class Training:
def eval_model(self, model, valid_loader, epoch):
with Progress(
TextColumn("[green]Evaluating..."),
TextColumn("Loss: {task.description}"),
TextColumn("Error: {task.description}"),
BarColumn(),
TaskProgressColumn(),
TextColumn("[green]Batch"),
@@ -326,18 +327,6 @@ class Training:
y_preds = torch.vstack(y_preds).cpu()
return ys, xs, y_preds
def dummy_model(self, loader):
xs = []
ys = []
for x, y in loader:
y = y.cpu().view(y.shape[0], -1, 2)
x = x.cpu().view(x.shape[0], -1, 2)
xs.append(x[:, 0, :].squeeze())
ys.append(y.squeeze())
xs = torch.vstack(xs)
ys = torch.vstack(ys)
return xs, ys
def objective(self, save=False, plot_before=False):
try:
rprint(*list(self.study_name.split("_")))
@@ -360,22 +349,18 @@ class Training:
self.train_model(self.model, optimizer, train_loader, epoch)
eval_loss = self.eval_model(self.model, valid_loader, epoch)
if save:
return eval_loss
except KeyboardInterrupt:
...
finally:
if hasattr(self, "model"):
save_path = (
Path(self.pytorch_settings.model_dir) / f"{self.study_name}.pth"
)
save_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(self.model, save_path)
return eval_loss
except KeyboardInterrupt:
pass
finally:
if hasattr(self, "model"):
except_save_path = Path(".models/exception") / f"{self.study_name}.pth"
except_save_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(self.model, except_save_path)
def _plot_model_response_plotter(self, fiber_in, fiber_out, regen, plot=True):
fig, axs = plt.subplots(2)
for i, ax in enumerate(axs):

View File

@@ -15,3 +15,5 @@ from . import complexNN # noqa: F401
# from .complexNN import UnitaryLayer # noqa: F401
# from .complexNN import complex_mse_loss # noqa: F401
# from .complexNN import complex_sse_loss # noqa: F401
from . import misc # noqa: F401

View File

@@ -1,82 +1,120 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
def complex_mse_loss(input, target):
"""
Compute the mean squared error between two complex tensors.
"""
return torch.mean(torch.square(input.real - target.real) + torch.square(input.imag - target.imag))
if input.is_complex():
return torch.mean(
torch.square(input.real - target.real)
+ torch.square(input.imag - target.imag)
)
else:
return F.mse_loss(input, target)
def complex_sse_loss(input, target):
"""
Compute the sum squared error between two complex tensors.
"""
if input.is_complex():
return torch.sum(torch.square(input.real - target.real) + torch.square(input.imag - target.imag))
return torch.sum(
torch.square(input.real - target.real)
+ torch.square(input.imag - target.imag)
)
else:
return torch.sum(torch.square(input - target))
class UnitaryLayer(nn.Module):
def __init__(self, in_features, out_features):
super(UnitaryLayer, self).__init__()
def __init__(self, in_features, out_features, dtype=None):
assert in_features >= out_features
super(UnitaryLayer, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = nn.Parameter(torch.randn(in_features, out_features, dtype=torch.cfloat))
self.weight = nn.Parameter(torch.randn(in_features, out_features, dtype=dtype))
self.reset_parameters()
def reset_parameters(self):
q, _ = torch.linalg.qr(self.weight)
self.weight.data = q
@staticmethod
@torch.jit.script
def _unitary_forward(x, weight):
out = torch.matmul(x, weight)
return out
def forward(self, x):
return torch.matmul(x, self.weight)
def __repr__(self):
return f"UnitaryLayer({self.in_features}, {self.out_features})"
class SemiUnitaryLayer(nn.Module):
def __init__(self, input_dim, output_dim, dtype=None):
super(SemiUnitaryLayer, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
# Create a larger square matrix for QR decomposition
self.weight = nn.Parameter(torch.randn(max(input_dim, output_dim), max(input_dim, output_dim), dtype=dtype))
self.reset_parameters()
def reset_parameters(self):
# Ensure the weights are semi-unitary by QR decomposition
q, _ = torch.linalg.qr(self.weight)
if self.input_dim > self.output_dim:
self.weight.data = q[:self.input_dim, :self.output_dim]
else:
self.weight.data = q[:self.output_dim, :self.input_dim].t()
def forward(self, x):
return self._unitary_forward(x, self.weight)
out = torch.matmul(x, self.weight)
return out
def __repr__(self):
return f"SemiUnitaryLayer({self.input_dim}, {self.output_dim})"
# class SpreadLayer(nn.Module):
# def __init__(self, in_features, out_features, dtype=None):
# super(SpreadLayer, self).__init__()
# self.in_features = in_features
# self.out_features = out_features
# self.mat = torch.ones(in_features, out_features, dtype=dtype)*torch.sqrt(torch.tensor(in_features/out_features))
# def forward(self, x):
# # N in_features -> M out_features, Enery is preserved (P = abs(x)^2)
# out = torch.matmul(x, self.mat)
# return out
#### as defined by zhang et al
class Identity(nn.Module):
"""
implements the "activation" function
M(z) = z
"""
def __init__(self):
super(Identity, self).__init__()
def forward(self, x):
return x
class Mag(nn.Module):
"""
implements the activation function
M(z) = ||z||
"""
def __init__(self):
super(Mag, self).__init__()
@torch.jit.script
def forward(self, x):
return torch.abs(x.real**2 + x.imag**2)
return torch.abs(x).to(dtype=x.dtype)
# class Tanh(nn.Module):
# """
# implements the activation function
# M(z) = tanh(z) = sinh(z)/cosh(z) = (exp(z)-exp(-z))/(exp(z)+exp(-z)) = (exp(2*z)-1)/(exp(2*z)+1)
# """
# def __init__(self):
# super(Tanh, self).__init__()
# def forward(self, x):
# return torch.tanh(x)
class ModReLU(nn.Module):
"""
@@ -84,32 +122,38 @@ class ModReLU(nn.Module):
M(z) = ReLU(||z|| + b)*exp(j*theta_z)
= ReLU(||z|| + b)*z/||z||
"""
def __init__(self, b=0):
super(ModReLU, self).__init__()
self.b = b
self.relu = nn.ReLU()
@staticmethod
# @torch.jit.script
def _mod_relu(x, b):
mod = torch.abs(x.real**2 + x.imag**2)
return torch.relu(mod + b) * x / mod
self.b = torch.tensor(b)
def forward(self, x):
return self._mod_relu(x, self.b)
if x.is_complex():
mod = torch.abs(x.real**2 + x.imag**2)
return torch.relu(mod + self.b) * x / mod
else:
return torch.relu(x + self.b)
def __repr__(self):
return f"ModReLU(b={self.b})"
class CReLU(nn.Module):
"""
implements the activation function
M(z) = ReLU(Re(z)) + j*ReLU(Im(z))
"""
def __init__(self):
super(CReLU, self).__init__()
self.relu = nn.ReLU()
@torch.jit.script
def forward(self, x):
return torch.relu(x.real) + 1j*torch.relu(x.imag)
if x.is_complex():
return torch.relu(x.real) + 1j * torch.relu(x.imag)
else:
return torch.relu(x)
class ZReLU(nn.Module):
"""
@@ -122,20 +166,8 @@ class ZReLU(nn.Module):
def __init__(self):
super(ZReLU, self).__init__()
@torch.jit.script
def forward(self, x):
return x * (torch.angle(x) >= 0) * (torch.angle(x) <= torch.pi/2)
# class ComplexFeedForwardNN(nn.Module):
# def __init__(self, in_features, hidden_features, out_features):
# super(ComplexFeedForwardNN, self).__init__()
# self.in_features = in_features
# self.hidden_features = hidden_features
# self.out_features = out_features
# self.fc1 = UnitaryLayer(in_features, hidden_features)
# self.fc2 = UnitaryLayer(hidden_features, out_features)
# def forward(self, x):
# x = self.fc1(x)
# x = self.fc2(x)
# return x
if x.is_complex():
return x * (torch.angle(x) >= 0) * (torch.angle(x) <= torch.pi / 2)
else:
return torch.relu(x)

View File

@@ -41,9 +41,10 @@ def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=Fals
data = np.load(datapath)[skipfirst * sps : symbols * sps + skipfirst * sps]
if normalize:
a, b, c, d = data.T
# square gets normalized to 1, as the power is (proportional to) the square of the amplitude
a, b, c, d = np.square(data.T)
a, b, c, d = a/np.max(np.abs(a)), b/np.max(np.abs(b)), c/np.max(np.abs(c)), d/np.max(np.abs(d))
data = np.array([a, b, c, d]).T
data = np.sqrt(np.array([a, b, c, d]).T)
if real:
data = np.abs(data)
@@ -98,7 +99,7 @@ class FiberRegenerationDataset(Dataset):
file_path: str | Path,
symbols: int | float,
*,
data_size: int = None,
output_dim: int = None,
target_delay: float | int = 0,
xy_delay: float | int = 0,
drop_first: float | int = 0,
@@ -129,7 +130,7 @@ class FiberRegenerationDataset(Dataset):
assert isinstance(symbols, (float, int)), (
"symbols must be a float or an integer"
)
assert data_size is None or isinstance(data_size, int), (
assert output_dim is None or isinstance(output_dim, int), (
"output_len must be an integer"
)
assert isinstance(target_delay, (float, int)), (
@@ -142,7 +143,7 @@ class FiberRegenerationDataset(Dataset):
# check values
assert symbols > 0, "symbols must be positive"
assert data_size is None or data_size > 0, "output_len must be positive or None"
assert output_dim is None or output_dim > 0, "output_len must be positive or None"
assert drop_first >= 0, "drop_first must be non-negative"
faux = kwargs.pop("faux", False)
@@ -158,7 +159,7 @@ class FiberRegenerationDataset(Dataset):
"glova": {"sps": 128},
}
else:
data_raw, self.config = load_data(file_path, skipfirst=drop_first, real=real, normalize=True, device=device, dtype=dtype)
data_raw, self.config = load_data(file_path, skipfirst=drop_first, symbols=kwargs.pop("num_symbols", None), real=real, normalize=True, device=device, dtype=dtype)
self.device = data_raw.device
@@ -166,7 +167,7 @@ class FiberRegenerationDataset(Dataset):
self.samples_per_slice = int(symbols * self.samples_per_symbol)
self.symbols_per_slice = self.samples_per_slice / self.samples_per_symbol
self.data_size = data_size or self.samples_per_slice
self.output_dim = output_dim or self.samples_per_slice
self.target_delay = target_delay or 0
self.xy_delay = xy_delay or 0
@@ -261,13 +262,13 @@ class FiberRegenerationDataset(Dataset):
data, target = self.data[idx, 1].squeeze(), self.data[idx, 0].squeeze()
# reduce by by taking self.output_dim equally spaced samples
data = data[:, : data.shape[1] // self.data_size * self.data_size]
data = data.view(data.shape[0], self.data_size, -1)
data = data[:, : data.shape[1] // self.output_dim * self.output_dim]
data = data.view(data.shape[0], self.output_dim, -1)
data = data[:, :, 0]
# target is corresponding to the middle of the data as the output sample is influenced by the data before and after it
target = target[:, : target.shape[1] // self.data_size * self.data_size]
target = target.view(target.shape[0], self.data_size, -1)
target = target[:, : target.shape[1] // self.output_dim * self.output_dim]
target = target.view(target.shape[0], self.output_dim, -1)
target = target[:, 0, target.shape[2] // 2]
data = data.transpose(0, 1).flatten().squeeze()

View File

@@ -0,0 +1,21 @@
def multi_getattr(objs, attr, fallback=None):
"""
tries to get the attribute from a list of objects, returning the first hit
if no object has the attribute, it returns the fallback value if provided, otherwise raises AttributeError
"""
try:
return _multi_getattr(objs, attr)
except AttributeError as e:
if fallback is not None:
return fallback
raise e
def _multi_getattr(objs, attr):
if not isinstance(objs, (list, tuple)):
objs = [objs]
for obj in objs:
try:
return getattr(obj, attr)
except AttributeError:
pass
raise AttributeError(f"None of the objects has attribute {attr}")

View File

@@ -28,3 +28,18 @@ def optional_suggest_int(trial, name, range_or_value, step=None, log=False):
def optional_suggest_float(trial, name, range_or_value, step=None, log=False):
return _optional_suggest(trial, name, range_or_value, step=step, log=log, type='float')
def force_suggest_int(trial, name, range_or_value, step=1, log=False):
if not hasattr(range_or_value, '__iter__') or isinstance(range_or_value, str):
return trial.suggest_int(name, range_or_value, range_or_value, step=step, log=log)
return trial.suggest_int(name, *range_or_value, step=step, log=log)
def force_suggest_float(trial, name, range_or_value, step=None, log=False):
if not hasattr(range_or_value, '__iter__') or isinstance(range_or_value, str):
return trial.suggest_float(name, range_or_value, range_or_value, step=step, log=log)
return trial.suggest_float(name, *range_or_value, step=step, log=log)
def force_suggest_categorical(trial, name, range_or_value):
if not hasattr(range_or_value, '__iter__') or isinstance(range_or_value, str):
return trial.suggest_categorical(name, [range_or_value])
return trial.suggest_categorical(name, range_or_value)

View File

@@ -38,7 +38,7 @@ def eye(*, path=None, data=None, sps=None, title=None, symbols=1000, skipfirst=0
axs[0, 1].plot(xaxis, np.abs(outx), color="C0", alpha=alpha or 0.1)
axs[1, 0].plot(xaxis, np.abs(iny), color="C0", alpha=alpha or 0.1)
axs[1, 1].plot(xaxis, np.abs(outy), color="C0", alpha=alpha or 0.1)
axs[0,0].set_ylim(0, 1.1*np.max(np.abs(data)))
axs[0, 0].set_ylim(0, 1.1*np.max(np.abs(data)))
axs2[0, 0].plot(xaxis, np.angle(inx), color="C1", alpha=alpha or 0.1)
axs2[0, 1].plot(xaxis, np.angle(outx), color="C1", alpha=alpha or 0.1)