training loop speedup

This commit is contained in:
Joseph Hopfmüller
2024-11-20 11:29:18 +01:00
parent 1622c38582
commit cdca5de473
11 changed files with 1026 additions and 151 deletions

1
.gitattributes vendored
View File

@@ -1,4 +1,5 @@
data/**/* filter=lfs diff=lfs merge=lfs -text data/**/* filter=lfs diff=lfs merge=lfs -text
data/*.db filter=lfs diff=lfs merge=lfs -text
data/*.ini filter=lfs diff=lfs merge=lfs text data/*.ini filter=lfs diff=lfs merge=lfs text
## lfs setup ## lfs setup

5
.gitignore vendored
View File

@@ -1,8 +1,5 @@
src/**/*.ini src/**/*.ini
.data .*
# VSCode
.vscode
# Byte-compiled / optimized / DLL files # Byte-compiled / optimized / DLL files
__pycache__/ __pycache__/

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:72460af57347d35df91cd76982231bcf538a82fd7f1b8522795202fa298a2dcb
size 696320

View File

@@ -1,6 +1,6 @@
import copy
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
import time
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
@@ -9,16 +9,21 @@ import warnings
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.functional as F # import torch.nn.functional as F # mse_loss doesn't support complex numbers
import torch.optim as optim import torch.optim as optim
import torch.utils.data import torch.utils.data
from torch.utils.tensorboard import SummaryWriter
from rich.progress import Progress, TextColumn, BarColumn, TaskProgressColumn, TimeRemainingColumn, MofNCompleteColumn
from rich.console import Console
import multiprocessing import multiprocessing
from util.datasets import FiberRegenerationDataset from util.datasets import FiberRegenerationDataset
from util.complexNN import complex_sse_loss
from util.optuna_helpers import optional_suggest_categorical, optional_suggest_float, optional_suggest_int
import util import util
# global settings # global settings
@dataclass @dataclass
class GlobalSettings: class GlobalSettings:
@@ -28,12 +33,14 @@ class GlobalSettings:
# data settings # data settings
@dataclass @dataclass
class DataSettings: class DataSettings:
config_path: str = "data/*-128-16384-10000-0-0-17-0-PAM4-0.ini" config_path: str = "data/*-128-16384-1000-0-0-17-0-PAM4-0.ini"
symbols_range: tuple = (1, 100) dtype: torch.dtype = torch.complex64
data_size_range: tuple = (1, 20) symbols_range: tuple|float|int = 16
data_size_range: tuple|float|int = 32
shuffle: bool = True
target_delay: float = 0 target_delay: float = 0
xy_delay_range: tuple = (0, 1) xy_delay_range: tuple|float|int = 0
drop_first: int = 1000 drop_first: int = 10
train_split: float = 0.8 train_split: float = 0.8
@@ -41,41 +48,46 @@ class DataSettings:
@dataclass @dataclass
class PytorchSettings: class PytorchSettings:
device: str = "cuda" device: str = "cuda"
batchsize: int = 128 batchsize: int = 1024
epochs: int = 100 epochs: int = 10
summary_dir: str = ".runs"
# model settings # model settings
@dataclass @dataclass
class ModelSettings: class ModelSettings:
output_size: int = 2 output_size: int = 2
n_layer_range: tuple = (1, 3) n_layer_range: tuple|float|int = (2,8)
n_units_range: tuple = (4, 128) n_units_range: tuple|float|int = (2,32)
activation_func_range: tuple = ("ReLU",) # activation_func_range: tuple = ("ReLU",)
@dataclass @dataclass
class OptimizerSettings: class OptimizerSettings:
optimizer_range: tuple = ("Adam", "RMSprop", "SGD") # optimizer_range: tuple|str = ("Adam", "RMSprop", "SGD")
lr_range: tuple = (1e-5, 1e-1) optimizer_range: tuple|str = "RMSprop"
# lr_range: tuple|float = (1e-5, 1e-1)
lr_range: tuple|float = 2e-5
# optuna settings # optuna settings
@dataclass @dataclass
class OptunaSettings: class OptunaSettings:
n_trials: int = 128 n_trials: int = 128
n_threads: int = 16 n_threads: int = 8
timeout: int = 600 timeout: int = 600
directions: tuple = ("maximize",) directions: tuple = ("minimize",)
metrics_names: tuple = ("sse",)
limit_examples: bool = True limit_examples: bool = True
n_train_examples: int = PytorchSettings.batchsize * 30 n_train_examples: int = PytorchSettings.batchsize * 50
n_valid_examples: int = PytorchSettings.batchsize * 10 # n_valid_examples: int = PytorchSettings.batchsize * 100
n_valid_examples: int = float("inf")
storage: str = "sqlite:///optuna_single_core_regen.db" storage: str = "sqlite:///optuna_single_core_regen.db"
study_name: str = ( study_name: str = (
f"single_core_regen_{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" f"single_core_regen_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}"
) )
metrics_names: tuple = ("accuracy",)
class HyperTraining: class HyperTraining:
def __init__(self): def __init__(self):
@@ -86,9 +98,43 @@ class HyperTraining:
self.optimizer_settings = OptimizerSettings() self.optimizer_settings = OptimizerSettings()
self.optuna_settings = OptunaSettings() self.optuna_settings = OptunaSettings()
self.console = Console()
# set some extra settings to make the code more readable # set some extra settings to make the code more readable
self._extra_optuna_settings() self._extra_optuna_settings()
def setup_tb_writer(self, study_name=None, append=None):
log_dir = self.pytorch_settings.summary_dir + "/" + (study_name or self.optuna_settings.study_name)
if append is not None:
log_dir += "_" + str(append)
return SummaryWriter(log_dir)
def resume_latest_study(self, verbose=True):
study_name = hyper_training.get_latest_study()
if study_name:
print(f"Resuming study: {study_name}")
self.optuna_settings.study_name = study_name
def get_latest_study(self, verbose=True):
studies = self.get_studies()
for study in studies:
study.datetime_start = study.datetime_start or datetime.min
if studies:
study = sorted(studies, key = lambda x: x.datetime_start, reverse=True)[0]
if verbose:
print(f"Last study: {study.study_name}")
study_name = study.study_name
else:
if verbose:
print("No previous studies found")
study_name = None
return study_name
def get_studies(self):
return optuna.get_all_study_summaries(storage=self.optuna_settings.storage)
def setup_study(self): def setup_study(self):
self.study = optuna.create_study( self.study = optuna.create_study(
study_name=self.optuna_settings.study_name, study_name=self.optuna_settings.study_name,
@@ -100,29 +146,49 @@ class HyperTraining:
with warnings.catch_warnings(action="ignore"): with warnings.catch_warnings(action="ignore"):
self.study.set_metric_names(self.optuna_settings.metrics_names) self.study.set_metric_names(self.optuna_settings.metrics_names)
self.n_threads = min(self.optuna_settings.n_trials, self.optuna_settings.n_threads) self.n_threads = min(
self.optuna_settings.n_trials, self.optuna_settings.n_threads
)
self.processes = [] self.processes = []
for _ in range(self.n_threads): if self.n_threads > 1:
p = multiprocessing.Process(target=self._run_optimize) for _ in range(self.n_threads):
self.processes.append(p) p = multiprocessing.Process(
# target=lambda n_trials: self._run_optimize(self, n_trials),
target = self._run_optimize,
args = (self.optuna_settings.n_trials // self.n_threads,),
)
self.processes.append(p)
def run_study(self): def run_study(self):
for p in self.processes: if self.processes:
p.start() for p in self.processes:
p.start()
for p in self.processes: for p in self.processes:
p.join() p.join()
remaining_trials = self.optuna_settings.n_trials - self.optuna_settings.n_trials % self.optuna_settings.n_threads remaining_trials = (
self.optuna_settings.n_trials
- self.optuna_settings.n_trials % self.optuna_settings.n_threads
)
else:
remaining_trials = self.optuna_settings.n_trials
if remaining_trials: if remaining_trials:
self._run_optimize(remaining_trials) self._run_optimize(remaining_trials)
def _run_optimize(self, n_trials):
self.study.optimize(self.objective, n_trials=n_trials, timeout=self.optuna_settings.timeout)
def eye(self, show=True): def _run_optimize(self, n_trials):
util.plot.eye(self.data_settings.config_path, show=show) self.study.optimize(
self.objective, n_trials=n_trials, timeout=self.optuna_settings.timeout
)
def plot_eye(self, show=True):
if not hasattr(self, "eye_data"):
data, config = util.datasets.load_data(
self.data_settings.config_path, skipfirst=10, symbols=1000
)
self.eye_data = {"data": data, "sps": int(config["glova"]["sps"])}
return util.plot.eye(**self.eye_data, show=show)
def _extra_optuna_settings(self): def _extra_optuna_settings(self):
self.optuna_settings.multi_objective = len(self.optuna_settings.directions) > 1 self.optuna_settings.multi_objective = len(self.optuna_settings.directions) > 1
@@ -143,145 +209,256 @@ class HyperTraining:
else float("inf") else float("inf")
) )
def define_model(self, trial: optuna.Trial): def define_model(self, trial: optuna.Trial, writer=None):
n_layers = trial.suggest_int( n_layers = optional_suggest_int(trial, "model_n_layers", self.model_settings.n_layer_range)
"model_n_layers", *self.model_settings.n_layer_range
in_features = 2 * trial.params.get(
"dataset_data_size",
optional_suggest_int(trial, "dataset_data_size", self.data_settings.data_size_range),
) )
trial.set_user_attr("input_dim", in_features)
layers = [] layers = []
# REVIEW does that work?
in_features = trial.params["dataset_data_size"] * 2
for i in range(n_layers): for i in range(n_layers):
out_features = trial.suggest_int( out_features = optional_suggest_int(trial, f"model_n_units_l{i}", self.model_settings.n_units_range, log=True)
f"model_n_units_l{i}", *self.model_settings.n_units_range
) layers.append(nn.Linear(in_features, out_features, dtype=self.data_settings.dtype))
activation_func = trial.suggest_categorical( # layers.append(getattr(nn, activation_func)())
f"model_activation_func_l{i}", self.model_settings.activation_func_range
)
layers.append(nn.Linear(in_features, out_features))
layers.append(getattr(nn, activation_func))
in_features = out_features in_features = out_features
layers.append(nn.Linear(in_features, self.model_settings.output_size)) layers.append(nn.Linear(in_features, self.model_settings.output_size, dtype=self.data_settings.dtype))
if writer is not None:
writer.add_graph(nn.Sequential(*layers), torch.zeros(1, trial.user_attrs["input_dim"], dtype=self.data_settings.dtype))
return nn.Sequential(*layers) return nn.Sequential(*layers)
def get_sliced_data(self, trial: optuna.Trial): def get_sliced_data(self, trial: optuna.Trial):
assert ModelSettings.input_size % 2 == 0, "input_dim must be even" symbols = optional_suggest_float(trial, "dataset_symbols", self.data_settings.symbols_range)
symbols = trial.suggest_float(
"dataset_symbols", *self.data_settings.symbols_range, log=True xy_delay = optional_suggest_float(trial, "dataset_xy_delay", self.data_settings.xy_delay_range)
)
xy_delay = trial.suggest_float( data_size = trial.params.get(
"dataset_xy_delay", *self.data_settings.xy_delay_range "dataset_data_size",
) optional_suggest_int(trial, "dataset_data_size", self.data_settings.data_size_range)
data_size = trial.suggest_int(
"dataset_data_size", *self.data_settings.data_size_range
) )
# get dataset # get dataset
dataset = FiberRegenerationDataset( dataset = FiberRegenerationDataset(
file_path=self.data_settings.config_path, file_path=self.data_settings.config_path,
symbols=symbols, symbols=symbols,
data_size=data_size, # two channels (x,y) data_size=data_size,
target_delay=self.data_settings.target_delay, target_delay=self.data_settings.target_delay,
xy_delay=xy_delay, xy_delay=xy_delay,
drop_first=self.data_settings.drop_first, drop_first=self.data_settings.drop_first,
dtype=self.data_settings.dtype,
) )
dataset_size = len(dataset) dataset_size = len(dataset)
indices = list(range(dataset_size)) indices = list(range(dataset_size))
split = int(np.floor(self.data_settings.train_split * dataset_size)) split = int(np.floor(self.data_settings.train_split * dataset_size))
np.random.seed(self.global_settings.seed) if self.data_settings.shuffle:
np.random.shuffle(indices) np.random.seed(self.global_settings.seed)
np.random.shuffle(indices)
train_indices, valid_indices = indices[:split], indices[split:] train_indices, valid_indices = indices[:split], indices[split:]
train_sampler = torch.utils.data.SubsetRandomSampler(train_indices) train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
valid_sampler = torch.utils.data.SubsetRandomSampler(valid_indices) valid_sampler = torch.utils.data.SubsetRandomSampler(valid_indices)
train_loader = torch.utils.data.DataLoader( train_loader = torch.utils.data.DataLoader(
dataset, batch_size=self.pytorch_settings.batchsize, sampler=train_sampler dataset, batch_size=self.pytorch_settings.batchsize, sampler=train_sampler, drop_last=True
) )
valid_loader = torch.utils.data.DataLoader( valid_loader = torch.utils.data.DataLoader(
dataset, batch_size=self.pytorch_settings.batchsize, sampler=valid_sampler dataset, batch_size=self.pytorch_settings.batchsize, sampler=valid_sampler, drop_last=True
) )
return train_loader, valid_loader return train_loader, valid_loader
def train_model(self, model, optimizer, train_loader, epoch, writer=None, enable_progress=True):
if enable_progress:
progress = Progress(
TextColumn("[yellow] Training..."),
TextColumn(" Loss: {task.description}"),
BarColumn(),
TaskProgressColumn(),
TextColumn("[green]Batch"),
MofNCompleteColumn(),
TimeRemainingColumn(),
# description="Training",
transient=False,
console=self.console,
refresh_per_second=10,
)
task = progress.add_task("-.---e--", total=len(train_loader))
def train_model(self, model, optimizer, train_loader): running_loss = 0.0
last_loss = 0.0
model.train() model.train()
for batch_idx, (data, target) in enumerate(train_loader): for batch_idx, (x, y) in enumerate(train_loader):
if (batch_idx * train_loader.batchsize if (
>= self.optuna_settings.n_train_examples): batch_idx * train_loader.batch_size
>= self.optuna_settings.n_train_examples
):
break break
optimizer.zero_grad() optimizer.zero_grad()
data, target = ( x, y = (
data.to(self.pytorch_settings.device), x.to(self.pytorch_settings.device),
target.to(self.pytorch_settings.device), y.to(self.pytorch_settings.device),
) )
target_pred = model(data) y_pred = model(x)
loss = F.mean_squared_error(target_pred, target) loss = complex_sse_loss(y_pred, y)
loss.backward() loss.backward()
optimizer.step() optimizer.step()
# clamp weights to keep energy bounded
for p in model.parameters():
p.data.clamp_(-1.0, 1.0)
def eval_model(self, model, valid_loader): last_loss = loss.item()
if enable_progress:
progress.update(task, advance=1, description=f"{last_loss:.3e}")
running_loss += loss.item()
if writer is not None:
if batch_idx % 10 == 0:
writer.add_scalar("training loss", running_loss/10, epoch*min(len(train_loader), self.optuna_settings.n_train_examples/train_loader.batch_size) + batch_idx)
running_loss = 0.0
if enable_progress:
progress.update(task, description=f"{last_loss:.3e}")
progress.stop()
def eval_model(self, model, valid_loader, epoch, writer=None, enable_progress=True):
if enable_progress:
progress = Progress(
TextColumn("[green]Evaluating..."),
TextColumn("Error: {task.description}"),
BarColumn(),
TaskProgressColumn(),
TextColumn("[green]Batch"),
MofNCompleteColumn(),
TimeRemainingColumn(),
# description="Training",
transient=False,
console=self.console,
refresh_per_second=10,
)
task = progress.add_task("-.---e--", total=len(valid_loader))
model.eval() model.eval()
correct = 0 running_error = 0
running_error_2 = 0
with torch.no_grad(): with torch.no_grad():
for batch_idx, (data, target) in enumerate(valid_loader): for batch_idx, (x, y) in enumerate(valid_loader):
if ( if (
batch_idx * valid_loader.batchsize batch_idx * valid_loader.batch_size
>= self.optuna_settings.n_valid_examples >= self.optuna_settings.n_valid_examples
): ):
break break
data, target = ( x, y = (
data.to(self.pytorch_settings.device), x.to(self.pytorch_settings.device),
target.to(self.pytorch_settings.device), y.to(self.pytorch_settings.device),
) )
target_pred = model(data) y_pred = model(x)
pred = target_pred.argmax(dim=1, keepdim=True) error = complex_sse_loss(y_pred, y)
correct += pred.eq(target.view_as(pred)).sum().item() running_error += error.item()
running_error_2 += error.item()
accuracy = correct / len(valid_loader.dataset) if enable_progress:
# num_params = sum(p.numel() for p in model.parameters()) progress.update(task, advance=1, description=f"{error.item():.3e}")
return accuracy
if writer is not None:
if batch_idx % 10 == 0:
writer.add_scalar("sse", running_error_2/10, epoch*min(len(valid_loader), self.optuna_settings.n_valid_examples/valid_loader.batch_size) + batch_idx)
running_error_2 = 0.0
running_error /= batch_idx + 1
if enable_progress:
progress.update(task, description=f"{running_error:.3e}")
progress.stop()
return running_error
def run_model(self, model, loader):
model.eval()
y_preds = []
with torch.no_grad():
for x, y in loader:
x, y = (
x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device),
)
y_preds.append(model(x))
return torch.stack(y_preds)
def objective(self, trial: optuna.Trial): def objective(self, trial: optuna.Trial):
model = self.define_model(trial).to(self.pytorch_settings.device) writer = self.setup_tb_writer(self.optuna_settings.study_name, f"{trial.number:0>len(str(self.optuna_settings.n_trials))}")
optimizer_name = trial.suggest_categorical(
"optimizer", self.optimizer_settings.optimizer_range
)
lr = trial.suggest_float("lr", *self.optimizer_settings.lr_range, log=True)
optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)
train_loader, valid_loader = self.get_sliced_data(trial) train_loader, valid_loader = self.get_sliced_data(trial)
for epoch in range(self.pytorch_settings.epochs): model = self.define_model(trial, writer).to(self.pytorch_settings.device)
self.train_model(model, optimizer, train_loader)
accuracy = self.eval_model(model, valid_loader)
if len(self.optuna_settings.directions) == 1: optimizer_name = optional_suggest_categorical(trial, "optimizer", self.optimizer_settings.optimizer_range)
trial.report(accuracy, epoch)
lr = optional_suggest_float(trial, "lr", self.optimizer_settings.lr_range, log=True)
optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)
for epoch in range(self.pytorch_settings.epochs):
enable_progress = self.optuna_settings.n_threads == 1
if enable_progress:
print(f"Epoch {epoch+1}/{self.pytorch_settings.epochs}")
self.train_model(model, optimizer, train_loader, epoch, writer, enable_progress=enable_progress)
sse = self.eval_model(model, valid_loader, epoch, writer, enable_progress=enable_progress)
if not self.optuna_settings.multi_objective:
trial.report(sse, epoch)
if trial.should_prune(): if trial.should_prune():
raise optuna.exceptions.TrialPruned() raise optuna.exceptions.TrialPruned()
return accuracy writer.close()
return sse
if __name__ == "__main__": if __name__ == "__main__":
# plt.ion()
hyper_training = HyperTraining() hyper_training = HyperTraining()
hyper_training.eye()
# hyper_training.setup_study() # hyper_training.resume_latest_study()
# hyper_training.run_study()
for i in range(10): hyper_training.setup_study()
#simulate some work hyper_training.run_study()
print(i)
time.sleep(0.2)
best_model = hyper_training.define_model(hyper_training.study.best_trial).to(hyper_training.pytorch_settings.device)
data_settings_backup = copy.copy(hyper_training.data_settings)
hyper_training.data_settings.shuffle = False
hyper_training.data_settings.train_split = 0.01
plot_loader, _ = hyper_training.get_sliced_data(hyper_training.study.best_trial)
regen = hyper_training.run_model(best_model, plot_loader)
regen = regen.view(-1, 2)
# [batch_no, batch_size, 2] -> [no, 2]
original, _ = util.datasets.load_data(hyper_training.data_settings.config_path, skipfirst=hyper_training.data_settings.drop_first)
original = original[:len(regen)]
regen = regen.cpu().numpy()
_, axs = plt.subplots(2)
for i, ax in enumerate(axs):
ax.plot(np.abs(original[:, i])**2, label="original")
ax.plot(np.abs(regen[:, i])**2, label="regen")
ax.legend()
plt.show() plt.show()
...
print(f"Best model: {best_model}")
# eye_fig = hyper_training.plot_eye()
...

