training loop speedup

This commit is contained in:
Joseph Hopfmüller
2024-11-20 11:29:18 +01:00
parent 1622c38582
commit cdca5de473
11 changed files with 1026 additions and 151 deletions

1
.gitattributes vendored
View File

@@ -1,4 +1,5 @@
data/**/* filter=lfs diff=lfs merge=lfs -text
data/*.db filter=lfs diff=lfs merge=lfs -text
data/*.ini filter=lfs diff=lfs merge=lfs text
## lfs setup

5
.gitignore vendored
View File

@@ -1,8 +1,5 @@
src/**/*.ini
.data
# VSCode
.vscode
.*
# Byte-compiled / optimized / DLL files
__pycache__/

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:72460af57347d35df91cd76982231bcf538a82fd7f1b8522795202fa298a2dcb
size 696320

View File

@@ -1,6 +1,6 @@
import copy
from dataclasses import dataclass
from datetime import datetime
import time
import matplotlib.pyplot as plt
import numpy as np
@@ -9,16 +9,21 @@ import warnings
import torch
import torch.nn as nn
import torch.functional as F
# import torch.nn.functional as F # mse_loss doesn't support complex numbers
import torch.optim as optim
import torch.utils.data
from torch.utils.tensorboard import SummaryWriter
from rich.progress import Progress, TextColumn, BarColumn, TaskProgressColumn, TimeRemainingColumn, MofNCompleteColumn
from rich.console import Console
import multiprocessing
from util.datasets import FiberRegenerationDataset
from util.complexNN import complex_sse_loss
from util.optuna_helpers import optional_suggest_categorical, optional_suggest_float, optional_suggest_int
import util
# global settings
@dataclass
class GlobalSettings:
@@ -28,12 +33,14 @@ class GlobalSettings:
# data settings
@dataclass
class DataSettings:
config_path: str = "data/*-128-16384-10000-0-0-17-0-PAM4-0.ini"
symbols_range: tuple = (1, 100)
data_size_range: tuple = (1, 20)
config_path: str = "data/*-128-16384-1000-0-0-17-0-PAM4-0.ini"
dtype: torch.dtype = torch.complex64
symbols_range: tuple|float|int = 16
data_size_range: tuple|float|int = 32
shuffle: bool = True
target_delay: float = 0
xy_delay_range: tuple = (0, 1)
drop_first: int = 1000
xy_delay_range: tuple|float|int = 0
drop_first: int = 10
train_split: float = 0.8
@@ -41,41 +48,46 @@ class DataSettings:
@dataclass
class PytorchSettings:
device: str = "cuda"
batchsize: int = 128
epochs: int = 100
batchsize: int = 1024
epochs: int = 10
summary_dir: str = ".runs"
# model settings
@dataclass
class ModelSettings:
output_size: int = 2
n_layer_range: tuple = (1, 3)
n_units_range: tuple = (4, 128)
activation_func_range: tuple = ("ReLU",)
n_layer_range: tuple|float|int = (2,8)
n_units_range: tuple|float|int = (2,32)
# activation_func_range: tuple = ("ReLU",)
@dataclass
class OptimizerSettings:
optimizer_range: tuple = ("Adam", "RMSprop", "SGD")
lr_range: tuple = (1e-5, 1e-1)
# optimizer_range: tuple|str = ("Adam", "RMSprop", "SGD")
optimizer_range: tuple|str = "RMSprop"
# lr_range: tuple|float = (1e-5, 1e-1)
lr_range: tuple|float = 2e-5
# optuna settings
@dataclass
class OptunaSettings:
n_trials: int = 128
n_threads: int = 16
n_threads: int = 8
timeout: int = 600
directions: tuple = ("maximize",)
directions: tuple = ("minimize",)
metrics_names: tuple = ("sse",)
limit_examples: bool = True
n_train_examples: int = PytorchSettings.batchsize * 30
n_valid_examples: int = PytorchSettings.batchsize * 10
n_train_examples: int = PytorchSettings.batchsize * 50
# n_valid_examples: int = PytorchSettings.batchsize * 100
n_valid_examples: int = float("inf")
storage: str = "sqlite:///optuna_single_core_regen.db"
study_name: str = (
f"single_core_regen_{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
f"single_core_regen_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}"
)
metrics_names: tuple = ("accuracy",)
class HyperTraining:
def __init__(self):
@@ -86,9 +98,43 @@ class HyperTraining:
self.optimizer_settings = OptimizerSettings()
self.optuna_settings = OptunaSettings()
self.console = Console()
# set some extra settings to make the code more readable
self._extra_optuna_settings()
def setup_tb_writer(self, study_name=None, append=None):
log_dir = self.pytorch_settings.summary_dir + "/" + (study_name or self.optuna_settings.study_name)
if append is not None:
log_dir += "_" + str(append)
return SummaryWriter(log_dir)
def resume_latest_study(self, verbose=True):
study_name = hyper_training.get_latest_study()
if study_name:
print(f"Resuming study: {study_name}")
self.optuna_settings.study_name = study_name
def get_latest_study(self, verbose=True):
studies = self.get_studies()
for study in studies:
study.datetime_start = study.datetime_start or datetime.min
if studies:
study = sorted(studies, key = lambda x: x.datetime_start, reverse=True)[0]
if verbose:
print(f"Last study: {study.study_name}")
study_name = study.study_name
else:
if verbose:
print("No previous studies found")
study_name = None
return study_name
def get_studies(self):
return optuna.get_all_study_summaries(storage=self.optuna_settings.storage)
def setup_study(self):
self.study = optuna.create_study(
study_name=self.optuna_settings.study_name,
@@ -101,28 +147,48 @@ class HyperTraining:
with warnings.catch_warnings(action="ignore"):
self.study.set_metric_names(self.optuna_settings.metrics_names)
self.n_threads = min(self.optuna_settings.n_trials, self.optuna_settings.n_threads)
self.n_threads = min(
self.optuna_settings.n_trials, self.optuna_settings.n_threads
)
self.processes = []
if self.n_threads > 1:
for _ in range(self.n_threads):
p = multiprocessing.Process(target=self._run_optimize)
p = multiprocessing.Process(
# target=lambda n_trials: self._run_optimize(self, n_trials),
target = self._run_optimize,
args = (self.optuna_settings.n_trials // self.n_threads,),
)
self.processes.append(p)
def run_study(self):
if self.processes:
for p in self.processes:
p.start()
for p in self.processes:
p.join()
remaining_trials = self.optuna_settings.n_trials - self.optuna_settings.n_trials % self.optuna_settings.n_threads
remaining_trials = (
self.optuna_settings.n_trials
- self.optuna_settings.n_trials % self.optuna_settings.n_threads
)
else:
remaining_trials = self.optuna_settings.n_trials
if remaining_trials:
self._run_optimize(remaining_trials)
def _run_optimize(self, n_trials):
self.study.optimize(self.objective, n_trials=n_trials, timeout=self.optuna_settings.timeout)
self.study.optimize(
self.objective, n_trials=n_trials, timeout=self.optuna_settings.timeout
)
def eye(self, show=True):
util.plot.eye(self.data_settings.config_path, show=show)
def plot_eye(self, show=True):
if not hasattr(self, "eye_data"):
data, config = util.datasets.load_data(
self.data_settings.config_path, skipfirst=10, symbols=1000
)
self.eye_data = {"data": data, "sps": int(config["glova"]["sps"])}
return util.plot.eye(**self.eye_data, show=show)
def _extra_optuna_settings(self):
self.optuna_settings.multi_objective = len(self.optuna_settings.directions) > 1
@@ -143,145 +209,256 @@ class HyperTraining:
else float("inf")
)
def define_model(self, trial: optuna.Trial):
n_layers = trial.suggest_int(
"model_n_layers", *self.model_settings.n_layer_range
def define_model(self, trial: optuna.Trial, writer=None):
n_layers = optional_suggest_int(trial, "model_n_layers", self.model_settings.n_layer_range)
in_features = 2 * trial.params.get(
"dataset_data_size",
optional_suggest_int(trial, "dataset_data_size", self.data_settings.data_size_range),
)
trial.set_user_attr("input_dim", in_features)
layers = []
# REVIEW does that work?
in_features = trial.params["dataset_data_size"] * 2
for i in range(n_layers):
out_features = trial.suggest_int(
f"model_n_units_l{i}", *self.model_settings.n_units_range
)
activation_func = trial.suggest_categorical(
f"model_activation_func_l{i}", self.model_settings.activation_func_range
)
out_features = optional_suggest_int(trial, f"model_n_units_l{i}", self.model_settings.n_units_range, log=True)
layers.append(nn.Linear(in_features, out_features))
layers.append(getattr(nn, activation_func))
layers.append(nn.Linear(in_features, out_features, dtype=self.data_settings.dtype))
# layers.append(getattr(nn, activation_func)())
in_features = out_features
layers.append(nn.Linear(in_features, self.model_settings.output_size))
layers.append(nn.Linear(in_features, self.model_settings.output_size, dtype=self.data_settings.dtype))
if writer is not None:
writer.add_graph(nn.Sequential(*layers), torch.zeros(1, trial.user_attrs["input_dim"], dtype=self.data_settings.dtype))
return nn.Sequential(*layers)
def get_sliced_data(self, trial: optuna.Trial):
assert ModelSettings.input_size % 2 == 0, "input_dim must be even"
symbols = trial.suggest_float(
"dataset_symbols", *self.data_settings.symbols_range, log=True
)
xy_delay = trial.suggest_float(
"dataset_xy_delay", *self.data_settings.xy_delay_range
)
data_size = trial.suggest_int(
"dataset_data_size", *self.data_settings.data_size_range
symbols = optional_suggest_float(trial, "dataset_symbols", self.data_settings.symbols_range)
xy_delay = optional_suggest_float(trial, "dataset_xy_delay", self.data_settings.xy_delay_range)
data_size = trial.params.get(
"dataset_data_size",
optional_suggest_int(trial, "dataset_data_size", self.data_settings.data_size_range)
)
# get dataset
dataset = FiberRegenerationDataset(
file_path=self.data_settings.config_path,
symbols=symbols,
data_size=data_size, # two channels (x,y)
data_size=data_size,
target_delay=self.data_settings.target_delay,
xy_delay=xy_delay,
drop_first=self.data_settings.drop_first,
dtype=self.data_settings.dtype,
)
dataset_size = len(dataset)
indices = list(range(dataset_size))
split = int(np.floor(self.data_settings.train_split * dataset_size))
if self.data_settings.shuffle:
np.random.seed(self.global_settings.seed)
np.random.shuffle(indices)
train_indices, valid_indices = indices[:split], indices[split:]
train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
valid_sampler = torch.utils.data.SubsetRandomSampler(valid_indices)
train_loader = torch.utils.data.DataLoader(
dataset, batch_size=self.pytorch_settings.batchsize, sampler=train_sampler
dataset, batch_size=self.pytorch_settings.batchsize, sampler=train_sampler, drop_last=True
)
valid_loader = torch.utils.data.DataLoader(
dataset, batch_size=self.pytorch_settings.batchsize, sampler=valid_sampler
dataset, batch_size=self.pytorch_settings.batchsize, sampler=valid_sampler, drop_last=True
)
return train_loader, valid_loader
def train_model(self, model, optimizer, train_loader, epoch, writer=None, enable_progress=True):
if enable_progress:
progress = Progress(
TextColumn("[yellow] Training..."),
TextColumn(" Loss: {task.description}"),
BarColumn(),
TaskProgressColumn(),
TextColumn("[green]Batch"),
MofNCompleteColumn(),
TimeRemainingColumn(),
# description="Training",
transient=False,
console=self.console,
refresh_per_second=10,
)
task = progress.add_task("-.---e--", total=len(train_loader))
def train_model(self, model, optimizer, train_loader):
running_loss = 0.0
last_loss = 0.0
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
if (batch_idx * train_loader.batchsize
>= self.optuna_settings.n_train_examples):
for batch_idx, (x, y) in enumerate(train_loader):
if (
batch_idx * train_loader.batch_size
>= self.optuna_settings.n_train_examples
):
break
optimizer.zero_grad()
data, target = (
data.to(self.pytorch_settings.device),
target.to(self.pytorch_settings.device),
x, y = (
x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device),
)
target_pred = model(data)
loss = F.mean_squared_error(target_pred, target)
y_pred = model(x)
loss = complex_sse_loss(y_pred, y)
loss.backward()
optimizer.step()
# clamp weights to keep energy bounded
for p in model.parameters():
p.data.clamp_(-1.0, 1.0)
last_loss = loss.item()
if enable_progress:
progress.update(task, advance=1, description=f"{last_loss:.3e}")
running_loss += loss.item()
if writer is not None:
if batch_idx % 10 == 0:
writer.add_scalar("training loss", running_loss/10, epoch*min(len(train_loader), self.optuna_settings.n_train_examples/train_loader.batch_size) + batch_idx)
running_loss = 0.0
if enable_progress:
progress.update(task, description=f"{last_loss:.3e}")
progress.stop()
def eval_model(self, model, valid_loader, epoch, writer=None, enable_progress=True):
if enable_progress:
progress = Progress(
TextColumn("[green]Evaluating..."),
TextColumn("Error: {task.description}"),
BarColumn(),
TaskProgressColumn(),
TextColumn("[green]Batch"),
MofNCompleteColumn(),
TimeRemainingColumn(),
# description="Training",
transient=False,
console=self.console,
refresh_per_second=10,
)
task = progress.add_task("-.---e--", total=len(valid_loader))
def eval_model(self, model, valid_loader):
model.eval()
correct = 0
running_error = 0
running_error_2 = 0
with torch.no_grad():
for batch_idx, (data, target) in enumerate(valid_loader):
for batch_idx, (x, y) in enumerate(valid_loader):
if (
batch_idx * valid_loader.batchsize
batch_idx * valid_loader.batch_size
>= self.optuna_settings.n_valid_examples
):
break
data, target = (
data.to(self.pytorch_settings.device),
target.to(self.pytorch_settings.device),
x, y = (
x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device),
)
target_pred = model(data)
pred = target_pred.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
y_pred = model(x)
error = complex_sse_loss(y_pred, y)
running_error += error.item()
running_error_2 += error.item()
if enable_progress:
progress.update(task, advance=1, description=f"{error.item():.3e}")
if writer is not None:
if batch_idx % 10 == 0:
writer.add_scalar("sse", running_error_2/10, epoch*min(len(valid_loader), self.optuna_settings.n_valid_examples/valid_loader.batch_size) + batch_idx)
running_error_2 = 0.0
running_error /= batch_idx + 1
if enable_progress:
progress.update(task, description=f"{running_error:.3e}")
progress.stop()
return running_error
def run_model(self, model, loader):
model.eval()
y_preds = []
with torch.no_grad():
for x, y in loader:
x, y = (
x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device),
)
y_preds.append(model(x))
return torch.stack(y_preds)
accuracy = correct / len(valid_loader.dataset)
# num_params = sum(p.numel() for p in model.parameters())
return accuracy
def objective(self, trial: optuna.Trial):
model = self.define_model(trial).to(self.pytorch_settings.device)
optimizer_name = trial.suggest_categorical(
"optimizer", self.optimizer_settings.optimizer_range
)
lr = trial.suggest_float("lr", *self.optimizer_settings.lr_range, log=True)
optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)
writer = self.setup_tb_writer(self.optuna_settings.study_name, f"{trial.number:0>len(str(self.optuna_settings.n_trials))}")
train_loader, valid_loader = self.get_sliced_data(trial)
for epoch in range(self.pytorch_settings.epochs):
self.train_model(model, optimizer, train_loader)
accuracy = self.eval_model(model, valid_loader)
model = self.define_model(trial, writer).to(self.pytorch_settings.device)
if len(self.optuna_settings.directions) == 1:
trial.report(accuracy, epoch)
optimizer_name = optional_suggest_categorical(trial, "optimizer", self.optimizer_settings.optimizer_range)
lr = optional_suggest_float(trial, "lr", self.optimizer_settings.lr_range, log=True)
optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)
for epoch in range(self.pytorch_settings.epochs):
enable_progress = self.optuna_settings.n_threads == 1
if enable_progress:
print(f"Epoch {epoch+1}/{self.pytorch_settings.epochs}")
self.train_model(model, optimizer, train_loader, epoch, writer, enable_progress=enable_progress)
sse = self.eval_model(model, valid_loader, epoch, writer, enable_progress=enable_progress)
if not self.optuna_settings.multi_objective:
trial.report(sse, epoch)
if trial.should_prune():
raise optuna.exceptions.TrialPruned()
return accuracy
writer.close()
return sse
if __name__ == "__main__":
# plt.ion()
hyper_training = HyperTraining()
hyper_training.eye()
# hyper_training.setup_study()
# hyper_training.run_study()
for i in range(10):
#simulate some work
print(i)
time.sleep(0.2)
# hyper_training.resume_latest_study()
hyper_training.setup_study()
hyper_training.run_study()
best_model = hyper_training.define_model(hyper_training.study.best_trial).to(hyper_training.pytorch_settings.device)
data_settings_backup = copy.copy(hyper_training.data_settings)
hyper_training.data_settings.shuffle = False
hyper_training.data_settings.train_split = 0.01
plot_loader, _ = hyper_training.get_sliced_data(hyper_training.study.best_trial)
regen = hyper_training.run_model(best_model, plot_loader)
regen = regen.view(-1, 2)
# [batch_no, batch_size, 2] -> [no, 2]
original, _ = util.datasets.load_data(hyper_training.data_settings.config_path, skipfirst=hyper_training.data_settings.drop_first)
original = original[:len(regen)]
regen = regen.cpu().numpy()
_, axs = plt.subplots(2)
for i, ax in enumerate(axs):
ax.plot(np.abs(original[:, i])**2, label="original")
ax.plot(np.abs(regen[:, i])**2, label="regen")
ax.legend()
plt.show()
...
print(f"Best model: {best_model}")
# eye_fig = hyper_training.plot_eye()
...

View File

@@ -0,0 +1,429 @@
import copy
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
# import torch.nn.functional as F # mse_loss doesn't support complex numbers
import torch.optim as optim
import torch.utils.data
from torch.utils.tensorboard import SummaryWriter
from rich.progress import (
Progress,
TextColumn,
BarColumn,
TaskProgressColumn,
TimeRemainingColumn,
MofNCompleteColumn,
TimeElapsedColumn,
)
from rich.console import Console
from rich import print as rprint
# from util.optuna_helpers import optional_suggest_categorical, optional_suggest_float, optional_suggest_int
import util
# global settings
@dataclass
class GlobalSettings:
seed: int = 42
# data settings
@dataclass
class DataSettings:
config_path: str = "data/*-128-16384-100000-0-0-17-0-PAM4-0.ini"
dtype: torch.dtype = torch.complex64
symbols_range: float | int = 8
data_size_range: float | int = 64
shuffle: bool = True
target_delay: float = 0
xy_delay_range: float | int = 0
drop_first: int = 10
train_split: float = 0.8
# pytorch settings
@dataclass
class PytorchSettings:
epochs: int = 1000
batchsize: int = 2**12
device: str = "cuda"
summary_dir: str = ".runs"
model_dir: str = ".models"
# model settings
@dataclass
class ModelSettings:
output_size: int = 2
# n_layer_range: float|int = 2
# n_units_range: float|int = 32
n_layers: int = 3
n_units: int = 32
activation_func: tuple | str = "ModReLU"
@dataclass
class OptimizerSettings:
optimizer_range: str = "Adam"
lr_range: float = 2e-3
class Training:
def __init__(self):
self.global_settings = GlobalSettings()
self.data_settings = DataSettings()
self.pytorch_settings = PytorchSettings()
self.model_settings = ModelSettings()
self.optimizer_settings = OptimizerSettings()
self.study_name = (
f"single_core_regen_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}"
)
if not hasattr(self.pytorch_settings, "model_dir"):
self.pytorch_settings.model_dir = ".models"
self.writer = None
self.console = Console()
def setup_tb_writer(self, study_name=None):
log_dir = (
self.pytorch_settings.summary_dir + "/" + (study_name or self.study_name)
)
self.writer = SummaryWriter(log_dir)
def plot_eye(self, width=2, symbols=None, alpha=None, complex=False, show=True):
if not hasattr(self, "eye_data"):
data, config = util.datasets.load_data(
self.data_settings.config_path,
skipfirst=10,
symbols=symbols or 1000,
real=not self.data_settings.dtype.is_complex,
normalize=True,
)
self.eye_data = {"data": data, "sps": int(config["glova"]["sps"])}
return util.plot.eye(
**self.eye_data,
width=width,
show=show,
alpha=alpha,
complex=complex,
symbols=symbols or 1000,
skipfirst=0,
)
def define_model(self):
n_layers = self.model_settings.n_layers
in_features = 2 * self.data_settings.data_size_range
layers = []
for i in range(n_layers):
out_features = self.model_settings.n_units
layers.append(util.complexNN.UnitaryLayer(in_features, out_features))
# layers.append(getattr(nn, self.model_settings.activation_func)())
layers.append(
getattr(util.complexNN, self.model_settings.activation_func)()
)
in_features = out_features
layers.append(
util.complexNN.UnitaryLayer(in_features, self.model_settings.output_size)
)
if self.writer is not None:
self.writer.add_graph(
nn.Sequential(*layers),
torch.zeros(1, layers[0].in_features, dtype=self.data_settings.dtype),
)
return nn.Sequential(*layers)
def get_sliced_data(self):
symbols = self.data_settings.symbols_range
xy_delay = self.data_settings.xy_delay_range
data_size = self.data_settings.data_size_range
# get dataset
dataset = util.datasets.FiberRegenerationDataset(
file_path=self.data_settings.config_path,
symbols=symbols,
data_size=data_size,
target_delay=self.data_settings.target_delay,
xy_delay=xy_delay,
drop_first=self.data_settings.drop_first,
dtype=self.data_settings.dtype,
real=not self.data_settings.dtype.is_complex,
# device=self.pytorch_settings.device,
)
dataset_size = len(dataset)
indices = list(range(dataset_size))
split = int(np.floor(self.data_settings.train_split * dataset_size))
if self.data_settings.shuffle:
np.random.seed(self.global_settings.seed)
np.random.shuffle(indices)
train_indices, valid_indices = indices[:split], indices[split:]
if self.data_settings.shuffle:
train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
valid_sampler = torch.utils.data.SubsetRandomSampler(valid_indices)
else:
train_sampler = train_indices
valid_sampler = valid_indices
train_loader = torch.utils.data.DataLoader(
dataset,
batch_size=self.pytorch_settings.batchsize,
sampler=train_sampler,
drop_last=True,
pin_memory=True,
num_workers=24,
prefetch_factor=4,
# persistent_workers=True
)
valid_loader = torch.utils.data.DataLoader(
dataset,
batch_size=self.pytorch_settings.batchsize,
sampler=valid_sampler,
drop_last=True,
pin_memory=True,
num_workers=24,
prefetch_factor=4,
# persistent_workers=True
)
return train_loader, valid_loader
def train_model(self, model, optimizer, train_loader, epoch):
with Progress(
TextColumn("[yellow] Training..."),
TextColumn("Loss: {task.description}"),
BarColumn(),
TaskProgressColumn(),
TextColumn("[green]Batch"),
MofNCompleteColumn(),
TimeRemainingColumn(),
TimeElapsedColumn(),
# description="Training",
transient=False,
console=self.console,
refresh_per_second=10,
) as progress:
task = progress.add_task("-.---e--", total=len(train_loader))
running_loss = 0.0
model.train()
for batch_idx, (x, y) in enumerate(train_loader):
model.zero_grad(set_to_none=True)
x, y = (
x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device),
)
y_pred = model(x)
loss = util.complexNN.complex_mse_loss(y_pred, y)
loss.backward()
optimizer.step()
progress.update(task, advance=1, description=f"{loss.item():.3e}")
running_loss += loss.item()
if self.writer is not None:
if (batch_idx + 1) % 10 == 0:
self.writer.add_scalar(
"training loss",
running_loss / 10,
epoch * len(train_loader) + batch_idx,
)
running_loss = 0.0
return running_loss
def eval_model(self, model, valid_loader, epoch):
with Progress(
TextColumn("[green]Evaluating..."),
TextColumn("Loss: {task.description}"),
BarColumn(),
TaskProgressColumn(),
TextColumn("[green]Batch"),
MofNCompleteColumn(),
TimeRemainingColumn(),
TimeElapsedColumn(),
# description="Training",
transient=False,
console=self.console,
refresh_per_second=10,
) as progress:
task = progress.add_task("-.---e--", total=len(valid_loader))
model.eval()
running_loss = 0
running_loss2 = 0
with torch.no_grad():
for batch_idx, (x, y) in enumerate(valid_loader):
x, y = (
x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device),
)
y_pred = model(x)
loss = util.complexNN.complex_mse_loss(y_pred, y)
running_loss += loss.item()
running_loss2 += loss.item()
progress.update(task, advance=1, description=f"{loss.item():.3e}")
if self.writer is not None:
if (batch_idx + 1) % 10 == 0:
self.writer.add_scalar(
"loss",
running_loss / 10,
epoch * len(valid_loader) + batch_idx,
)
running_loss = 0.0
if self.writer is not None:
self.writer.add_figure("fiber response", self.plot_model_response(model, plot=False), epoch+1)
return running_loss2 / len(valid_loader)
def run_model(self, model, loader):
model.eval()
xs = []
ys = []
y_preds = []
with torch.no_grad():
model = model.to(self.pytorch_settings.device)
for x, y in loader:
x, y = (
x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device),
)
y_pred = model(x).cpu()
# x = x.cpu()
# y = y.cpu()
y_pred = y_pred.view(y_pred.shape[0], -1, 2)
y = y.view(y.shape[0], -1, 2)
x = x.view(x.shape[0], -1, 2)
xs.append(x[:, 0, :].squeeze())
ys.append(y.squeeze())
y_preds.append(y_pred.squeeze())
xs = torch.vstack(xs).cpu()
ys = torch.vstack(ys).cpu()
y_preds = torch.vstack(y_preds).cpu()
return ys, xs, y_preds
def dummy_model(self, loader):
xs = []
ys = []
for x, y in loader:
y = y.cpu().view(y.shape[0], -1, 2)
x = x.cpu().view(x.shape[0], -1, 2)
xs.append(x[:, 0, :].squeeze())
ys.append(y.squeeze())
xs = torch.vstack(xs)
ys = torch.vstack(ys)
return xs, ys
def objective(self, save=False, plot_before=False):
try:
rprint(*list(self.study_name.split("_")))
self.model = self.define_model().to(self.pytorch_settings.device)
if self.writer is not None:
self.writer.add_figure("fiber response", self.plot_model_response(plot=plot_before), 0)
train_loader, valid_loader = self.get_sliced_data()
optimizer_name = self.optimizer_settings.optimizer_range
lr = self.optimizer_settings.lr_range
optimizer = getattr(optim, optimizer_name)(self.model.parameters(), lr=lr)
for epoch in range(self.pytorch_settings.epochs):
self.console.rule(f"Epoch {epoch + 1}/{self.pytorch_settings.epochs}")
self.train_model(self.model, optimizer, train_loader, epoch)
eval_loss = self.eval_model(self.model, valid_loader, epoch)
if save:
save_path = (
Path(self.pytorch_settings.model_dir) / f"{self.study_name}.pth"
)
save_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(self.model, save_path)
return eval_loss
except KeyboardInterrupt:
pass
finally:
if hasattr(self, "model"):
except_save_path = Path(".models/exception") / f"{self.study_name}.pth"
except_save_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(self.model, except_save_path)
def _plot_model_response_plotter(self, fiber_in, fiber_out, regen, plot=True):
fig, axs = plt.subplots(2)
for i, ax in enumerate(axs):
ax.plot(np.abs(fiber_in[:, i]) ** 2, label="fiber in")
ax.plot(np.abs(fiber_out[:, i]) ** 2, label="fiber out")
ax.plot(np.abs(regen[:, i]) ** 2, label="regenerated")
ax.legend()
if plot:
plt.show()
return fig
def plot_model_response(self, model=None, plot=True):
data_settings_backup = copy.copy(self.data_settings)
self.data_settings.shuffle = False
self.data_settings.train_split = 0.01
self.data_settings.drop_first = 100
plot_loader, _ = self.get_sliced_data()
self.data_settings = data_settings_backup
fiber_in, fiber_out, regen = self.run_model(model or self.model, plot_loader)
fiber_in = fiber_in.view(-1, 2)
fiber_out = fiber_out.view(-1, 2)
regen = regen.view(-1, 2)
fiber_in = fiber_in.numpy()
fiber_out = fiber_out.numpy()
regen = regen.numpy()
# https://github.com/matplotlib/matplotlib/issues/27713#issue-2104110987
# https://github.com/matplotlib/matplotlib/issues/27713#issuecomment-1915497463
import gc
fig = self._plot_model_response_plotter(fiber_in, fiber_out, regen, plot=plot)
gc.collect()
return fig
if __name__ == "__main__":
trainer = Training()
# trainer.plot_eye()
trainer.setup_tb_writer()
trainer.objective(save=True)
best_model = trainer.model
# best_model = trainer.define_model(trainer.study.best_trial).to(trainer.pytorch_settings.device)
trainer.plot_model_response(best_model)
# print(f"Best model: {best_model}")
...