View File

@@ -0,0 +1,429 @@
import copy
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
# import torch.nn.functional as F # mse_loss doesn't support complex numbers
import torch.optim as optim
import torch.utils.data
from torch.utils.tensorboard import SummaryWriter
from rich.progress import (
Progress,
TextColumn,
BarColumn,
TaskProgressColumn,
TimeRemainingColumn,
MofNCompleteColumn,
TimeElapsedColumn,
)
from rich.console import Console
from rich import print as rprint
# from util.optuna_helpers import optional_suggest_categorical, optional_suggest_float, optional_suggest_int
import util
# global settings
@dataclass
class GlobalSettings:
seed: int = 42
# data settings
@dataclass
class DataSettings:
config_path: str = "data/*-128-16384-100000-0-0-17-0-PAM4-0.ini"
dtype: torch.dtype = torch.complex64
symbols_range: float | int = 8
data_size_range: float | int = 64
shuffle: bool = True
target_delay: float = 0
xy_delay_range: float | int = 0
drop_first: int = 10
train_split: float = 0.8
# pytorch settings
@dataclass
class PytorchSettings:
epochs: int = 1000
batchsize: int = 2**12
device: str = "cuda"
summary_dir: str = ".runs"
model_dir: str = ".models"
# model settings
@dataclass
class ModelSettings:
output_size: int = 2
# n_layer_range: float|int = 2
# n_units_range: float|int = 32
n_layers: int = 3
n_units: int = 32
activation_func: tuple | str = "ModReLU"
@dataclass
class OptimizerSettings:
optimizer_range: str = "Adam"
lr_range: float = 2e-3
class Training:
def __init__(self):
self.global_settings = GlobalSettings()
self.data_settings = DataSettings()
self.pytorch_settings = PytorchSettings()
self.model_settings = ModelSettings()
self.optimizer_settings = OptimizerSettings()
self.study_name = (
f"single_core_regen_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}"
)
if not hasattr(self.pytorch_settings, "model_dir"):
self.pytorch_settings.model_dir = ".models"
self.writer = None
self.console = Console()
def setup_tb_writer(self, study_name=None):
log_dir = (
self.pytorch_settings.summary_dir + "/" + (study_name or self.study_name)
)
self.writer = SummaryWriter(log_dir)
def plot_eye(self, width=2, symbols=None, alpha=None, complex=False, show=True):
if not hasattr(self, "eye_data"):
data, config = util.datasets.load_data(
self.data_settings.config_path,
skipfirst=10,
symbols=symbols or 1000,
real=not self.data_settings.dtype.is_complex,
normalize=True,
)
self.eye_data = {"data": data, "sps": int(config["glova"]["sps"])}
return util.plot.eye(
**self.eye_data,
width=width,
show=show,
alpha=alpha,
complex=complex,
symbols=symbols or 1000,
skipfirst=0,
)
def define_model(self):
n_layers = self.model_settings.n_layers
in_features = 2 * self.data_settings.data_size_range
layers = []
for i in range(n_layers):
out_features = self.model_settings.n_units
layers.append(util.complexNN.UnitaryLayer(in_features, out_features))
# layers.append(getattr(nn, self.model_settings.activation_func)())
layers.append(
getattr(util.complexNN, self.model_settings.activation_func)()
)
in_features = out_features
layers.append(
util.complexNN.UnitaryLayer(in_features, self.model_settings.output_size)
)
if self.writer is not None:
self.writer.add_graph(
nn.Sequential(*layers),
torch.zeros(1, layers[0].in_features, dtype=self.data_settings.dtype),
)
return nn.Sequential(*layers)
def get_sliced_data(self):
symbols = self.data_settings.symbols_range
xy_delay = self.data_settings.xy_delay_range
data_size = self.data_settings.data_size_range
# get dataset
dataset = util.datasets.FiberRegenerationDataset(
file_path=self.data_settings.config_path,
symbols=symbols,
data_size=data_size,
target_delay=self.data_settings.target_delay,
xy_delay=xy_delay,
drop_first=self.data_settings.drop_first,
dtype=self.data_settings.dtype,
real=not self.data_settings.dtype.is_complex,
# device=self.pytorch_settings.device,
)
dataset_size = len(dataset)
indices = list(range(dataset_size))
split = int(np.floor(self.data_settings.train_split * dataset_size))
if self.data_settings.shuffle:
np.random.seed(self.global_settings.seed)
np.random.shuffle(indices)
train_indices, valid_indices = indices[:split], indices[split:]
if self.data_settings.shuffle:
train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
valid_sampler = torch.utils.data.SubsetRandomSampler(valid_indices)
else:
train_sampler = train_indices
valid_sampler = valid_indices
train_loader = torch.utils.data.DataLoader(
dataset,
batch_size=self.pytorch_settings.batchsize,
sampler=train_sampler,
drop_last=True,
pin_memory=True,
num_workers=24,
prefetch_factor=4,
# persistent_workers=True
)
valid_loader = torch.utils.data.DataLoader(
dataset,
batch_size=self.pytorch_settings.batchsize,
sampler=valid_sampler,
drop_last=True,
pin_memory=True,
num_workers=24,
prefetch_factor=4,
# persistent_workers=True
)
return train_loader, valid_loader
def train_model(self, model, optimizer, train_loader, epoch):
with Progress(
TextColumn("[yellow] Training..."),
TextColumn("Loss: {task.description}"),
BarColumn(),
TaskProgressColumn(),
TextColumn("[green]Batch"),
MofNCompleteColumn(),
TimeRemainingColumn(),
TimeElapsedColumn(),
# description="Training",
transient=False,
console=self.console,
refresh_per_second=10,
) as progress:
task = progress.add_task("-.---e--", total=len(train_loader))
running_loss = 0.0
model.train()
for batch_idx, (x, y) in enumerate(train_loader):
model.zero_grad(set_to_none=True)
x, y = (
x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device),
)
y_pred = model(x)
loss = util.complexNN.complex_mse_loss(y_pred, y)
loss.backward()
optimizer.step()
progress.update(task, advance=1, description=f"{loss.item():.3e}")
running_loss += loss.item()
if self.writer is not None:
if (batch_idx + 1) % 10 == 0:
self.writer.add_scalar(
"training loss",
running_loss / 10,
epoch * len(train_loader) + batch_idx,
)
running_loss = 0.0
return running_loss
def eval_model(self, model, valid_loader, epoch):
with Progress(
TextColumn("[green]Evaluating..."),
TextColumn("Loss: {task.description}"),
BarColumn(),
TaskProgressColumn(),
TextColumn("[green]Batch"),
MofNCompleteColumn(),
TimeRemainingColumn(),
TimeElapsedColumn(),
# description="Training",
transient=False,
console=self.console,
refresh_per_second=10,
) as progress:
task = progress.add_task("-.---e--", total=len(valid_loader))
model.eval()
running_loss = 0
running_loss2 = 0
with torch.no_grad():
for batch_idx, (x, y) in enumerate(valid_loader):
x, y = (
x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device),
)
y_pred = model(x)
loss = util.complexNN.complex_mse_loss(y_pred, y)
running_loss += loss.item()
running_loss2 += loss.item()
progress.update(task, advance=1, description=f"{loss.item():.3e}")
if self.writer is not None:
if (batch_idx + 1) % 10 == 0:
self.writer.add_scalar(
"loss",
running_loss / 10,
epoch * len(valid_loader) + batch_idx,
)
running_loss = 0.0
if self.writer is not None:
self.writer.add_figure("fiber response", self.plot_model_response(model, plot=False), epoch+1)
return running_loss2 / len(valid_loader)
def run_model(self, model, loader):
model.eval()
xs = []
ys = []
y_preds = []
with torch.no_grad():
model = model.to(self.pytorch_settings.device)
for x, y in loader:
x, y = (
x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device),
)
y_pred = model(x).cpu()
# x = x.cpu()
# y = y.cpu()
y_pred = y_pred.view(y_pred.shape[0], -1, 2)
y = y.view(y.shape[0], -1, 2)
x = x.view(x.shape[0], -1, 2)
xs.append(x[:, 0, :].squeeze())
ys.append(y.squeeze())
y_preds.append(y_pred.squeeze())
xs = torch.vstack(xs).cpu()
ys = torch.vstack(ys).cpu()
y_preds = torch.vstack(y_preds).cpu()
return ys, xs, y_preds
def dummy_model(self, loader):
xs = []
ys = []
for x, y in loader:
y = y.cpu().view(y.shape[0], -1, 2)
x = x.cpu().view(x.shape[0], -1, 2)
xs.append(x[:, 0, :].squeeze())
ys.append(y.squeeze())
xs = torch.vstack(xs)
ys = torch.vstack(ys)
return xs, ys
def objective(self, save=False, plot_before=False):
try:
rprint(*list(self.study_name.split("_")))
self.model = self.define_model().to(self.pytorch_settings.device)
if self.writer is not None:
self.writer.add_figure("fiber response", self.plot_model_response(plot=plot_before), 0)
train_loader, valid_loader = self.get_sliced_data()
optimizer_name = self.optimizer_settings.optimizer_range
lr = self.optimizer_settings.lr_range
optimizer = getattr(optim, optimizer_name)(self.model.parameters(), lr=lr)
for epoch in range(self.pytorch_settings.epochs):
self.console.rule(f"Epoch {epoch + 1}/{self.pytorch_settings.epochs}")
self.train_model(self.model, optimizer, train_loader, epoch)
eval_loss = self.eval_model(self.model, valid_loader, epoch)
if save:
save_path = (
Path(self.pytorch_settings.model_dir) / f"{self.study_name}.pth"
)
save_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(self.model, save_path)
return eval_loss
except KeyboardInterrupt:
pass
finally:
if hasattr(self, "model"):
except_save_path = Path(".models/exception") / f"{self.study_name}.pth"
except_save_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(self.model, except_save_path)
def _plot_model_response_plotter(self, fiber_in, fiber_out, regen, plot=True):
fig, axs = plt.subplots(2)
for i, ax in enumerate(axs):
ax.plot(np.abs(fiber_in[:, i]) ** 2, label="fiber in")
ax.plot(np.abs(fiber_out[:, i]) ** 2, label="fiber out")
ax.plot(np.abs(regen[:, i]) ** 2, label="regenerated")
ax.legend()
if plot:
plt.show()
return fig
def plot_model_response(self, model=None, plot=True):
data_settings_backup = copy.copy(self.data_settings)
self.data_settings.shuffle = False
self.data_settings.train_split = 0.01
self.data_settings.drop_first = 100
plot_loader, _ = self.get_sliced_data()
self.data_settings = data_settings_backup
fiber_in, fiber_out, regen = self.run_model(model or self.model, plot_loader)
fiber_in = fiber_in.view(-1, 2)
fiber_out = fiber_out.view(-1, 2)
regen = regen.view(-1, 2)
fiber_in = fiber_in.numpy()
fiber_out = fiber_out.numpy()
regen = regen.numpy()
# https://github.com/matplotlib/matplotlib/issues/27713#issue-2104110987
# https://github.com/matplotlib/matplotlib/issues/27713#issuecomment-1915497463
import gc
fig = self._plot_model_response_plotter(fiber_in, fiber_out, regen, plot=plot)
gc.collect()
return fig
if __name__ == "__main__":
trainer = Training()
# trainer.plot_eye()
trainer.setup_tb_writer()
trainer.objective(save=True)
best_model = trainer.model
# best_model = trainer.define_model(trainer.study.best_trial).to(trainer.pytorch_settings.device)
trainer.plot_model_response(best_model)
# print(f"Best model: {best_model}")
...

View File

@@ -1,2 +1,17 @@
from .datasets import FiberRegenerationDataset # noqa: F401 from . import datasets # noqa: F401
from .plot import eye # noqa: F401 # from .datasets import FiberRegenerationDataset # noqa: F401
# from .datasets import load_data # noqa: F401
from . import plot # noqa: F401
# from .plot import eye # noqa: F401
from . import optuna_helpers # noqa: F401
# from .optuna_helpers import optional_suggest_categorical # noqa: F401
# from .optuna_helpers import optional_suggest_float # noqa: F401
# from .optuna_helpers import optional_suggest_int # noqa: F401
from . import complexNN # noqa: F401
# from .complexNN import UnitaryLayer # noqa: F401
# from .complexNN import complex_mse_loss # noqa: F401
# from .complexNN import complex_sse_loss # noqa: F401

View File

@@ -0,0 +1,141 @@
import torch
import torch.nn as nn
def complex_mse_loss(input, target):
"""
Compute the mean squared error between two complex tensors.
"""
return torch.mean(torch.square(input.real - target.real) + torch.square(input.imag - target.imag))
def complex_sse_loss(input, target):
"""
Compute the sum squared error between two complex tensors.
"""
if input.is_complex():
return torch.sum(torch.square(input.real - target.real) + torch.square(input.imag - target.imag))
else:
return torch.sum(torch.square(input - target))
class UnitaryLayer(nn.Module):
def __init__(self, in_features, out_features):
super(UnitaryLayer, self).__init__()
assert in_features >= out_features
self.in_features = in_features
self.out_features = out_features
self.weight = nn.Parameter(torch.randn(in_features, out_features, dtype=torch.cfloat))
self.reset_parameters()
def reset_parameters(self):
q, _ = torch.linalg.qr(self.weight)
self.weight.data = q
@staticmethod
@torch.jit.script
def _unitary_forward(x, weight):
out = torch.matmul(x, weight)
return out
def forward(self, x):
return self._unitary_forward(x, self.weight)
#### as defined by zhang et al
class Identity(nn.Module):
"""
implements the "activation" function
M(z) = z
"""
def __init__(self):
super(Identity, self).__init__()
def forward(self, x):
return x
class Mag(nn.Module):
"""
implements the activation function
M(z) = ||z||
"""
def __init__(self):
super(Mag, self).__init__()
@torch.jit.script
def forward(self, x):
return torch.abs(x.real**2 + x.imag**2)
# class Tanh(nn.Module):
# """
# implements the activation function
# M(z) = tanh(z) = sinh(z)/cosh(z) = (exp(z)-exp(-z))/(exp(z)+exp(-z)) = (exp(2*z)-1)/(exp(2*z)+1)
# """
# def __init__(self):
# super(Tanh, self).__init__()
# def forward(self, x):
# return torch.tanh(x)
class ModReLU(nn.Module):
"""
implements the activation function
M(z) = ReLU(||z|| + b)*exp(j*theta_z)
= ReLU(||z|| + b)*z/||z||
"""
def __init__(self, b=0):
super(ModReLU, self).__init__()
self.b = b
self.relu = nn.ReLU()
@staticmethod
# @torch.jit.script
def _mod_relu(x, b):
mod = torch.abs(x.real**2 + x.imag**2)
return torch.relu(mod + b) * x / mod
def forward(self, x):
return self._mod_relu(x, self.b)
class CReLU(nn.Module):
"""
implements the activation function
M(z) = ReLU(Re(z)) + j*ReLU(Im(z))
"""
def __init__(self):
super(CReLU, self).__init__()
self.relu = nn.ReLU()
@torch.jit.script
def forward(self, x):
return torch.relu(x.real) + 1j*torch.relu(x.imag)
class ZReLU(nn.Module):
"""
implements the activation function
M(z) = z if 0 <= angle(z) <= pi/2
= 0 otherwise
"""
def __init__(self):
super(ZReLU, self).__init__()
@torch.jit.script
def forward(self, x):
return x * (torch.angle(x) >= 0) * (torch.angle(x) <= torch.pi/2)
# class ComplexFeedForwardNN(nn.Module):
# def __init__(self, in_features, hidden_features, out_features):
# super(ComplexFeedForwardNN, self).__init__()
# self.in_features = in_features
# self.hidden_features = hidden_features
# self.out_features = out_features
# self.fc1 = UnitaryLayer(in_features, hidden_features)
# self.fc2 = UnitaryLayer(hidden_features, out_features)
# def forward(self, x):
# x = self.fc1(x)
# x = self.fc2(x)
# return x