View File

@@ -1,2 +1,17 @@
from .datasets import FiberRegenerationDataset # noqa: F401
from .plot import eye # noqa: F401
from . import datasets # noqa: F401
# from .datasets import FiberRegenerationDataset # noqa: F401
# from .datasets import load_data # noqa: F401
from . import plot # noqa: F401
# from .plot import eye # noqa: F401
from . import optuna_helpers # noqa: F401
# from .optuna_helpers import optional_suggest_categorical # noqa: F401
# from .optuna_helpers import optional_suggest_float # noqa: F401
# from .optuna_helpers import optional_suggest_int # noqa: F401
from . import complexNN # noqa: F401
# from .complexNN import UnitaryLayer # noqa: F401
# from .complexNN import complex_mse_loss # noqa: F401
# from .complexNN import complex_sse_loss # noqa: F401

View File

@@ -0,0 +1,141 @@
import torch
import torch.nn as nn
def complex_mse_loss(input, target):
"""
Compute the mean squared error between two complex tensors.
"""
return torch.mean(torch.square(input.real - target.real) + torch.square(input.imag - target.imag))
def complex_sse_loss(input, target):
"""
Compute the sum squared error between two complex tensors.
"""
if input.is_complex():
return torch.sum(torch.square(input.real - target.real) + torch.square(input.imag - target.imag))
else:
return torch.sum(torch.square(input - target))
class UnitaryLayer(nn.Module):
def __init__(self, in_features, out_features):
super(UnitaryLayer, self).__init__()
assert in_features >= out_features
self.in_features = in_features
self.out_features = out_features
self.weight = nn.Parameter(torch.randn(in_features, out_features, dtype=torch.cfloat))
self.reset_parameters()
def reset_parameters(self):
q, _ = torch.linalg.qr(self.weight)
self.weight.data = q
@staticmethod
@torch.jit.script
def _unitary_forward(x, weight):
out = torch.matmul(x, weight)
return out
def forward(self, x):
return self._unitary_forward(x, self.weight)
#### as defined by zhang et al
class Identity(nn.Module):
"""
implements the "activation" function
M(z) = z
"""
def __init__(self):
super(Identity, self).__init__()
def forward(self, x):
return x
class Mag(nn.Module):
"""
implements the activation function
M(z) = ||z||
"""
def __init__(self):
super(Mag, self).__init__()
@torch.jit.script
def forward(self, x):
return torch.abs(x.real**2 + x.imag**2)
# class Tanh(nn.Module):
# """
# implements the activation function
# M(z) = tanh(z) = sinh(z)/cosh(z) = (exp(z)-exp(-z))/(exp(z)+exp(-z)) = (exp(2*z)-1)/(exp(2*z)+1)
# """
# def __init__(self):
# super(Tanh, self).__init__()
# def forward(self, x):
# return torch.tanh(x)
class ModReLU(nn.Module):
"""
implements the activation function
M(z) = ReLU(||z|| + b)*exp(j*theta_z)
= ReLU(||z|| + b)*z/||z||
"""
def __init__(self, b=0):
super(ModReLU, self).__init__()
self.b = b
self.relu = nn.ReLU()
@staticmethod
# @torch.jit.script
def _mod_relu(x, b):
mod = torch.abs(x.real**2 + x.imag**2)
return torch.relu(mod + b) * x / mod
def forward(self, x):
return self._mod_relu(x, self.b)
class CReLU(nn.Module):
"""
implements the activation function
M(z) = ReLU(Re(z)) + j*ReLU(Im(z))
"""
def __init__(self):
super(CReLU, self).__init__()
self.relu = nn.ReLU()
@torch.jit.script
def forward(self, x):
return torch.relu(x.real) + 1j*torch.relu(x.imag)
class ZReLU(nn.Module):
"""
implements the activation function
M(z) = z if 0 <= angle(z) <= pi/2
= 0 otherwise
"""
def __init__(self):
super(ZReLU, self).__init__()
@torch.jit.script
def forward(self, x):
return x * (torch.angle(x) >= 0) * (torch.angle(x) <= torch.pi/2)
# class ComplexFeedForwardNN(nn.Module):
# def __init__(self, in_features, hidden_features, out_features):
# super(ComplexFeedForwardNN, self).__init__()
# self.in_features = in_features
# self.hidden_features = hidden_features
# self.out_features = out_features
# self.fc1 = UnitaryLayer(in_features, hidden_features)
# self.fc2 = UnitaryLayer(hidden_features, out_features)
# def forward(self, x):
# x = self.fc1(x)
# x = self.fc2(x)
# return x