View File

@@ -1,11 +1,28 @@
from pathlib import Path from pathlib import Path
import torch import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
# from torch.utils.data import Sampler
import numpy as np import numpy as np
import configparser import configparser
# class SubsetSampler(Sampler[int]):
# """
# Samples elements from a given list of indices.
def load_data(config_path, skipfirst=0, num_symbols=None): # :param indices: List of indices to sample from.
# :type indices: list[int]
# """
# def __init__(self, indices):
# self.indices = indices
# def __iter__(self):
# return iter(self.indices)
# def __len__(self):
# return len(self.indices)
def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=False, device=None, dtype=None):
filepath = Path(config_path) filepath = Path(config_path)
filepath = filepath.parent.glob(filepath.name) filepath = filepath.parent.glob(filepath.name)
config = configparser.ConfigParser() config = configparser.ConfigParser()
@@ -18,15 +35,25 @@ def load_data(config_path, skipfirst=0, num_symbols=None):
datapath = Path("/".join(path_elements).replace('"', "")) datapath = Path("/".join(path_elements).replace('"', ""))
sps = int(config["glova"]["sps"]) sps = int(config["glova"]["sps"])
if num_symbols is None: if symbols is None:
num_symbols = int(config["glova"]["nos"]) - skipfirst symbols = int(config["glova"]["nos"]) - skipfirst
data = np.load(datapath)[skipfirst * sps : num_symbols * sps + skipfirst * sps] data = np.load(datapath)[skipfirst * sps : symbols * sps + skipfirst * sps]
config["glova"]["nos"] = str(num_symbols)
if normalize:
a, b, c, d = data.T
a, b, c, d = a/np.max(np.abs(a)), b/np.max(np.abs(b)), c/np.max(np.abs(c)), d/np.max(np.abs(d))
data = np.array([a, b, c, d]).T
if real:
data = np.abs(data)
config["glova"]["nos"] = str(symbols)
data = torch.tensor(data, device=device, dtype=dtype)
return data, config return data, config
def roll_along(arr, shifts, dim): def roll_along(arr, shifts, dim):
# https://stackoverflow.com/a/76920720 # https://stackoverflow.com/a/76920720
# (c) Mateen Ulhaq, 2023 # (c) Mateen Ulhaq, 2023
@@ -39,7 +66,6 @@ def roll_along(arr, shifts, dim):
indices = (dim_indices - shifts.unsqueeze(dim)) % arr.shape[dim] indices = (dim_indices - shifts.unsqueeze(dim)) % arr.shape[dim]
return torch.gather(arr, dim, indices) return torch.gather(arr, dim, indices)
class FiberRegenerationDataset(Dataset): class FiberRegenerationDataset(Dataset):
""" """
Dataset for fiber regeneration training. Dataset for fiber regeneration training.
@@ -76,6 +102,9 @@ class FiberRegenerationDataset(Dataset):
target_delay: float | int = 0, target_delay: float | int = 0,
xy_delay: float | int = 0, xy_delay: float | int = 0,
drop_first: float | int = 0, drop_first: float | int = 0,
dtype: torch.dtype = None,
real: bool = False,
device = None,
**kwargs, **kwargs,
): ):
""" """
@@ -123,13 +152,16 @@ class FiberRegenerationDataset(Dataset):
[[i + 0.1j, i + 0.2j, i + 1.1j, i + 1.2j] for i in range(12800)], [[i + 0.1j, i + 0.2j, i + 1.1j, i + 1.2j] for i in range(12800)],
dtype=np.complex128, dtype=np.complex128,
) )
data_raw = torch.tensor(data_raw, device=device, dtype=dtype)
self.config = { self.config = {
"data": {"dir": '"."', "npy_dir": '"."', "file": "faux"}, "data": {"dir": '"."', "npy_dir": '"."', "file": "faux"},
"glova": {"sps": 128}, "glova": {"sps": 128},
} }
else: else:
data_raw, self.config = load_data(file_path) data_raw, self.config = load_data(file_path, skipfirst=drop_first, real=real, normalize=True, device=device, dtype=dtype)
self.device = data_raw.device
self.samples_per_symbol = int(self.config["glova"]["sps"]) self.samples_per_symbol = int(self.config["glova"]["sps"])
self.samples_per_slice = int(symbols * self.samples_per_symbol) self.samples_per_slice = int(symbols * self.samples_per_symbol)
self.symbols_per_slice = self.samples_per_slice / self.samples_per_symbol self.symbols_per_slice = self.samples_per_slice / self.samples_per_symbol
@@ -140,7 +172,6 @@ class FiberRegenerationDataset(Dataset):
ovrd_target_delay_samples = kwargs.pop("ovrd_target_delay_samples", None) ovrd_target_delay_samples = kwargs.pop("ovrd_target_delay_samples", None)
ovrd_xy_delay_samples = kwargs.pop("ovrd_xy_delay_samples", None) ovrd_xy_delay_samples = kwargs.pop("ovrd_xy_delay_samples", None)
ovrd_drop_first_samples = kwargs.pop("ovrd_drop_first_samples", None)
self.target_delay_samples = ( self.target_delay_samples = (
ovrd_target_delay_samples ovrd_target_delay_samples
@@ -152,14 +183,8 @@ class FiberRegenerationDataset(Dataset):
if ovrd_xy_delay_samples is not None if ovrd_xy_delay_samples is not None
else int(self.xy_delay * self.samples_per_symbol) else int(self.xy_delay * self.samples_per_symbol)
) )
drop_first_samples = (
ovrd_drop_first_samples
if ovrd_drop_first_samples is not None
else int(drop_first * self.samples_per_symbol)
)
# drop samples from the beginning # data_raw = torch.tensor(data_raw, dtype=dtype)
data_raw = data_raw[drop_first_samples:]
# data layout # data layout
# [ [E_in_x0, E_in_y0, E_out_x0, E_out_y0], # [ [E_in_x0, E_in_y0, E_out_x0, E_out_y0],
@@ -240,10 +265,10 @@ class FiberRegenerationDataset(Dataset):
data = data.view(data.shape[0], self.data_size, -1) data = data.view(data.shape[0], self.data_size, -1)
data = data[:, :, 0] data = data[:, :, 0]
# target is corresponding to the latest data point -> try to regenerate that # target is corresponding to the middle of the data as the output sample is influenced by the data before and after it
target = target[:, : target.shape[1] // self.data_size * self.data_size] target = target[:, : target.shape[1] // self.data_size * self.data_size]
target = target.view(target.shape[0], self.data_size, -1) target = target.view(target.shape[0], self.data_size, -1)
target = target[:, 0, 0] target = target[:, 0, target.shape[2] // 2]
data = data.transpose(0, 1).flatten().squeeze() data = data.transpose(0, 1).flatten().squeeze()
target = target.flatten().squeeze() target = target.flatten().squeeze()

View File

@@ -0,0 +1,30 @@
def _optional_suggest(trial, name, range_or_value, log=False, step=None, type='int'):
# not a range
if not hasattr(range_or_value, '__iter__') or isinstance(range_or_value, str):
return range_or_value
# range with only one value
if len(range_or_value) == 1:
return range_or_value[0]
if type == 'int':
step = step or 1
return trial.suggest_int(name, *range_or_value, step=step, log=log)
if type == 'float':
return trial.suggest_float(name, *range_or_value, step=step, log=log)
if type == 'categorical':
return trial.suggest_categorical(name, range_or_value)
raise ValueError(f"Unknown type: {type}")
def optional_suggest_categorical(trial, name, choices_or_value):
return _optional_suggest(trial, name, choices_or_value, type='categorical')
def optional_suggest_int(trial, name, range_or_value, step=None, log=False):
return _optional_suggest(trial, name, range_or_value, step=step, log=log, type='int')
def optional_suggest_float(trial, name, range_or_value, step=None, log=False):
return _optional_suggest(trial, name, range_or_value, step=step, log=log, type='float')

View File

@@ -0,0 +1,18 @@
from dash import Dash, dcc, html
import logging
import dash_bootstrap_components as dbc
def show_figures(*figures):
for figure in figures:
figure.layout.template = 'plotly_dark'
app = Dash(external_stylesheets=[dbc.themes.DARKLY])
app.layout = html.Div([
dcc.Graph(figure=figure) for figure in figures
])
log = logging.getLogger('werkzeug')
log.setLevel(logging.ERROR)
app.show = lambda *args, **kwargs: app.run_server(*args, **kwargs, debug=False)
return app

View File

@@ -2,33 +2,72 @@ import matplotlib.pyplot as plt
import numpy as np import numpy as np
from .datasets import load_data from .datasets import load_data
def eye(path, title=None, head=1000, skipfirst=1000, show=True): def eye(*, path=None, data=None, sps=None, title=None, symbols=1000, skipfirst=0, width=2, alpha=None, complex=False, show=True):
"""Plot an eye diagram for the data given by filepath. """Plot an eye diagram for the data given by filepath.
Either path or data and sps must be given.
Args: Args:
path (str): Path to the data description file. path (str): Path to the data description file.
data (np.ndarray): Data to plot.
sps (int): Samples per symbol.
title (str): Title of the plot. title (str): Title of the plot.
head (int): Number of symbols to plot. head (int): Number of symbols to plot.
skipfirst (int): Number of symbols to skip. skipfirst (int): Number of symbols to skip.
show (bool): Whether to call plt.show(). show (bool): Whether to call plt.show().
""" """
data, config = load_data(path, skipfirst, head) if path is None and data is None:
sps = int(config["glova"]["sps"]) raise ValueError("Either path or data and sps must be given.")
if path is not None:
data, config = load_data(path, skipfirst, symbols)
sps = int(config["glova"]["sps"])
if sps is None:
raise ValueError("sps not set.")
xaxis = np.linspace(0, 2, 2*sps, endpoint=False) xaxis = np.linspace(0, width, width*sps, endpoint=False)
fig, axs = plt.subplots(2, 2, figsize=(10, 10), sharex=True, sharey=True) fig, axs = plt.subplots(2, 2, figsize=(10, 10), sharex=True, sharey=True)
for i in range(head-1): if complex:
inx, iny, outx, outy = data[i*sps:(i+2)*sps].T # create secondary axis for phase
axs[0, 0].plot(xaxis, np.abs(inx)**2, color="C0", alpha=0.1) axs2 = axs[0, 0].twinx(), axs[0, 1].twinx(), axs[1, 0].twinx(), axs[1, 1].twinx()
axs[0, 1].plot(xaxis, np.abs(outx)**2, color="C0", alpha=0.1) axs2 = np.reshape(axs2, (2, 2))
axs[1, 0].plot(xaxis, np.abs(iny)**2, color="C0", alpha=0.1)
axs[1, 1].plot(xaxis, np.abs(outy)**2, color="C0", alpha=0.1) for i in range(symbols-(width-1)):
inx, iny, outx, outy = data[i*sps:(i+width)*sps].T
if complex:
axs[0, 0].plot(xaxis, np.abs(inx), color="C0", alpha=alpha or 0.1)
axs[0, 1].plot(xaxis, np.abs(outx), color="C0", alpha=alpha or 0.1)
axs[1, 0].plot(xaxis, np.abs(iny), color="C0", alpha=alpha or 0.1)
axs[1, 1].plot(xaxis, np.abs(outy), color="C0", alpha=alpha or 0.1)
axs[0,0].set_ylim(0, 1.1*np.max(np.abs(data)))
axs2[0, 0].plot(xaxis, np.angle(inx), color="C1", alpha=alpha or 0.1)
axs2[0, 1].plot(xaxis, np.angle(outx), color="C1", alpha=alpha or 0.1)
axs2[1, 0].plot(xaxis, np.angle(iny), color="C1", alpha=alpha or 0.1)
axs2[1, 1].plot(xaxis, np.angle(outy), color="C1", alpha=alpha or 0.1)
else:
axs[0, 0].plot(xaxis, np.abs(inx)**2, color="C0", alpha=alpha or 0.1)
axs[0, 1].plot(xaxis, np.abs(outx)**2, color="C0", alpha=alpha or 0.1)
axs[1, 0].plot(xaxis, np.abs(iny)**2, color="C0", alpha=alpha or 0.1)
axs[1, 1].plot(xaxis, np.abs(outy)**2, color="C0", alpha=alpha or 0.1)
if complex:
axs2[0, 0].sharey(axs2[0, 1])
axs2[0, 1].sharey(axs2[1, 0])
axs2[1, 0].sharey(axs2[1, 1])
# make y axis symmetric
ylim = np.max(np.abs(np.angle(data)))*1.1
if ylim != 0:
axs2[0, 0].set_ylim(-ylim, ylim)
else:
axs[0,0].set_ylim(0, 1.1*np.max(np.abs(data))**2)
axs[0, 0].set_title("Input x") axs[0, 0].set_title("Input x")
axs[0, 1].set_title("Output x") axs[0, 1].set_title("Output x")
axs[1, 0].set_title("Input y") axs[1, 0].set_title("Input y")
axs[1, 1].set_title("Output y") axs[1, 1].set_title("Output y")
fig.suptitle(title) fig.suptitle(title or "Eye diagram")
if show: if show:
plt.show(block=False) plt.show()
return fig