View File

@@ -1,11 +1,28 @@
from pathlib import Path
import torch
from torch.utils.data import Dataset
# from torch.utils.data import Sampler
import numpy as np
import configparser
# class SubsetSampler(Sampler[int]):
# """
# Samples elements from a given list of indices.
def load_data(config_path, skipfirst=0, num_symbols=None):
# :param indices: List of indices to sample from.
# :type indices: list[int]
# """
# def __init__(self, indices):
# self.indices = indices
# def __iter__(self):
# return iter(self.indices)
# def __len__(self):
# return len(self.indices)
def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=False, device=None, dtype=None):
filepath = Path(config_path)
filepath = filepath.parent.glob(filepath.name)
config = configparser.ConfigParser()
@@ -18,15 +35,25 @@ def load_data(config_path, skipfirst=0, num_symbols=None):
datapath = Path("/".join(path_elements).replace('"', ""))
sps = int(config["glova"]["sps"])
if num_symbols is None:
num_symbols = int(config["glova"]["nos"]) - skipfirst
if symbols is None:
symbols = int(config["glova"]["nos"]) - skipfirst
data = np.load(datapath)[skipfirst * sps : num_symbols * sps + skipfirst * sps]
config["glova"]["nos"] = str(num_symbols)
data = np.load(datapath)[skipfirst * sps : symbols * sps + skipfirst * sps]
if normalize:
a, b, c, d = data.T
a, b, c, d = a/np.max(np.abs(a)), b/np.max(np.abs(b)), c/np.max(np.abs(c)), d/np.max(np.abs(d))
data = np.array([a, b, c, d]).T
if real:
data = np.abs(data)
config["glova"]["nos"] = str(symbols)
data = torch.tensor(data, device=device, dtype=dtype)
return data, config
def roll_along(arr, shifts, dim):
# https://stackoverflow.com/a/76920720
# (c) Mateen Ulhaq, 2023
@@ -39,7 +66,6 @@ def roll_along(arr, shifts, dim):
indices = (dim_indices - shifts.unsqueeze(dim)) % arr.shape[dim]
return torch.gather(arr, dim, indices)
class FiberRegenerationDataset(Dataset):
"""
Dataset for fiber regeneration training.
@@ -76,6 +102,9 @@ class FiberRegenerationDataset(Dataset):
target_delay: float | int = 0,
xy_delay: float | int = 0,
drop_first: float | int = 0,
dtype: torch.dtype = None,
real: bool = False,
device = None,
**kwargs,
):
"""
@@ -123,12 +152,15 @@ class FiberRegenerationDataset(Dataset):
[[i + 0.1j, i + 0.2j, i + 1.1j, i + 1.2j] for i in range(12800)],
dtype=np.complex128,
)
data_raw = torch.tensor(data_raw, device=device, dtype=dtype)
self.config = {
"data": {"dir": '"."', "npy_dir": '"."', "file": "faux"},
"glova": {"sps": 128},
}
else:
data_raw, self.config = load_data(file_path)
data_raw, self.config = load_data(file_path, skipfirst=drop_first, real=real, normalize=True, device=device, dtype=dtype)
self.device = data_raw.device
self.samples_per_symbol = int(self.config["glova"]["sps"])
self.samples_per_slice = int(symbols * self.samples_per_symbol)
@@ -140,7 +172,6 @@ class FiberRegenerationDataset(Dataset):
ovrd_target_delay_samples = kwargs.pop("ovrd_target_delay_samples", None)
ovrd_xy_delay_samples = kwargs.pop("ovrd_xy_delay_samples", None)
ovrd_drop_first_samples = kwargs.pop("ovrd_drop_first_samples", None)
self.target_delay_samples = (
ovrd_target_delay_samples
@@ -152,14 +183,8 @@ class FiberRegenerationDataset(Dataset):
if ovrd_xy_delay_samples is not None
else int(self.xy_delay * self.samples_per_symbol)
)
drop_first_samples = (
ovrd_drop_first_samples
if ovrd_drop_first_samples is not None
else int(drop_first * self.samples_per_symbol)
)
# drop samples from the beginning
data_raw = data_raw[drop_first_samples:]
# data_raw = torch.tensor(data_raw, dtype=dtype)
# data layout
# [ [E_in_x0, E_in_y0, E_out_x0, E_out_y0],
@@ -240,10 +265,10 @@ class FiberRegenerationDataset(Dataset):
data = data.view(data.shape[0], self.data_size, -1)
data = data[:, :, 0]
# target is corresponding to the latest data point -> try to regenerate that
# target is corresponding to the middle of the data as the output sample is influenced by the data before and after it
target = target[:, : target.shape[1] // self.data_size * self.data_size]
target = target.view(target.shape[0], self.data_size, -1)
target = target[:, 0, 0]
target = target[:, 0, target.shape[2] // 2]
data = data.transpose(0, 1).flatten().squeeze()
target = target.flatten().squeeze()

View File

@@ -0,0 +1,30 @@
def _optional_suggest(trial, name, range_or_value, log=False, step=None, type='int'):
# not a range
if not hasattr(range_or_value, '__iter__') or isinstance(range_or_value, str):
return range_or_value
# range with only one value
if len(range_or_value) == 1:
return range_or_value[0]
if type == 'int':
step = step or 1
return trial.suggest_int(name, *range_or_value, step=step, log=log)
if type == 'float':
return trial.suggest_float(name, *range_or_value, step=step, log=log)
if type == 'categorical':
return trial.suggest_categorical(name, range_or_value)
raise ValueError(f"Unknown type: {type}")
def optional_suggest_categorical(trial, name, choices_or_value):
return _optional_suggest(trial, name, choices_or_value, type='categorical')
def optional_suggest_int(trial, name, range_or_value, step=None, log=False):
return _optional_suggest(trial, name, range_or_value, step=step, log=log, type='int')
def optional_suggest_float(trial, name, range_or_value, step=None, log=False):
return _optional_suggest(trial, name, range_or_value, step=step, log=log, type='float')

View File

@@ -0,0 +1,18 @@
from dash import Dash, dcc, html
import logging
import dash_bootstrap_components as dbc
def show_figures(*figures):
for figure in figures:
figure.layout.template = 'plotly_dark'
app = Dash(external_stylesheets=[dbc.themes.DARKLY])
app.layout = html.Div([
dcc.Graph(figure=figure) for figure in figures
])
log = logging.getLogger('werkzeug')
log.setLevel(logging.ERROR)
app.show = lambda *args, **kwargs: app.run_server(*args, **kwargs, debug=False)
return app

View File

@@ -2,33 +2,72 @@ import matplotlib.pyplot as plt
import numpy as np
from .datasets import load_data
def eye(path, title=None, head=1000, skipfirst=1000, show=True):
def eye(*, path=None, data=None, sps=None, title=None, symbols=1000, skipfirst=0, width=2, alpha=None, complex=False, show=True):
"""Plot an eye diagram for the data given by filepath.
Either path or data and sps must be given.
Args:
path (str): Path to the data description file.
data (np.ndarray): Data to plot.
sps (int): Samples per symbol.
title (str): Title of the plot.
head (int): Number of symbols to plot.
skipfirst (int): Number of symbols to skip.
show (bool): Whether to call plt.show().
"""
data, config = load_data(path, skipfirst, head)
if path is None and data is None:
raise ValueError("Either path or data and sps must be given.")
if path is not None:
data, config = load_data(path, skipfirst, symbols)
sps = int(config["glova"]["sps"])
if sps is None:
raise ValueError("sps not set.")
xaxis = np.linspace(0, 2, 2*sps, endpoint=False)
xaxis = np.linspace(0, width, width*sps, endpoint=False)
fig, axs = plt.subplots(2, 2, figsize=(10, 10), sharex=True, sharey=True)
for i in range(head-1):
inx, iny, outx, outy = data[i*sps:(i+2)*sps].T
axs[0, 0].plot(xaxis, np.abs(inx)**2, color="C0", alpha=0.1)
axs[0, 1].plot(xaxis, np.abs(outx)**2, color="C0", alpha=0.1)
axs[1, 0].plot(xaxis, np.abs(iny)**2, color="C0", alpha=0.1)
axs[1, 1].plot(xaxis, np.abs(outy)**2, color="C0", alpha=0.1)
if complex:
# create secondary axis for phase
axs2 = axs[0, 0].twinx(), axs[0, 1].twinx(), axs[1, 0].twinx(), axs[1, 1].twinx()
axs2 = np.reshape(axs2, (2, 2))
for i in range(symbols-(width-1)):
inx, iny, outx, outy = data[i*sps:(i+width)*sps].T
if complex:
axs[0, 0].plot(xaxis, np.abs(inx), color="C0", alpha=alpha or 0.1)
axs[0, 1].plot(xaxis, np.abs(outx), color="C0", alpha=alpha or 0.1)
axs[1, 0].plot(xaxis, np.abs(iny), color="C0", alpha=alpha or 0.1)
axs[1, 1].plot(xaxis, np.abs(outy), color="C0", alpha=alpha or 0.1)
axs[0,0].set_ylim(0, 1.1*np.max(np.abs(data)))
axs2[0, 0].plot(xaxis, np.angle(inx), color="C1", alpha=alpha or 0.1)
axs2[0, 1].plot(xaxis, np.angle(outx), color="C1", alpha=alpha or 0.1)
axs2[1, 0].plot(xaxis, np.angle(iny), color="C1", alpha=alpha or 0.1)
axs2[1, 1].plot(xaxis, np.angle(outy), color="C1", alpha=alpha or 0.1)
else:
axs[0, 0].plot(xaxis, np.abs(inx)**2, color="C0", alpha=alpha or 0.1)
axs[0, 1].plot(xaxis, np.abs(outx)**2, color="C0", alpha=alpha or 0.1)
axs[1, 0].plot(xaxis, np.abs(iny)**2, color="C0", alpha=alpha or 0.1)
axs[1, 1].plot(xaxis, np.abs(outy)**2, color="C0", alpha=alpha or 0.1)
if complex:
axs2[0, 0].sharey(axs2[0, 1])
axs2[0, 1].sharey(axs2[1, 0])
axs2[1, 0].sharey(axs2[1, 1])
# make y axis symmetric
ylim = np.max(np.abs(np.angle(data)))*1.1
if ylim != 0:
axs2[0, 0].set_ylim(-ylim, ylim)
else:
axs[0,0].set_ylim(0, 1.1*np.max(np.abs(data))**2)
axs[0, 0].set_title("Input x")
axs[0, 1].set_title("Output x")
axs[1, 0].set_title("Input y")
axs[1, 1].set_title("Output y")
fig.suptitle(title)
fig.suptitle(title or "Eye diagram")
if show:
plt.show(block=False)
plt.show()
return fig