Compare commits

...

23 Commits

Author SHA1 Message Date
Joseph Hopfmüller
487288c923 define new activation functions and parametrizations 2024-11-29 15:51:25 +01:00
Joseph Hopfmüller
bdf6f5bfb8 clean up regen_no_hyper.py 2024-11-29 15:50:34 +01:00
Joseph Hopfmüller
e02662ed4f new optuna studies 2024-11-29 15:49:59 +01:00
Joseph Hopfmüller
fd7a0b9c31 using latest knowledge for hyperparameter search 2024-11-29 15:49:46 +01:00
Joseph Hopfmüller
ff32aefd52 minor fixes and changes 2024-11-29 15:49:10 +01:00
Joseph Hopfmüller
b156b9ceaf refactor hypertraining.py to improve model layer handling and response plotting; adjust data settings for batch processing 2024-11-29 15:48:27 +01:00
Joseph Hopfmüller
cfa08aae4e add training.py for defining and running models without hyperparametertuning 2024-11-29 15:48:18 +01:00
Joseph Hopfmüller
0422c81f3b update single_core_regen settings new runs 2024-11-24 01:56:01 +01:00
Joseph Hopfmüller
7343ccb3a5 refactor complex loss functions for improved readability; update settings and dataset classes for consistency 2024-11-24 01:55:32 +01:00
Joseph Hopfmüller
9a16a5637d add optional parameter suggestion methods for Optuna trials 2024-11-24 01:55:12 +01:00
Joseph Hopfmüller
80e9a3379e add autosampler support 2024-11-20 23:10:14 +01:00
Joseph Hopfmüller
8d4d0468bd complexhalf (complex32) isn't supported by torch.linalg.qr 2024-11-20 22:56:26 +01:00
Joseph Hopfmüller
6358c95c42 new hyperparameter db 2024-11-20 22:49:40 +01:00
Joseph Hopfmüller
674033ac2e move hypertraining class into separate file;
move settings dataclasses into separate file;
add SemiUnitaryLayer;
clean up model response plotting code;
cnt hyperparameter search
2024-11-20 22:49:31 +01:00
Joseph Hopfmüller
cdca5de473 training loop speedup 2024-11-20 11:29:18 +01:00
Joseph Hopfmüller
1622c38582 refactor: remove unused Optuna visualization utility 2024-11-17 22:23:37 +01:00
Joseph Hopfmüller
2bba760378 add: implement Optuna visualization utility with Dash 2024-11-17 22:23:01 +01:00
Joseph Hopfmüller
9ec548757d add: regen.py (main hyperparameter training file)
feat: add utility functions for fiber dataset visualization and hyperparameter training;
housekeeping: rename dataset.py -> datasets.py
2024-11-17 22:22:37 +01:00
Joseph Hopfmüller
05a3ee9394 refactor: clean up .gitignore, remove unused scripts 2024-11-17 22:18:44 +01:00
Joseph Hopfmüller
086240489a minor edits on notes 2024-11-17 22:16:52 +01:00
Joseph Hopfmüller
87f40fc37c add SlicedDataset class and utility scripts; refactor: remove _path_fix.py and update imports; 2024-11-17 01:04:33 +01:00
Joseph Hopfmüller
90aa6dbaf8 housekeeping 2024-11-17 01:04:14 +01:00
Joseph Hopfmüller
744c5f5166 rename dir;
add torch import test script
2024-11-16 00:39:19 +01:00
28 changed files with 3508 additions and 65 deletions

1
.gitattributes vendored
View File

@@ -1,4 +1,5 @@
data/**/* filter=lfs diff=lfs merge=lfs -text data/**/* filter=lfs diff=lfs merge=lfs -text
data/*.db filter=lfs diff=lfs merge=lfs -text
data/*.ini filter=lfs diff=lfs merge=lfs text data/*.ini filter=lfs diff=lfs merge=lfs text
## lfs setup ## lfs setup

4
.gitignore vendored
View File

@@ -1,7 +1,5 @@
src/**/*.ini src/**/*.ini
.*
# VSCode
.vscode
# Byte-compiled / optimized / DLL files # Byte-compiled / optimized / DLL files
__pycache__/ __pycache__/

View File

@@ -10,6 +10,7 @@ use is covered by a right of the copyright holder of the Work).
The Work is provided under the terms of this Licence when the Licensor (as The Work is provided under the terms of this Licence when the Licensor (as
defined below) has placed the following notice immediately following the defined below) has placed the following notice immediately following the
copyright notice for the Work: copyright notice for the Work:
```raw ```raw
Licensed under the EUPL Licensed under the EUPL
``` ```

View File

@@ -13,7 +13,7 @@ Full license text in LICENSE file
# optical-regeneration # optical-regeneration
## Notes on cloning: ## Notes on cloning
- `pypho` is added as a submodule -> `--recurse-submodules` - `pypho` is added as a submodule -> `--recurse-submodules`
- This repo has about 7.5GB of datasets in it. The `git lfs fetch` step will take a while. - This repo has about 7.5GB of datasets in it. The `git lfs fetch` step will take a while.
@@ -29,4 +29,5 @@ git lfs checkout
``` ```
## License ## License
This project is licensed under EUPL-1.2. This project is licensed under EUPL-1.2.

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:e12f0c21fca93620a165fbb6ed58d0b313093e972ef4416694c29c9cea6dc867
size 831488

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:746ea83013870351296e01e294905ee027291ef79fd78c1e6b69dd9ebaa1cba0
size 10240000

View File

@@ -1,6 +1,6 @@
# CUDA 12.4 Install # CUDA 12.4 Install
> https://unknowndba.blogspot.com/2024/04/cuda-getting-started-on-wsl.html (@\_freddenis\_, 2024) > <https://unknowndba.blogspot.com/2024/04/cuda-getting-started-on-wsl.html> (@\_freddenis\_, 2024)
```bash ```bash
# get md5sums # get md5sums
@@ -49,10 +49,9 @@ make
./devicequery ./devicequery
``` ```
### if the cuda-toolkit install fails with unmet dependencies ### if the cuda-toolkit install fails with unmet dependencies
>https://askubuntu.com/a/1493087 (jspinella, 2023, CC BY-SA 4.0) ><https://askubuntu.com/a/1493087> (jspinella, 2023, CC BY-SA 4.0)
1. Open the *new* file for storing the sources list 1. Open the *new* file for storing the sources list
@@ -71,4 +70,3 @@ make
``` ```
3. Save the file and run `sudo apt update` - now the install command for CUDA should work. 3. Save the file and run `sudo apt update` - now the install command for CUDA should work.

View File

@@ -1,4 +1,4 @@
# useful links # useful links
- (Optuna)[https://optuna.org] Hyperparameter optimization framework - [Optuna](https://optuna.org) Hyperparameter optimization framework
`pip install optuna` `pip install optuna`

View File

@@ -1,46 +1,42 @@
# pyenv install # pyenv installation
## install ## pyenv
nice to have: 1. Install pyenv
```bash ```bash
sudo apt install python-is-python3 curl https://pyenv.run | bash
``` ```
```bash 2. setup zsh
curl https://pyenv.run | bash
```
## setup zsh add the following to `.zshrc`:
add the following to `.zshrc`: ```bash
export PYENV_ROOT="$HOME/.pyenv"
[[ -d $PYENV_ROOT/bin ]] && export PATH="$PYENV_ROOT/bin:$PATH"
eval "$(pyenv init -)"
```
```bash ## python installation
export PYENV_ROOT="$HOME/.pyenv"
[[ -d $PYENV_ROOT/bin ]] && export PATH="$PYENV_ROOT/bin:$PATH"
eval "$(pyenv init -)"
```
## pyenv install 1. prerequisites
prerequisites: ```bash
sudo apt update
sudo apt install build-essential libssl-dev zlib1g-dev \
libbz2-dev libreadline-dev libsqlite3-dev curl git \
libncursesw5-dev xz-utils tk-dev libxml2-dev libxmlsec1-dev \
libffi-dev liblzma-dev python3-pip
```
```bash 2. install
sudo apt update
sudo apt install build-essential libssl-dev zlib1g-dev \
libbz2-dev libreadline-dev libsqlite3-dev curl git \
libncursesw5-dev xz-utils tk-dev libxml2-dev libxmlsec1-dev \
libffi-dev liblzma-dev python3-pip
```
install: ```bash
# using python 3.12.7 as an example
pyenv install 3.12.7
```bash # optional
# using python 3.12.7 as an example pyenv global 3.12.7
pyenv install 3.12.7 pyenv versions
```
# optional
pyenv global 3.12.7
pyenv versions
```

View File

@@ -8,7 +8,8 @@ source ./.venv/bin/activate
``` ```
## install pytorch ## install pytorch
> https://pytorch.org/get-started/locally/
> <https://pytorch.org/get-started/locally/>
```bash ```bash
pip install torch torchvision torchaudio pip install torch torchvision torchaudio

View File

@@ -0,0 +1,37 @@
# add_pypho.py
#
# This file is part of the repo "optical-regeneration"
# https://git.suuppl.dev/seppl/optical-regeneration.git
#
# (c) Joseph Hopfmüller, 2024
# Licensed under the EUPL
#
# Full license text in LICENSE file
###
# copy this file into the directory where you want to use pypho
import sys
from pathlib import Path
__log = []
# add the dir above the one where this file lives
__parent_dir = Path(__file__).parent
# search for a dir containing ./pypho/pypho, then add the lower ./pypho
while not (__parent_dir / "pypho" / "pypho").exists() and __parent_dir != Path("/"):
__parent_dir = __parent_dir.parent
if __parent_dir != Path("/"):
sys.path.append(str(__parent_dir / "pypho"))
__log.append(f"Added '{__parent_dir/ "pypho"}' to 'PATH'")
else:
__log.append('pypho not found')
def show_log():
for entry in __log:
print(entry)

View File

@@ -19,9 +19,8 @@ import time
from matplotlib import pyplot as plt # noqa: F401 from matplotlib import pyplot as plt # noqa: F401
import numpy as np import numpy as np
import _path_fix # noqa: F401 import add_pypho # noqa: F401
import pypho import pypho
# import inspect
default_config = f""" default_config = f"""
[glova] [glova]
@@ -498,17 +497,18 @@ def plot_eye_diagram(
if __name__ == "__main__": if __name__ == "__main__":
add_pypho.show_log()
config = get_config() config = get_config()
length_ranges = [1000, 10000] # length_ranges = [1000, 10000]
length_scales = [1, 2, 3, 4, 5, 6, 7, 8, 9] # length_scales = [1, 2, 3, 4, 5, 6, 7, 8, 9]
lengths = [ # lengths = [
length_scale * length_range # length_scale * length_range
for length_range in length_ranges # for length_range in length_ranges
for length_scale in length_scales # for length_scale in length_scales
] # ]
lengths.append(max(length_ranges)*10) # lengths.append(max(length_ranges)*10)
# length_loop(config, lengths) # length_loop(config, lengths)

View File

@@ -1,9 +0,0 @@
import sys
from pathlib import Path
# hack to add the parent directory to the path -> pypho doesn't have to be installed as package
parent_dir = Path(__file__).parent
while not (parent_dir / "pypho" / "pypho").exists() and parent_dir != Path("/"):
parent_dir = parent_dir.parent
print(f"Adding '{parent_dir / "pypho"}' to 'sys.path' to enable import of '{parent_dir / 'pypho' / 'pypho'}'")
sys.path.append(str(parent_dir / "pypho"))

View File

@@ -0,0 +1,763 @@
import copy
from datetime import datetime
from pathlib import Path
from typing import Literal
import matplotlib.pyplot as plt
import numpy as np
import optuna
# import optunahub
import warnings
import torch
import torch.nn as nn
# import torch.nn.functional as F # mse_loss doesn't support complex numbers
import torch.optim as optim
import torch.utils.data
from torch.utils.tensorboard import SummaryWriter
# from rich.progress import (
# Progress,
# TextColumn,
# BarColumn,
# TaskProgressColumn,
# TimeRemainingColumn,
# MofNCompleteColumn,
# TimeElapsedColumn,
# )
# from rich.console import Console
# from rich import print as rprint
import multiprocessing
from util.datasets import FiberRegenerationDataset
# from util.optuna_helpers import (
# suggest_categorical_optional, # noqa: F401
# suggest_float_optional, # noqa: F401
# suggest_int_optional, # noqa: F401
# )
from util.optuna_helpers import install_optional_suggests
import util
from .settings import (
GlobalSettings,
DataSettings,
ModelSettings,
OptunaSettings,
OptimizerSettings,
PytorchSettings,
)
install_optional_suggests()
class HyperTraining:
def __init__(
self,
*,
global_settings,
data_settings,
pytorch_settings,
model_settings,
optimizer_settings,
optuna_settings,
# console=None,
):
self.global_settings: GlobalSettings = global_settings
self.data_settings: DataSettings = data_settings
self.pytorch_settings: PytorchSettings = pytorch_settings
self.model_settings: ModelSettings = model_settings
self.optimizer_settings: OptimizerSettings = optimizer_settings
self.optuna_settings: OptunaSettings = optuna_settings
self.processes = None
# self.console = console or Console()
# set some extra settings to make the code more readable
self._extra_optuna_settings()
self.stop_study = True
def setup_tb_writer(self, study_name=None, append=None):
log_dir = self.pytorch_settings.summary_dir + "/" + (study_name or self.optuna_settings.study_name)
if append is not None:
log_dir += "_" + str(append)
return SummaryWriter(log_dir)
def resume_latest_study(self, verbose=True):
study_name = self.get_latest_study()
if study_name:
if verbose:
print(f"Resuming study: {study_name}")
self.optuna_settings.study_name = study_name
def get_latest_study(self, verbose=False) -> optuna.Study:
studies = self.get_studies()
study = None
for study in studies:
study.datetime_start = study.datetime_start or datetime.min
if studies:
study = sorted(studies, key=lambda x: x.datetime_start, reverse=True)[0]
if verbose:
print(f"Last study: {study.study_name}")
else:
if verbose:
print("No previous studies found")
return optuna.load_study(study_name=study.study_name, storage=self.optuna_settings.storage)
# def study(self) -> optuna.Study:
# return optuna.load_study(self.optuna_settings.study_name, storage=self.optuna_settings.storage)
def get_studies(self):
return optuna.get_all_study_summaries(storage=self.optuna_settings.storage)
def setup_study(self):
# module = optunahub.load_module(package="samplers/auto_sampler")
if self.optuna_settings._parallel:
self.processes = []
pruner = getattr(optuna.pruners, self.optuna_settings.pruner, None)
if pruner and self.optuna_settings.pruner_kwargs is not None:
pruner = pruner(**self.optuna_settings.pruner_kwargs)
elif pruner:
pruner = pruner()
self.study = optuna.create_study(
study_name=self.optuna_settings.study_name,
storage=self.optuna_settings.storage,
load_if_exists=True,
direction=self.optuna_settings._direction,
directions=self.optuna_settings._directions,
pruner=pruner,
# sampler=module.AutoSampler(),
)
# print("using sampler:", self.study.sampler)
with warnings.catch_warnings(action="ignore"):
self.study.set_metric_names(self.optuna_settings.metrics_names)
def run_study(self):
try:
if self.optuna_settings._parallel:
self._run_parallel_study()
else:
self._run_study()
except KeyboardInterrupt:
print("Stopping. Please wait for the processes to finish.")
self.stop_study = True
def trials_left(self):
return self.optuna_settings.n_trials - len(self.study.get_trials(states=self.optuna_settings.n_trials_filter))
def remove_completed_processes(self):
if self.processes is None:
return
for p, process in enumerate(self.processes):
if not process.is_alive():
process.join()
self.processes.pop(p)
def remove_outliers(self):
if self.optuna_settings.remove_outliers is not None:
trials = self.study.get_trials(states=(optuna.trial.TrialState.COMPLETE,))
if len(trials) == 0:
return
vals = [trial.value for trial in trials]
vals = np.log(vals)
mean = np.mean(vals)
std = np.std(vals)
outliers = [
trial for trial in trials if np.log(trial.value) > mean + self.optuna_settings.remove_outliers * std
]
for trial in outliers:
trial: optuna.trial.Trial = trial
trial.state = optuna.trial.TrialState.FAIL
trial.set_user_attr("outlier", True)
def _run_study(self):
while trials_left := self.trials_left():
self.remove_outliers()
self._run_optimize(n_trials=trials_left, timeout=self.optuna_settings.timeout)
def _run_parallel_study(self):
while trials_left := self.trials_left():
self.remove_outliers()
self.remove_completed_processes()
n_trials = max(trials_left, self.optuna_settings._n_threads) // self.optuna_settings._n_threads
def target_fun():
self._run_optimize(n_trials=n_trials, timeout=self.optuna_settings.timeout)
for _ in range(self.optuna_settings._n_threads - len(self.processes)):
self.processes.append(multiprocessing.Process(target=target_fun))
self.processes[-1].start()
def _run_optimize(self, **kwargs):
self.study.optimize(
self.objective,
**kwargs,
show_progress_bar=not self.optuna_settings._parallel,
)
def _extra_optuna_settings(self):
self.optuna_settings._multi_objective = len(self.optuna_settings.directions) > 1
if self.optuna_settings._multi_objective:
self.optuna_settings._direction = None
self.optuna_settings._directions = self.optuna_settings.directions
else:
self.optuna_settings._direction = self.optuna_settings.directions[0]
self.optuna_settings._directions = None
self.optuna_settings._n_train_batches = (
self.optuna_settings.n_train_batches if self.optuna_settings.limit_examples else float("inf")
)
self.optuna_settings._n_valid_batches = (
self.optuna_settings.n_valid_batches if self.optuna_settings.limit_examples else float("inf")
)
self.optuna_settings._n_threads = self.optuna_settings.n_workers
self.optuna_settings._parallel = self.optuna_settings._n_threads > 1
def define_model(self, trial: optuna.Trial, writer=None):
n_layers = trial.suggest_int_optional("model_n_hidden_layers", self.model_settings.n_hidden_layers)
input_dim = trial.suggest_int_optional(
"model_input_dim",
self.data_settings.output_size,
step=2,
multiply=2,
set_new=False,
)
# trial.set_user_attr("model_input_dim", input_dim)
dtype = trial.suggest_categorical_optional("model_dtype", self.data_settings.dtype, set_new=False)
dtype = getattr(torch, dtype)
afunc = trial.suggest_categorical_optional("model_activation_func", self.model_settings.model_activation_func)
# T0 = trial.suggest_float_optional("T0", self.model_settings.satabsT0 , log=True)
layers = []
last_dim = input_dim
n_nodes = last_dim
for i in range(n_layers):
if hidden_dim_override := self.model_settings.overrides.get(f"n_hidden_nodes_{i}", False):
hidden_dim = trial.suggest_int_optional(f"model_hidden_dim_{i}", hidden_dim_override)
else:
hidden_dim = trial.suggest_int_optional(
f"model_hidden_dim_{i}",
self.model_settings.n_hidden_nodes,
)
layers.append(util.complexNN.SemiUnitaryLayer(last_dim, hidden_dim, dtype=dtype))
last_dim = hidden_dim
layers.append(getattr(util.complexNN, afunc)())
n_nodes += last_dim
layers.append(util.complexNN.SemiUnitaryLayer(last_dim, self.model_settings.output_dim, dtype=dtype))
model = nn.Sequential(*layers)
if writer is not None:
writer.add_graph(model, torch.zeros(1, input_dim, dtype=dtype), use_strict_trace=False)
n_params = sum(p.numel() for p in model.parameters())
trial.set_user_attr("model_n_params", n_params)
trial.set_user_attr("model_n_nodes", n_nodes)
return model.to(self.pytorch_settings.device)
def get_sliced_data(self, trial: optuna.Trial, override=None):
symbols = trial.suggest_float_optional("dataset_symbols", self.data_settings.symbols, set_new=False)
in_out_delay = trial.suggest_float_optional(
"dataset_in_out_delay", self.data_settings.in_out_delay, set_new=False
)
xy_delay = trial.suggest_float_optional("dataset_xy_delay", self.data_settings.xy_delay, set_new=False)
data_size = int(
0.5
* trial.suggest_int_optional(
"model_input_dim",
self.data_settings.output_size,
step=2,
multiply=2,
set_new=False,
)
)
dtype = trial.suggest_categorical_optional("model_dtype", self.data_settings.dtype, set_new=False)
dtype = getattr(torch, dtype)
num_symbols = None
if override is not None:
num_symbols = override.get("num_symbols", None)
# get dataset
dataset = FiberRegenerationDataset(
file_path=self.data_settings.config_path,
symbols=symbols,
output_dim=data_size,
target_delay=in_out_delay,
xy_delay=xy_delay,
drop_first=self.data_settings.drop_first,
dtype=dtype,
real=not dtype.is_complex,
num_symbols=num_symbols,
)
dataset_size = len(dataset)
indices = list(range(dataset_size))
split = int(np.floor(self.data_settings.train_split * dataset_size))
if self.data_settings.shuffle:
np.random.seed(self.global_settings.seed)
np.random.shuffle(indices)
train_indices, valid_indices = indices[:split], indices[split:]
if self.data_settings.shuffle:
train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
valid_sampler = torch.utils.data.SubsetRandomSampler(valid_indices)
else:
train_sampler = train_indices
valid_sampler = valid_indices
train_loader = torch.utils.data.DataLoader(
dataset,
batch_size=self.pytorch_settings.batchsize,
sampler=train_sampler,
drop_last=True,
pin_memory=True,
num_workers=self.pytorch_settings.dataloader_workers,
prefetch_factor=self.pytorch_settings.dataloader_prefetch,
)
valid_loader = torch.utils.data.DataLoader(
dataset,
batch_size=self.pytorch_settings.batchsize,
sampler=valid_sampler,
drop_last=True,
pin_memory=True,
num_workers=self.pytorch_settings.dataloader_workers,
prefetch_factor=self.pytorch_settings.dataloader_prefetch,
)
return train_loader, valid_loader
def train_model(
self,
trial,
model,
optimizer,
train_loader,
epoch,
writer=None,
# enable_progress=False,
):
# if enable_progress:
# progress = Progress(
# TextColumn("[yellow] Training..."),
# TextColumn("Error: {task.description}"),
# BarColumn(),
# TaskProgressColumn(),
# TextColumn("[green]Batch"),
# MofNCompleteColumn(),
# TimeRemainingColumn(),
# TimeElapsedColumn(),
# # description="Training",
# transient=False,
# console=self.console,
# refresh_per_second=10,
# )
# task = progress.add_task("-.---e--", total=len(train_loader))
# progress.start()
running_loss2 = 0.0
running_loss = 0.0
model.train()
for batch_idx, (x, y) in enumerate(train_loader):
if batch_idx >= self.optuna_settings._n_train_batches:
break
model.zero_grad(set_to_none=True)
x, y = (
x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device),
)
y_pred = model(x)
loss = util.complexNN.complex_mse_loss(y_pred, y)
loss_value = loss.item()
loss.backward()
optimizer.step()
running_loss2 += loss_value
running_loss += loss_value
# if enable_progress:
# progress.update(task, advance=1, description=f"{loss_value:.3e}")
if writer is not None:
if batch_idx % self.pytorch_settings.write_every == 0:
writer.add_scalar(
"training loss",
running_loss2 / (self.pytorch_settings.write_every if batch_idx > 0 else 1),
epoch * min(len(train_loader), self.optuna_settings._n_train_batches) + batch_idx,
)
running_loss2 = 0.0
# if enable_progress:
# progress.stop()
return running_loss / min(len(train_loader), self.optuna_settings._n_train_batches)
def eval_model(
self,
trial,
model,
valid_loader,
epoch,
writer=None,
# enable_progress=True
):
# if enable_progress:
# progress = Progress(
# TextColumn("[green]Evaluating..."),
# TextColumn("Error: {task.description}"),
# BarColumn(),
# TaskProgressColumn(),
# TextColumn("[green]Batch"),
# MofNCompleteColumn(),
# TimeRemainingColumn(),
# TimeElapsedColumn(),
# # description="Training",
# transient=False,
# console=self.console,
# refresh_per_second=10,
# )
# progress.start()
# task = progress.add_task("-.---e--", total=len(valid_loader))
model.eval()
running_error = 0
running_error_2 = 0
with torch.no_grad():
for batch_idx, (x, y) in enumerate(valid_loader):
if batch_idx >= self.optuna_settings._n_valid_batches:
break
x, y = (
x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device),
)
y_pred = model(x)
error = util.complexNN.complex_mse_loss(y_pred, y)
error_value = error.item()
running_error += error_value
running_error_2 += error_value
# if enable_progress:
# progress.update(task, advance=1, description=f"{error_value:.3e}")
if writer is not None:
if batch_idx % self.pytorch_settings.write_every == 0:
writer.add_scalar(
"eval loss",
running_error_2 / (self.pytorch_settings.write_every if batch_idx > 0 else 1),
epoch * min(len(valid_loader), self.optuna_settings._n_valid_batches) + batch_idx,
)
running_error_2 = 0.0
running_error /= min(len(valid_loader), self.optuna_settings._n_valid_batches)
if writer is not None:
title_append, subtitle = self.build_title(trial)
writer.add_figure(
"fiber response",
self.plot_model_response(
trial,
model=model,
title_append=title_append,
subtitle=subtitle,
show=False,
),
epoch + 1,
)
# if enable_progress:
# progress.stop()
return running_error
def run_model(self, model, loader):
model.eval()
xs = []
ys = []
y_preds = []
with torch.no_grad():
model = model.to(self.pytorch_settings.device)
for x, y in loader:
x, y = (
x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device),
)
y_pred = model(x).cpu()
# x = x.cpu()
# y = y.cpu()
y_pred = y_pred.view(y_pred.shape[0], -1, 2)
y = y.view(y.shape[0], -1, 2)
x = x.view(x.shape[0], -1, 2)
xs.append(x[:, 0, :].squeeze())
ys.append(y.squeeze())
y_preds.append(y_pred.squeeze())
xs = torch.vstack(xs).cpu()
ys = torch.vstack(ys).cpu()
y_preds = torch.vstack(y_preds).cpu()
return ys, xs, y_preds
def objective(self, trial: optuna.Trial, plot_before=False):
if self.stop_study:
trial.study.stop()
model = None
writer = self.setup_tb_writer(
self.optuna_settings.study_name,
f"{trial.number:0{len(str(self.optuna_settings.n_trials)) + 2}}",
)
model = self.define_model(trial, writer)
# n_nodes = trial.params.get("model_n_hidden_layers", self.model_settings.model_n_layers) * trial.params.get("model_hidden_dim", self.model_settings.unit_count)
title_append, subtitle = self.build_title(trial)
writer.add_figure(
"fiber response",
self.plot_model_response(
trial,
model=model,
title_append=title_append,
subtitle=subtitle,
show=plot_before,
),
0,
)
train_loader, valid_loader = self.get_sliced_data(trial)
optimizer_name = trial.suggest_categorical_optional("optimizer", self.optimizer_settings.optimizer)
lr = trial.suggest_float_optional("lr", self.optimizer_settings.learning_rate, log=True)
optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)
if self.optimizer_settings.scheduler is not None:
scheduler = getattr(optim.lr_scheduler, self.optimizer_settings.scheduler)(
optimizer, **self.optimizer_settings.scheduler_kwargs
)
for epoch in range(self.pytorch_settings.epochs):
trial.set_user_attr("epoch", epoch)
# enable_progress = self.optuna_settings.n_threads == 1
# if enable_progress:
# self.console.rule(
# f"Epoch {epoch + 1}/{self.pytorch_settings.epochs}"
# )
self.train_model(
trial,
model,
optimizer,
train_loader,
epoch,
writer,
# enable_progress=enable_progress,
)
error = self.eval_model(
trial,
model,
valid_loader,
epoch,
writer,
# enable_progress=enable_progress,
)
if self.optimizer_settings.scheduler is not None:
scheduler.step(error)
trial.set_user_attr("mse", error)
trial.set_user_attr("log_mse", np.log10(error + np.finfo(float).eps))
trial.set_user_attr("neg_mse", -error)
trial.set_user_attr("neg_log_mse", -np.log10(error + np.finfo(float).eps))
if not self.optuna_settings._multi_objective:
trial.report(error, epoch)
if trial.should_prune():
raise optuna.exceptions.TrialPruned()
writer.close()
if self.optuna_settings._multi_objective:
return -np.log10(error + np.finfo(float).eps), trial.user_attrs.get("model_n_nodes", -1)
if self.pytorch_settings.save_models and model is not None:
save_path = Path(self.pytorch_settings.model_dir) / f"{self.optuna_settings.study_name}_{trial.number}.pth"
save_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(model, save_path)
return error
def _plot_model_response_eye(self, *signals, labels=None, sps=None, title_append="", subtitle="", show=True):
if sps is None:
raise ValueError("sps must be provided")
if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))):
labels = [labels]
else:
labels = list(labels)
while len(labels) < len(signals):
labels.append(None)
# check if there are any labels
if not any(labels):
labels = [f"signal {i + 1}" for i in range(len(signals))]
fig, axs = plt.subplots(2, len(signals), sharex=True, sharey=True)
fig.suptitle(f"Eye diagram{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}")
xaxis = np.linspace(0, 2, 2 * sps, endpoint=False)
for j, (label, signal) in enumerate(zip(labels, signals)):
# signal = signal.cpu().numpy()
for i in range(len(signal) // sps - 1):
x, y = signal[i * sps : (i + 2) * sps].T
axs[0, j].plot(xaxis, np.abs(x) ** 2, color="C0", alpha=0.02)
axs[1, j].plot(xaxis, np.abs(y) ** 2, color="C0", alpha=0.02)
axs[0, j].set_title(label + " x")
axs[1, j].set_title(label + " y")
axs[0, j].set_xlabel("Symbol")
axs[1, j].set_xlabel("Symbol")
axs[0, j].set_ylabel("normalized power")
axs[1, j].set_ylabel("normalized power")
if show:
plt.show()
return fig
def _plot_model_response_head(self, *signals, labels=None, sps=None, title_append="", subtitle="", show=True):
if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))):
labels = [labels]
else:
labels = list(labels)
while len(labels) < len(signals):
labels.append(None)
# check if there are any labels
if not any(labels):
labels = [f"signal {i + 1}" for i in range(len(signals))]
fig, axs = plt.subplots(1, 2, sharex=True, sharey=True)
fig.set_size_inches(18, 6)
fig.suptitle(f"Fiber response{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}")
for i, ax in enumerate(axs):
for signal, label in zip(signals, labels):
if sps is not None:
xaxis = np.linspace(0, len(signal) / sps, len(signal), endpoint=False)
else:
xaxis = np.arange(len(signal))
ax.plot(xaxis, np.abs(signal[:, i]) ** 2, label=label)
ax.set_xlabel("Sample" if sps is None else "Symbol")
ax.set_ylabel("normalized power")
ax.legend(loc="upper right")
if show:
plt.show()
return fig
def plot_model_response(
self,
trial,
model=None,
title_append="",
subtitle="",
mode: Literal["eye", "head"] = "head",
show=True,
):
data_settings_backup = copy.deepcopy(self.data_settings)
pytorch_settings_backup = copy.deepcopy(self.pytorch_settings)
self.data_settings.drop_first = 100*128
self.data_settings.shuffle = False
self.data_settings.train_split = 1.0
self.pytorch_settings.batchsize = (
self.pytorch_settings.eye_symbols if mode == "eye" else self.pytorch_settings.head_symbols
)
plot_loader, _ = self.get_sliced_data(trial, override={"num_symbols": self.pytorch_settings.batchsize})
self.data_settings = data_settings_backup
self.pytorch_settings = pytorch_settings_backup
fiber_in, fiber_out, regen = self.run_model(model, plot_loader)
fiber_in = fiber_in.view(-1, 2)
fiber_out = fiber_out.view(-1, 2)
regen = regen.view(-1, 2)
fiber_in = fiber_in.numpy()
fiber_out = fiber_out.numpy()
regen = regen.numpy()
# https://github.com/matplotlib/matplotlib/issues/27713#issue-2104110987
# https://github.com/matplotlib/matplotlib/issues/27713#issuecomment-1915497463
import gc
if mode == "head":
fig = self._plot_model_response_head(
fiber_in,
fiber_out,
regen,
labels=("fiber in", "fiber out", "regen"),
sps=plot_loader.dataset.samples_per_symbol,
title_append=title_append,
subtitle=subtitle,
show=show,
)
elif mode == "eye":
# raise NotImplementedError("Eye diagram not implemented")
fig = self._plot_model_response_eye(
fiber_in,
fiber_out,
regen,
labels=("fiber in", "fiber out", "regen"),
sps=plot_loader.dataset.samples_per_symbol,
title_append=title_append,
subtitle=subtitle,
show=show,
)
else:
raise ValueError(f"Unknown mode: {mode}")
gc.collect()
return fig
@staticmethod
def build_title(trial: optuna.trial.Trial):
title_append = f"for trial {trial.number}"
model_n_hidden_layers = util.misc.multi_getattr((trial.params, trial.user_attrs), "model_n_hidden_layers", 0)
input_dim = util.misc.multi_getattr((trial.params, trial.user_attrs), "model_input_dim", 0)
model_dims = [
util.misc.multi_getattr((trial.params, trial.user_attrs), f"model_hidden_dim_{i}", 0)
for i in range(model_n_hidden_layers)
]
model_dims.insert(0, input_dim)
model_dims.append(2)
model_dims = [str(dim) for dim in model_dims]
model_activation_func = util.misc.multi_getattr(
(trial.params, trial.user_attrs),
"model_activation_func",
"unknown act. fun",
)
model_dtype = util.misc.multi_getattr((trial.params, trial.user_attrs), "model_dtype", "unknown dtype")
subtitle = (
f"{model_n_hidden_layers+2} layers à ({', '.join(model_dims)}) units, {model_activation_func}, {model_dtype}"
)
return title_append, subtitle

View File

@@ -0,0 +1,98 @@
from dataclasses import dataclass, field
from datetime import datetime
# global settings
@dataclass(frozen=True)
class GlobalSettings:
seed: int = 42
# data settings
@dataclass
class DataSettings:
config_path: str # = "data/*-128-16384-100000-0-0-17-0-PAM4-0.ini"
dtype: tuple = ("complex64", "float64")
symbols: tuple | float | int = 8
output_size: tuple | float | int = 64
shuffle: bool = True
in_out_delay: float = 0
xy_delay: tuple | float | int = 0
drop_first: int = 1000
train_split: float = 0.8
# pytorch settings
@dataclass
class PytorchSettings:
epochs: int = 1
batchsize: int = 2**10
device: str = "cuda"
dataloader_workers: int = 2
dataloader_prefetch: int = 2
save_models: bool = True
model_dir: str = ".models"
summary_dir: str = ".runs"
write_every: int = 10
head_symbols: int = 40
eye_symbols: int = 400
# model settings
@dataclass
class ModelSettings:
output_dim: int = 2
n_hidden_layers: tuple | int = 3
n_hidden_nodes: tuple | int = 8
model_activation_func: tuple | str = "ModReLU"
overrides: dict = field(default_factory=dict)
dropout_prob: float | None = None
model_layer_function: str | None = None
model_layer_parametrizations: list= field(default_factory=list)
@dataclass
class OptimizerSettings:
optimizer: tuple | str = ("Adam", "RMSprop", "SGD")
learning_rate: tuple | float = (1e-5, 1e-1)
scheduler: str | None = None
scheduler_kwargs: dict | None = None
def _pruner_default_kwargs():
# MedianPruner
return {
"n_startup_trials": 0,
"n_warmup_steps": 5,
}
# optuna settings
@dataclass
class OptunaSettings:
n_trials: int = 128
n_workers: int = 1
timeout: int = None
pruner: str = "MedianPruner"
pruner_kwargs: dict = field(default_factory=_pruner_default_kwargs)
directions: tuple = ("minimize",)
metrics_names: tuple = ("mse",)
limit_examples: bool = True
n_train_batches: int = float("inf")
n_valid_batches: int = float("inf")
storage: str = "sqlite:///example.db"
study_name: str = (
f"optuna_study_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}"
)
n_trials_filter: tuple|list = None, #(optuna.trial.TrialState.COMPLETE,)
remove_outliers: float|int = None
## reserved, set by HyperTraining
_multi_objective = None
_parallel = None
_n_threads = None
_directions = None
_direction = None
_n_train_batches = None
_n_valid_batches = None

View File

@@ -0,0 +1,739 @@
import copy
from datetime import datetime
from pathlib import Path
from typing import Literal
import matplotlib
import torch.nn.utils.parametrize
try:
matplotlib.use("cairo")
except ImportError:
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
# import torch.nn.functional as F # mse_loss doesn't support complex numbers
import torch.optim as optim
import torch.utils.data
from torch.utils.tensorboard import SummaryWriter
from rich.progress import (
Progress,
TextColumn,
BarColumn,
TaskProgressColumn,
TimeRemainingColumn,
MofNCompleteColumn,
TimeElapsedColumn,
)
from rich.console import Console
from util.datasets import FiberRegenerationDataset
import util
from .settings import (
GlobalSettings,
DataSettings,
ModelSettings,
OptimizerSettings,
PytorchSettings,
)
class regenerator(nn.Module):
def __init__(
self,
*dims,
layer_function=util.complexNN.ONN,
layer_parametrizations: list[dict] = None,
# [
# {
# "tensor_name": "weight",
# "parametrization": util.complexNN.Unitary,
# },
# {
# "tensor_name": "scale",
# "parametrization": util.complexNN.Clamp,
# },
# ],
activation_function=util.complexNN.Pow,
dtype=torch.float64,
dropout_prob=0.01,
**kwargs,
):
super(regenerator, self).__init__()
if len(dims) == 0:
try:
dims = kwargs["dims"]
except KeyError:
raise ValueError("dims must be provided")
self._n_hidden_layers = len(dims) - 2
self._layers = nn.Sequential()
for i in range(self._n_hidden_layers + 1):
self._layers.append(layer_function(dims[i], dims[i + 1], dtype=dtype))
if i < self._n_hidden_layers:
if dropout_prob is not None:
self._layers.append(util.complexNN.DropoutComplex(p=dropout_prob))
self._layers.append(activation_function())
# add parametrizations
if layer_parametrizations is not None:
for layer_parametrization in layer_parametrizations:
tensor_name = layer_parametrization.get("tensor_name", None)
parametrization = layer_parametrization.get("parametrization", None)
param_kwargs = layer_parametrization.get("kwargs", {})
if (
tensor_name is not None
and tensor_name in self._layers[-1]._parameters
and parametrization is not None
):
parametrization(self._layers[-1], tensor_name, **param_kwargs)
def forward(self, input_x):
x = input_x
# check if tracing
if torch.jit.is_tracing():
for layer in self._layers:
x = layer(x)
else:
# with torch.nn.utils.parametrize.cached():
for layer in self._layers:
x = layer(x)
return x
def traverse_dict_update(target, source):
for k, v in source.items():
if isinstance(v, dict):
if k not in target:
target[k] = {}
traverse_dict_update(target[k], v)
else:
try:
target[k] = v
except TypeError:
target.__dict__[k] = v
class Trainer:
def __init__(
self,
*,
global_settings=None,
data_settings=None,
pytorch_settings=None,
model_settings=None,
optimizer_settings=None,
console=None,
checkpoint_path=None,
settings_override=None,
reset_epoch=False,
):
self.resume = checkpoint_path is not None
torch.serialization.add_safe_globals([
*util.complexNN.__all__,
GlobalSettings,
DataSettings,
ModelSettings,
OptimizerSettings,
PytorchSettings,
regenerator,
torch.nn.utils.parametrizations.orthogonal
])
if self.resume:
self.checkpoint_dict = torch.load(checkpoint_path, weights_only=True)
if settings_override is not None:
traverse_dict_update(self.checkpoint_dict["settings"], settings_override)
if reset_epoch:
self.checkpoint_dict["epoch"] = -1
self.global_settings: GlobalSettings = self.checkpoint_dict["settings"]["global_settings"]
self.data_settings: DataSettings = self.checkpoint_dict["settings"]["data_settings"]
self.pytorch_settings: PytorchSettings = self.checkpoint_dict["settings"]["pytorch_settings"]
self.model_settings: ModelSettings = self.checkpoint_dict["settings"]["model_settings"]
self.optimizer_settings: OptimizerSettings = self.checkpoint_dict["settings"]["optimizer_settings"]
else:
if global_settings is None:
raise ValueError("global_settings must be provided")
if data_settings is None:
raise ValueError("data_settings must be provided")
if pytorch_settings is None:
raise ValueError("pytorch_settings must be provided")
if model_settings is None:
raise ValueError("model_settings must be provided")
if optimizer_settings is None:
raise ValueError("optimizer_settings must be provided")
self.global_settings: GlobalSettings = global_settings
self.data_settings: DataSettings = data_settings
self.pytorch_settings: PytorchSettings = pytorch_settings
self.model_settings: ModelSettings = model_settings
self.optimizer_settings: OptimizerSettings = optimizer_settings
self.console = console or Console()
self.writer = None
def setup_tb_writer(self, append=None):
log_dir = self.pytorch_settings.summary_dir + "/" + (datetime.now().strftime("%Y%m%d_%H%M%S"))
if append is not None:
log_dir += "_" + str(append)
print(f"Logging to {log_dir}")
self.writer = SummaryWriter(log_dir=log_dir)
def save_checkpoint(self, save_dict, filename):
torch.save(save_dict, filename)
def build_checkpoint_dict(self, loss=None, epoch=None):
return {
"epoch": -1 if epoch is None else epoch,
"loss": float("inf") if loss is None else loss,
"model_state_dict": copy.deepcopy(self.model.state_dict()),
"optimizer_state_dict": copy.deepcopy(self.optimizer.state_dict()),
"scheduler_state_dict": copy.deepcopy(self.scheduler.state_dict()) if hasattr(self, "scheduler") else None,
"model_kwargs": copy.deepcopy(self.model_kwargs),
"settings": {
"global_settings": copy.deepcopy(self.global_settings),
"data_settings": copy.deepcopy(self.data_settings),
"pytorch_settings": copy.deepcopy(self.pytorch_settings),
"model_settings": copy.deepcopy(self.model_settings),
"optimizer_settings": copy.deepcopy(self.optimizer_settings),
},
}
def define_model(self, model_kwargs=None):
if model_kwargs is None:
n_hidden_layers = self.model_settings.n_hidden_layers
input_dim = 2 * self.data_settings.output_size
dtype = getattr(torch, self.data_settings.dtype)
afunc = getattr(util.complexNN, self.model_settings.model_activation_func)
layer_func = getattr(util.complexNN, self.model_settings.model_layer_function)
layer_parametrizations = self.model_settings.model_layer_parametrizations
hidden_dims = [self.model_settings.overrides.get(f"n_hidden_nodes_{i}") for i in range(n_hidden_layers)]
self.model_kwargs = {
"dims": (input_dim, *hidden_dims, self.model_settings.output_dim),
"layer_function": layer_func,
"layer_parametrizations": layer_parametrizations,
"activation_function": afunc,
"dtype": dtype,
"dropout_prob": self.model_settings.dropout_prob,
}
else:
self.model_kwargs = model_kwargs
input_dim = self.model_kwargs["dims"][0]
dtype = self.model_kwargs["dtype"]
# dims = self.model_kwargs.pop("dims")
self.model = regenerator(**self.model_kwargs)
self.writer.add_graph(self.model, torch.zeros(1, input_dim, dtype=dtype))
self.model = self.model.to(self.pytorch_settings.device)
def get_sliced_data(self, override=None):
symbols = self.data_settings.symbols
in_out_delay = self.data_settings.in_out_delay
xy_delay = self.data_settings.xy_delay
data_size = self.data_settings.output_size
dtype = getattr(torch, self.data_settings.dtype)
num_symbols = None
if override is not None:
num_symbols = override.get("num_symbols", None)
# get dataset
dataset = FiberRegenerationDataset(
file_path=self.data_settings.config_path,
symbols=symbols,
output_dim=data_size,
target_delay=in_out_delay,
xy_delay=xy_delay,
drop_first=self.data_settings.drop_first,
dtype=dtype,
real=not dtype.is_complex,
num_symbols=num_symbols,
)
dataset_size = len(dataset)
indices = list(range(dataset_size))
split = int(np.floor(self.data_settings.train_split * dataset_size))
if self.data_settings.shuffle:
np.random.seed(self.global_settings.seed)
np.random.shuffle(indices)
train_indices, valid_indices = indices[:split], indices[split:]
if self.data_settings.shuffle:
train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
valid_sampler = torch.utils.data.SubsetRandomSampler(valid_indices)
else:
train_sampler = train_indices
valid_sampler = valid_indices
train_loader = torch.utils.data.DataLoader(
dataset,
batch_size=self.pytorch_settings.batchsize,
sampler=train_sampler,
drop_last=True,
pin_memory=True,
num_workers=self.pytorch_settings.dataloader_workers,
prefetch_factor=self.pytorch_settings.dataloader_prefetch,
)
valid_loader = torch.utils.data.DataLoader(
dataset,
batch_size=self.pytorch_settings.batchsize,
sampler=valid_sampler,
drop_last=True,
pin_memory=True,
num_workers=self.pytorch_settings.dataloader_workers,
prefetch_factor=self.pytorch_settings.dataloader_prefetch,
)
return train_loader, valid_loader
def train_model(
self,
optimizer,
train_loader,
epoch,
enable_progress=False,
):
if enable_progress:
progress = Progress(
TextColumn("[yellow] Training..."),
TextColumn("Error: {task.description}"),
BarColumn(),
TaskProgressColumn(),
TextColumn("[green]Batch"),
MofNCompleteColumn(),
TimeRemainingColumn(),
TimeElapsedColumn(),
transient=False,
console=self.console,
refresh_per_second=10,
)
task = progress.add_task("-.---e--", total=len(train_loader))
progress.start()
running_loss2 = 0.0
running_loss = 0.0
self.model.train()
for batch_idx, (x, y) in enumerate(train_loader):
self.model.zero_grad(set_to_none=True)
x, y = (
x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device),
)
y_pred = self.model(x)
loss = util.complexNN.complex_mse_loss(y_pred, y, power=True)
loss_value = loss.item()
loss.backward()
optimizer.step()
running_loss2 += loss_value
running_loss += loss_value
if enable_progress:
progress.update(task, advance=1, description=f"{loss_value:.3e}")
if batch_idx % self.pytorch_settings.write_every == 0:
self.writer.add_scalar(
"training loss",
running_loss2 / (self.pytorch_settings.write_every if batch_idx > 0 else 1),
epoch * len(train_loader) + batch_idx,
)
running_loss2 = 0.0
if enable_progress:
progress.stop()
return running_loss / len(train_loader)
def eval_model(self, valid_loader, epoch, enable_progress=True):
if enable_progress:
progress = Progress(
TextColumn("[green]Evaluating..."),
TextColumn("Error: {task.description}"),
BarColumn(),
TaskProgressColumn(),
TextColumn("[green]Batch"),
MofNCompleteColumn(),
TimeRemainingColumn(),
TimeElapsedColumn(),
transient=False,
console=self.console,
refresh_per_second=10,
)
progress.start()
task = progress.add_task("-.---e--", total=len(valid_loader))
self.model.eval()
running_error = 0
with torch.no_grad():
for batch_idx, (x, y) in enumerate(valid_loader):
x, y = (
x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device),
)
y_pred = self.model(x)
error = util.complexNN.complex_mse_loss(y_pred, y, power=True)
error_value = error.item()
running_error += error_value
if enable_progress:
progress.update(task, advance=1, description=f"{error_value:.3e}")
running_error /= len(valid_loader)
self.writer.add_scalar(
"eval loss",
running_error,
epoch,
)
title_append, subtitle = self.build_title(epoch + 1)
self.writer.add_figure(
"fiber response",
self.plot_model_response(
model=self.model,
title_append=title_append,
subtitle=subtitle,
show=False,
),
epoch + 1,
)
self.writer.add_figure(
"eye diagram",
self.plot_model_response(
model=self.model,
title_append=title_append,
subtitle=subtitle,
show=False,
mode="eye",
),
epoch + 1,
)
self.writer_histograms(epoch + 1)
if enable_progress:
progress.stop()
return running_error
def run_model(self, model, loader):
model.eval()
xs = []
ys = []
y_preds = []
with torch.no_grad():
model = model.to(self.pytorch_settings.device)
for x, y in loader:
x, y = (
x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device),
)
y_pred = model(x).cpu()
# x = x.cpu()
# y = y.cpu()
y_pred = y_pred.view(y_pred.shape[0], -1, 2)
y = y.view(y.shape[0], -1, 2)
x = x.view(x.shape[0], -1, 2)
xs.append(x[:, 0, :].squeeze())
ys.append(y.squeeze())
y_preds.append(y_pred.squeeze())
xs = torch.vstack(xs).cpu()
ys = torch.vstack(ys).cpu()
y_preds = torch.vstack(y_preds).cpu()
return ys, xs, y_preds
def writer_histograms(self, epoch, attributes=["weight", "weight_U", "weight_V", "bias", "sigma", "scale"]):
for i, layer in enumerate(self.model._layers):
tag = f"layer {i}"
for attribute in attributes:
if hasattr(layer, attribute):
vals: np.ndarray = getattr(layer, attribute).detach().cpu().numpy().flatten()
if vals.ndim <= 1 and len(vals) == 1:
if np.iscomplexobj(vals):
self.writer.add_scalar(f"{tag} {attribute} (Mag)", np.abs(vals), epoch)
self.writer.add_scalar(f"{tag} {attribute} (Phase)", np.angle(vals), epoch)
else:
self.writer.add_scalar(f"{tag} {attribute}", vals, epoch)
else:
if np.iscomplexobj(vals):
self.writer.add_histogram(f"{tag} {attribute} (Mag)", np.abs(vals), epoch, bins="fd")
self.writer.add_histogram(f"{tag} {attribute} (Phase)", np.angle(vals), epoch, bins="fd")
else:
self.writer.add_histogram(f"{tag} {attribute}", vals, epoch, bins="fd")
def train(self):
if self.writer is None:
self.setup_tb_writer()
if self.resume:
model_kwargs = self.checkpoint_dict["model_kwargs"]
else:
model_kwargs = None
self.define_model(model_kwargs=model_kwargs)
print(f"number of parameters (trainable): {sum(p.numel() for p in self.model.parameters())} ({sum(p.numel() for p in self.model.parameters() if p.requires_grad)})")
title_append, subtitle = self.build_title(0)
self.writer.add_figure(
"fiber response",
self.plot_model_response(
model=self.model,
title_append=title_append,
subtitle=subtitle,
show=False,
),
0,
)
self.writer.add_figure(
"eye diagram",
self.plot_model_response(
model=self.model,
title_append=title_append,
subtitle=subtitle,
mode="eye",
show=False,
),
0,
)
self.writer_histograms(0)
train_loader, valid_loader = self.get_sliced_data()
optimizer_name = self.optimizer_settings.optimizer
lr = self.optimizer_settings.learning_rate
self.optimizer: optim.Optimizer = getattr(optim, optimizer_name)(self.model.parameters(), lr=lr)
if self.optimizer_settings.scheduler is not None:
self.scheduler = getattr(optim.lr_scheduler, self.optimizer_settings.scheduler)(
self.optimizer, **self.optimizer_settings.scheduler_kwargs
)
if self.resume:
try:
self.scheduler.load_state_dict(self.checkpoint_dict["scheduler_state_dict"])
except ValueError:
pass
self.writer.add_scalar("learning rate", self.scheduler.get_last_lr()[0], -1)
if not self.resume:
self.best = self.build_checkpoint_dict()
else:
self.best = self.checkpoint_dict
self.model.load_state_dict(self.best["model_state_dict"], strict=False)
try:
self.optimizer.load_state_dict(self.best["optimizer_state_dict"])
except ValueError:
pass
for epoch in range(self.best["epoch"] + 1, self.pytorch_settings.epochs):
enable_progress = True
if enable_progress:
self.console.rule(f"Epoch {epoch + 1}/{self.pytorch_settings.epochs}")
self.train_model(
self.optimizer,
train_loader,
epoch,
enable_progress=enable_progress,
)
loss = self.eval_model(
valid_loader,
epoch,
enable_progress=enable_progress,
)
if self.optimizer_settings.scheduler is not None:
lr_old = self.scheduler.get_last_lr()
self.scheduler.step(loss)
lr_new = self.scheduler.get_last_lr()
if lr_old[0] != lr_new[0]:
self.writer.add_scalar("learning rate", lr_new[0], epoch)
if self.pytorch_settings.save_models and self.model is not None:
save_path = (
Path(self.pytorch_settings.model_dir) / f"{self.writer.get_logdir().split('/')[-1]}_{epoch}.tar"
)
save_path.parent.mkdir(parents=True, exist_ok=True)
checkpoint = self.build_checkpoint_dict(loss, epoch)
self.save_checkpoint(checkpoint, save_path)
if loss < self.best["loss"]:
self.best = checkpoint
save_path = (
Path(self.pytorch_settings.model_dir) / f"best_{self.writer.get_logdir().split('/')[-1]}.tar"
)
save_path.parent.mkdir(parents=True, exist_ok=True)
self.save_checkpoint(self.best, save_path)
self.writer.flush()
self.writer.close()
return self.best
def _plot_model_response_eye(self, *signals, labels=None, sps=None, title_append="", subtitle="", show=True):
if sps is None:
raise ValueError("sps must be provided")
if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))):
labels = [labels]
else:
labels = list(labels)
while len(labels) < len(signals):
labels.append(None)
# check if there are any labels
if not any(labels):
labels = [f"signal {i + 1}" for i in range(len(signals))]
fig, axs = plt.subplots(1, 2 * len(signals), sharex=True, sharey=True)
fig.set_figwidth(18)
fig.suptitle(f"Eye diagram{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}")
xaxis = np.linspace(0, 2, 2 * sps, endpoint=False)
for j, (label, signal) in enumerate(zip(labels, signals)):
# signal = signal.cpu().numpy()
for i in range(len(signal) // sps - 1):
x, y = signal[i * sps : (i + 2) * sps].T
axs[0 + 2 * j].plot(xaxis, np.abs(x) ** 2, color=f"C{j}", alpha=1 / (len(signal) // sps) * 10)
axs[1 + 2 * j].plot(xaxis, np.abs(y) ** 2, color=f"C{j}", alpha=1 / (len(signal) // sps) * 10)
axs[0 + 2 * j].set_title(label + " x")
axs[1 + 2 * j].set_title(label + " y")
axs[0 + 2 * j].set_xlabel("Symbol")
axs[1 + 2 * j].set_xlabel("Symbol")
axs[0 + 2 * j].set_box_aspect(1)
axs[1 + 2 * j].set_box_aspect(1)
axs[0].set_ylabel("normalized power")
fig.tight_layout()
# axs[1+2*len(labels)-1].set_ylabel("normalized power")
if show:
plt.show()
return fig
def _plot_model_response_head(self, *signals, labels=None, sps=None, title_append="", subtitle="", show=True):
if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))):
labels = [labels]
else:
labels = list(labels)
while len(labels) < len(signals):
labels.append(None)
# check if there are any labels
if not any(labels):
labels = [f"signal {i + 1}" for i in range(len(signals))]
fig, axs = plt.subplots(1, 2, sharex=True, sharey=True)
fig.set_figwidth(18)
fig.set_figheight(4)
fig.suptitle(f"Fiber response{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}")
for i, ax in enumerate(axs):
for signal, label in zip(signals, labels):
if sps is not None:
xaxis = np.linspace(0, len(signal) / sps, len(signal), endpoint=False)
else:
xaxis = np.arange(len(signal))
ax.plot(xaxis, np.abs(signal[:, i]) ** 2, label=label)
ax.set_xlabel("Sample" if sps is None else "Symbol")
ax.set_ylabel("normalized power")
ax.legend(loc="upper right")
fig.tight_layout()
if show:
plt.show()
return fig
def plot_model_response(
self,
model=None,
title_append="",
subtitle="",
mode: Literal["eye", "head"] = "head",
show=False,
):
data_settings_backup = copy.deepcopy(self.data_settings)
pytorch_settings_backup = copy.deepcopy(self.pytorch_settings)
self.data_settings.drop_first = 100 * 128
self.data_settings.shuffle = False
self.data_settings.train_split = 1.0
self.pytorch_settings.batchsize = (
self.pytorch_settings.eye_symbols if mode == "eye" else self.pytorch_settings.head_symbols
)
plot_loader, _ = self.get_sliced_data(override={"num_symbols": self.pytorch_settings.batchsize})
self.data_settings = data_settings_backup
self.pytorch_settings = pytorch_settings_backup
fiber_in, fiber_out, regen = self.run_model(model, plot_loader)
fiber_in = fiber_in.view(-1, 2)
fiber_out = fiber_out.view(-1, 2)
regen = regen.view(-1, 2)
fiber_in = fiber_in.numpy()
fiber_out = fiber_out.numpy()
regen = regen.numpy()
# https://github.com/matplotlib/matplotlib/issues/27713#issue-2104110987
# https://github.com/matplotlib/matplotlib/issues/27713#issuecomment-1915497463
import gc
if mode == "head":
fig = self._plot_model_response_head(
fiber_in,
fiber_out,
regen,
labels=("fiber in", "fiber out", "regen"),
sps=plot_loader.dataset.samples_per_symbol,
title_append=title_append,
subtitle=subtitle,
show=show,
)
elif mode == "eye":
# raise NotImplementedError("Eye diagram not implemented")
fig = self._plot_model_response_eye(
fiber_in,
fiber_out,
regen,
labels=("fiber in", "fiber out", "regen"),
sps=plot_loader.dataset.samples_per_symbol,
title_append=title_append,
subtitle=subtitle,
show=show,
)
else:
raise ValueError(f"Unknown mode: {mode}")
gc.collect()
return fig
def build_title(self, number: int):
title_append = f"epoch {number}"
model_n_hidden_layers = self.model_settings.n_hidden_layers
input_dim = 2 * self.data_settings.output_size
model_dims = [
self.model_settings.overrides.get(f"n_hidden_nodes_{i}", -1) for i in range(model_n_hidden_layers)
]
model_dims.insert(0, input_dim)
model_dims.append(2)
model_dims = [str(dim) for dim in model_dims]
model_activation_func = self.model_settings.model_activation_func
model_dtype = self.data_settings.dtype
subtitle = f"{model_n_hidden_layers + 2} layers à ({', '.join(model_dims)}) units, {model_activation_func}, {model_dtype}"
return title_append, subtitle

View File

@@ -0,0 +1,118 @@
from datetime import datetime
import optuna
from hypertraining.hypertraining import HyperTraining
from hypertraining.settings import (
GlobalSettings,
DataSettings,
PytorchSettings,
ModelSettings,
OptimizerSettings,
OptunaSettings,
)
global_settings = GlobalSettings(
seed=42,
)
data_settings = DataSettings(
config_path="data/*-128-16384-100000-0-0-17-0-PAM4-0.ini",
dtype="complex64",
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
symbols=13, # study: single_core_regen_20241123_011232
# output_size = (11, 32), # ballpark 26 taps -> 2 taps per input symbol -> 1 tap every 0.01m (model has 52 inputs)
output_size=26, # study: single_core_regen_20241123_011232 (model_input_dim/2)
shuffle=True,
in_out_delay=0,
xy_delay=0,
drop_first=128 * 100,
train_split=0.8,
)
pytorch_settings = PytorchSettings(
epochs=10000,
batchsize=2**10,
device="cuda",
dataloader_workers=12,
dataloader_prefetch=4,
summary_dir=".runs",
write_every=2**5,
save_models=True,
model_dir=".models",
)
model_settings = ModelSettings(
output_dim=2,
# n_hidden_layers = (3, 8),
n_hidden_layers=4,
overrides={
"n_hidden_nodes_0": 8,
"n_hidden_nodes_1": 6,
"n_hidden_nodes_2": 4,
"n_hidden_nodes_3": 8,
},
model_activation_func="Mag",
# satabsT0=(1e-6, 1),
)
optimizer_settings = OptimizerSettings(
optimizer="Adam",
# learning_rate = (1e-5, 1e-1),
learning_rate=5e-3
# learning_rate=5e-4,
)
optuna_settings = OptunaSettings(
n_trials=1,
n_workers=1,
timeout=3600,
directions=("minimize",),
metrics_names=("mse",),
limit_examples=False,
n_train_batches=500,
# n_valid_batches = 100,
storage="sqlite:///data/single_core_regen.db",
study_name=f"single_core_regen_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
n_trials_filter=(optuna.trial.TrialState.COMPLETE, optuna.trial.TrialState.PRUNED),
pruner="MedianPruner",
pruner_kwargs=None
)
if __name__ == "__main__":
hyper_training = HyperTraining(
global_settings=global_settings,
data_settings=data_settings,
pytorch_settings=pytorch_settings,
model_settings=model_settings,
optimizer_settings=optimizer_settings,
optuna_settings=optuna_settings,
)
hyper_training.setup_study()
hyper_training.run_study()
# best_trial = hyper_training.study.best_trial
# best_model = hyper_training.define_model(best_trial).to(
# hyper_training.pytorch_settings.device
# )
# title_append, subtitle = hyper_training.build_title(best_trial)
# hyper_training.plot_model_response(
# best_trial,
# model=best_model,
# title_append=title_append,
# subtitle=subtitle,
# mode="eye",
# show=True,
# )
# print(f"Best model found for trial {best_trial.number}")
# print(f"Best model error: {best_trial.value}")
# print(f"Best model params: {best_trial.params}")
# print()
# print(best_model)
# eye_fig = hyper_training.plot_eye()
...

View File

@@ -0,0 +1,130 @@
from hypertraining.settings import (
GlobalSettings,
DataSettings,
PytorchSettings,
ModelSettings,
OptimizerSettings,
)
from hypertraining.training import Trainer
import torch
import json
import util
global_settings = GlobalSettings(
seed=42,
)
data_settings = DataSettings(
config_path="data/*-128-16384-100000-0-0-17-0-PAM4-0.ini",
dtype="complex64",
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
symbols=13, # study: single_core_regen_20241123_011232
# output_size = (11, 32), # 26 taps -> 2 taps per input symbol -> 1 tap every 1cm (model has 52 inputs (x/y))
output_size=26, # study: single_core_regen_20241123_011232 (model_input_dim/2)
shuffle=True,
in_out_delay=0,
xy_delay=0,
drop_first=128*64,
train_split=0.8,
)
pytorch_settings = PytorchSettings(
epochs=10000,
batchsize=2**12,
device="cuda",
dataloader_workers=12,
dataloader_prefetch=8,
summary_dir=".runs",
write_every=2**5,
save_models=True,
model_dir=".models",
)
model_settings = ModelSettings(
output_dim=2,
n_hidden_layers=4,
overrides={
"n_hidden_nodes_0": 8,
"n_hidden_nodes_1": 8,
"n_hidden_nodes_2": 4,
"n_hidden_nodes_3": 6,
},
model_activation_func="PowScale",
# dropout_prob=0.01,
model_layer_function="ONN",
model_layer_parametrizations=[
{
"tensor_name": "weight",
"parametrization": torch.nn.utils.parametrizations.orthogonal,
},
{
"tensor_name": "scales",
"parametrization": util.complexNN.clamp,
},
{
"tensor_name": "scale",
"parametrization": util.complexNN.clamp,
},
{
"tensor_name": "bias",
"parametrization": util.complexNN.clamp,
},
# {
# "tensor_name": "V",
# "parametrization": torch.nn.utils.parametrizations.orthogonal,
# },
# {
# "tensor_name": "S",
# "parametrization": util.complexNN.clamp,
# },
],
)
optimizer_settings = OptimizerSettings(
optimizer="Adam",
learning_rate=0.05,
scheduler="ReduceLROnPlateau",
scheduler_kwargs={
"patience": 2**6,
"factor": 0.9,
# "threshold": 1e-3,
"min_lr": 1e-6,
"cooldown": 10,
},
)
def save_dict_to_file(dictionary, filename):
"""
Save the best dictionary to a JSON file.
:param best: Dictionary containing the best training results.
:type best: dict
:param filename: Path to the JSON file where the dictionary will be saved.
:type filename: str
"""
with open(filename, 'w') as f:
json.dump(dictionary, f, indent=4)
if __name__ == "__main__":
trainer = Trainer(
global_settings=global_settings,
data_settings=data_settings,
pytorch_settings=pytorch_settings,
model_settings=model_settings,
optimizer_settings=optimizer_settings,
checkpoint_path='.models/20241128_084935_8885.tar',
settings_override={
"model_settings": {
# "model_activation_func": "PowScale",
"dropout_prob": 0,
}
},
reset_epoch=True,
)
best = trainer.train()
save_dict_to_file(best, ".models/best_results.json")
...

View File

@@ -0,0 +1,194 @@
from datetime import datetime
from pathlib import Path
import optuna
import warnings
from util.optuna_vis import show_figures
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data
from torchvision import datasets
from torchvision import transforms
import multiprocessing
# from util.dataset import SlicedDataset
DEVICE = torch.device("cuda")
BATCHSIZE = 128
CLASSES = 10
DIR = Path(__file__).parent
EPOCHS = 100
N_TRAIN_EXAMPLES = BATCHSIZE * 30
N_VALID_EXAMPLES = BATCHSIZE * 10
n_trials = 128
n_threads = 16
def define_model(trial):
n_layers = trial.suggest_int("n_layers", 1, 3)
layers = []
in_features = 28 * 28
for i in range(n_layers):
out_features = trial.suggest_int(f"n_units_l{i}", 4, 128)
layers.append(nn.Linear(in_features, out_features))
layers.append(nn.ReLU())
p = trial.suggest_float(f"dropout_l{i}", 0.2, 0.5)
layers.append(nn.Dropout(p))
in_features = out_features
layers.append(nn.Linear(in_features, CLASSES))
layers.append(nn.LogSoftmax(dim=1))
return nn.Sequential(*layers)
def get_mnist():
# Load FashionMNIST dataset.
train_loader = torch.utils.data.DataLoader(
datasets.FashionMNIST(
DIR / ".data", train=True, download=True, transform=transforms.ToTensor()
),
batch_size=BATCHSIZE,
shuffle=True,
)
valid_loader = torch.utils.data.DataLoader(
datasets.FashionMNIST(
DIR / ".data", train=False, transform=transforms.ToTensor()
),
batch_size=BATCHSIZE,
shuffle=True,
)
return train_loader, valid_loader
def objective(trial):
model = define_model(trial).to(DEVICE)
optimizer_name = trial.suggest_categorical("optimizer", ["Adam", "RMSprop", "SGD"])
lr = trial.suggest_float("lr", 1e-5, 1e-1, log=True)
optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)
train_loader, valid_loader = get_mnist()
for epoch in range(EPOCHS):
train_model(model, optimizer, train_loader)
accuracy, num_params = eval_model(model, valid_loader)
return accuracy, num_params
def eval_model(model, valid_loader):
model.eval()
correct = 0
with torch.no_grad():
for batch_idx, (data, target) in enumerate(valid_loader):
if batch_idx * BATCHSIZE >= N_VALID_EXAMPLES:
break
data, target = data.view(data.size(0), -1).to(DEVICE), target.to(DEVICE)
output = model(data)
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
accuracy = correct / min(len(valid_loader.dataset), N_VALID_EXAMPLES)
num_params = sum(p.numel() for p in model.parameters())
return accuracy, num_params
def train_model(model, optimizer, train_loader):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
if batch_idx * BATCHSIZE >= N_TRAIN_EXAMPLES:
break
data, target = data.view(data.size(0), -1).to(DEVICE), target.to(DEVICE)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
def run_optimize(n_trials, study):
study.optimize(objective, n_trials=n_trials, timeout=600)
if __name__ == "__main__":
study_name = f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')} mnist example"
storage = "sqlite:///db.sqlite3"
directions = ["maximize", "minimize"]
study = optuna.create_study(
directions=directions,
storage=storage,
study_name=study_name,
)
with warnings.catch_warnings(action="ignore"):
study.set_metric_names(["accuracy", "num params"])
n_threads = min(n_trials, n_threads)
processes = []
for _ in range(n_threads):
p = multiprocessing.Process(
target=run_optimize, args=(n_trials // n_threads, study)
)
p.start()
processes.append(p)
for p in processes:
p.join()
remaining_trials = n_trials - ((n_trials // n_threads) * n_threads)
if remaining_trials:
print(
f"\nRunning last {remaining_trials} trial{'s' if remaining_trials > 1 else ''}:"
)
run_optimize(directions, remaining_trials, study_name, storage)
print(f"Number of trials on the Pareto front: {len(study.best_trials)}")
trial_with_highest_accuracy = max(study.best_trials, key=lambda t: t.values[1])
print("Trial with highest accuracy: ")
print(f"\tnumber: {trial_with_highest_accuracy.number}")
print(f"\tparams: {trial_with_highest_accuracy.params}")
print(f"\tvalues: {trial_with_highest_accuracy.values}")
# for trial in trials:
# print(f"Trial {trial.number}")
# print(f" Accuracy: {trial.values[0]}")
# print(f" n_params: {int(trial.values[1])}")
# print( " Params: ")
# for key, value in trial.params.items():
# print(" {}: {}".format(key, value))
# print()
# print(" Value: ", trial.value)
# print(" Params: ")
# for key, value in trial.params.items():
# print(" {}: {}".format(key, value))
figures = []
figures.append(
optuna.visualization.plot_pareto_front(
study, target_names=["accuracy", "num_params"]
)
)
figures.append(optuna.visualization.plot_timeline(study))
plt = show_figures(*figures)
print()
# plt.show()

View File

@@ -0,0 +1,51 @@
# move into dir single-core-regen before running
from util.dataset import SlicedDataset
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
import numpy as np
def eye_dataset(dataset, no_symbols=None, offset=False, show=True):
if no_symbols is None:
no_symbols = len(dataset)
_, axs = plt.subplots(2,2, sharex=True, sharey=True)
xaxis = np.linspace(0,dataset.symbols_per_slice,dataset.samples_per_slice)
roll = dataset.samples_per_symbol//2 if offset else 0
for E_out, E_in in dataset[roll:dataset.samples_per_symbol*no_symbols+roll:dataset.samples_per_symbol]:
E_in_x, E_in_y, E_out_x, E_out_y = E_in[0], E_in[1], E_out[0], E_out[1]
axs[0,0].plot(xaxis, np.abs( E_in_x.numpy())**2, alpha=0.05, color='C0')
axs[1,0].plot(xaxis, np.abs( E_in_y.numpy())**2, alpha=0.05, color='C0')
axs[0,1].plot(xaxis, np.abs(E_out_x.numpy())**2, alpha=0.05, color='C0')
axs[1,1].plot(xaxis, np.abs(E_out_y.numpy())**2, alpha=0.05, color='C0')
if show:
plt.show()
# def plt_dataloader(dataloader, show=True):
# _, axs = plt.subplots(2,2, sharex=True, sharey=True)
# E_outs, E_ins = next(iter(dataloader))
# for i, (E_out, E_in) in enumerate(zip(E_outs, E_ins)):
# xaxis = np.linspace(dataset.symbols_per_slice*i,dataset.symbols_per_slice+dataset.symbols_per_slice*i,dataset.samples_per_slice)
# E_in_x, E_in_y, E_out_x, E_out_y = E_in[0], E_in[1], E_out[0], E_out[1]
# axs[0,0].plot(xaxis, np.abs(E_in_x.numpy())**2)
# axs[1,0].plot(xaxis, np.abs(E_in_y.numpy())**2)
# axs[0,1].plot(xaxis, np.abs(E_out_x.numpy())**2)
# axs[1,1].plot(xaxis, np.abs(E_out_y.numpy())**2)
# if show:
# plt.show()
if __name__ == "__main__":
dataset = SlicedDataset("data/20241115-175517-128-16384-10000-0-0-17-0-PAM4-0.ini", symbols=1, drop_first=100)
print(dataset[0][0].shape)
eye_dataset(dataset, 1000, offset=True, show=False)
train_loader = DataLoader(dataset, batch_size=10, shuffle=False)
# plt_dataloader(train_loader, show=False)
plt.show()

View File

@@ -0,0 +1,68 @@
import torch
import time
def print_torch_env():
print("Torch version: ", torch.__version__)
print("CUDA available: ", torch.cuda.is_available())
print("CUDA version: ", torch.version.cuda)
print("CUDNN version: ", torch.backends.cudnn.version())
print("Device count: ", torch.cuda.device_count())
print("Current device: ", torch.cuda.current_device())
print("Device name: ", torch.cuda.get_device_name(0))
print("Device capability: ", torch.cuda.get_device_capability(0))
print("Device memory: ", torch.cuda.get_device_properties(0).total_memory)
def measure_runtime(func):
"""
Measure the runtime of a function.
:param func: Function to measure
:type func: function
:return: Wrapped function with runtime measurement
:rtype: function
"""
def wrapper(*args, **kwargs):
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
print(f"Runtime: {end_time - start_time:.6f} seconds")
return result, end_time - start_time
return wrapper
@measure_runtime
def tensor_addition(a, b):
"""
Perform tensor addition.
:param a: First tensor
:type a: torch.Tensor
:param b: Second tensor
:type b: torch.Tensor
:return: Sum of tensors
:rtype: torch.Tensor
"""
return a + b
def runtime_test():
x = torch.rand(2**18, 2**10)
y = torch.rand(2**18, 2**10)
print("Tensor addition on CPU")
_, cpu_time = tensor_addition(x, y)
print()
print("Tensor addition on GPU")
if not torch.cuda.is_available():
print("CUDA is not available")
return
_, gpu_time = tensor_addition(x.cuda(), y.cuda())
print()
print(f"Speedup: {cpu_time / gpu_time *100:.2f}%")
if __name__ == "__main__":
print_torch_env()
print()
runtime_test()

View File

@@ -0,0 +1,19 @@
from . import datasets # noqa: F401
# from .datasets import FiberRegenerationDataset # noqa: F401
# from .datasets import load_data # noqa: F401
from . import plot # noqa: F401
# from .plot import eye # noqa: F401
from . import optuna_helpers # noqa: F401
# from .optuna_helpers import optional_suggest_categorical # noqa: F401
# from .optuna_helpers import optional_suggest_float # noqa: F401
# from .optuna_helpers import optional_suggest_int # noqa: F401
from . import complexNN # noqa: F401
# from .complexNN import UnitaryLayer # noqa: F401
# from .complexNN import complex_mse_loss # noqa: F401
# from .complexNN import complex_sse_loss # noqa: F401
from . import misc # noqa: F401

View File

@@ -0,0 +1,504 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
# from torchlambertw.special import lambertw
def complex_mse_loss(input, target, power=False, reduction="mean"):
"""
Compute the mean squared error between two complex tensors.
If power is set to True, the loss is computed as |input|^2 - |target|^2
"""
reduce = getattr(torch, reduction)
if power:
input = (input * input.conj()).real.to(dtype=input.dtype.to_real())
target = (target * target.conj()).real.to(dtype=target.dtype.to_real())
if input.is_complex() and target.is_complex():
return reduce(torch.square(input.real - target.real) + torch.square(input.imag - target.imag))
elif input.is_complex() or target.is_complex():
raise ValueError("Input and target must have the same type (real or complex)")
else:
return F.mse_loss(input, target, reduction=reduction)
def complex_sse_loss(input, target):
"""
Compute the sum squared error between two complex tensors.
"""
if input.is_complex():
return torch.sum(torch.square(input.real - target.real) + torch.square(input.imag - target.imag))
else:
return torch.sum(torch.square(input - target))
class UnitaryLayer(nn.Module):
def __init__(self, in_features, out_features, dtype=None):
assert in_features >= out_features
super(UnitaryLayer, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = nn.Parameter(torch.randn(in_features, out_features, dtype=dtype))
self.reset_parameters()
def reset_parameters(self):
q, _ = torch.linalg.qr(self.weight)
self.weight.data = q
def forward(self, x):
return torch.matmul(x, self.weight)
def __repr__(self):
return f"UnitaryLayer({self.in_features}, {self.out_features})"
class _Unitary(nn.Module):
def forward(self, X:torch.Tensor):
if X.ndim < 2:
raise ValueError(
"Only tensors with 2 or more dimensions are supported. "
f"Got a tensor of shape {X.shape}"
)
n, k = X.size(-2), X.size(-1)
transpose = n<k
if transpose:
X = X.transpose(-2, -1)
q, r = torch.linalg.qr(X)
# q: torch.Tensor = q
# r: torch.Tensor = r
d = r.diagonal(dim1=-2, dim2=-1).sgn()
q*=d.unsqueeze(-2)
if transpose:
q = q.transpose(-2, -1)
if n == k:
mask = (torch.linalg.det(q).abs() >= 0).to(q.dtype.to_real())
mask[mask == 0] = -1
mask = mask.unsqueeze(-1)
q[..., 0] *= mask
# X.copy_(q)
return q
def unitary(module: nn.Module, name: str = "weight") -> nn.Module:
weight = getattr(module, name, None)
if not isinstance(weight, torch.Tensor):
raise ValueError(f"Module '{module}' has no parameter or buffer '{name}'")
if weight.ndim < 2:
raise ValueError(f"Expected a matrix or batch of matrices. Got a tensor of {weight.ndim} dimensions.")
if weight.shape[-2] != weight.shape[-1]:
raise ValueError(f"Expected a square matrix or batch of square matrices. Got a tensor of shape {weight.shape}")
unit = _Unitary()
nn.utils.parametrize.register_parametrization(module, name, unit)
return module
class _SpecialUnitary(nn.Module):
def __init__(self):
super().__init__()
def forward(self, X:torch.Tensor):
n, k = X.size(-2), X.size(-1)
if n != k:
raise ValueError(f"Expected a square matrix. Got a tensor of shape {X.shape}")
q, _ = torch.linalg.qr(X)
q = q / torch.linalg.det(q).pow(1/n)
return q
def special_unitary(module: nn.Module, name: str = "weight") -> nn.Module:
weight = getattr(module, name, None)
if not isinstance(weight, torch.Tensor):
raise ValueError(f"Module '{module}' has no parameter or buffer '{name}'")
if weight.ndim < 2:
raise ValueError(f"Expected a matrix or batch of matrices. Got a tensor of {weight.ndim} dimensions.")
if weight.shape[-2] != weight.shape[-1]:
raise ValueError(f"Expected a square matrix or batch of square matrices. Got a tensor of shape {weight.shape}")
unit = _SpecialUnitary()
nn.utils.parametrize.register_parametrization(module, name, unit)
return module
class _Clamp(nn.Module):
def __init__(self, min, max):
super(_Clamp, self).__init__()
self.min = min
self.max = max
def forward(self, x):
if x.is_complex():
# clamp magnitude, ignore phase
return torch.clamp(x.abs(), self.min, self.max) * x / x.abs()
return torch.clamp(x, self.min, self.max)
def clamp(module: nn.Module, name: str = "scale", min=0, max=1) -> nn.Module:
scale = getattr(module, name, None)
if not isinstance(scale, torch.Tensor):
raise ValueError(f"Module '{module}' has no parameter or buffer '{name}'")
cl = _Clamp(min, max)
nn.utils.parametrize.register_parametrization(module, name, cl)
return module
class ONNMiller(nn.Module):
def __init__(self, input_dim, output_dim, dtype=None) -> None:
super(ONNMiller, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.dtype = dtype
self.dim = max(input_dim, output_dim)
# zero pad input to internal size if smaller
if self.input_dim < self.dim:
self.pad = lambda x: F.pad(x, ((self.dim - self.input_dim) // 2, (self.dim - self.input_dim + 1) // 2))
else:
self.pad = lambda x: x
self.pad.__doc__ = f"Zero pad input from {self.input_dim} to {self.dim}"
# crop output to desired size
if self.output_dim < self.dim:
self.crop = lambda x: x[:, (self.dim - self.output_dim) // 2 : (x.shape[1] - (self.dim - self.output_dim + 1) // 2)]
else:
self.crop = lambda x: x
self.crop.__doc__ = f"Crop output from {self.dim} to {self.output_dim}"
self.U = nn.Parameter(torch.randn(self.dim, self.dim, dtype=self.dtype)) # -> parametrization: Unitary
self.S = nn.Parameter(torch.randn(self.dim, dtype=self.dtype)) # -> parametrization: Clamp (magnitude 0..1)
self.V = nn.Parameter(torch.randn(self.dim, self.dim, dtype=self.dtype)) # -> parametrization: Unitary
self.register_buffer("MZI_scale", torch.tensor(2, dtype=self.dtype.to_real()).sqrt())
# V is actually V.H, but
def forward(self, x_in):
x = x_in
x = self.pad(x)
x = x @ self.U
x = x * (self.S.squeeze() / self.MZI_scale)
x = x @ self.V
x = self.crop(x)
return x
class ONN(nn.Module):
def __init__(self, input_dim, output_dim, dtype=None) -> None:
super(ONN, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.dtype = dtype
self.dim = max(input_dim, output_dim)
# zero pad input to internal size if smaller
if self.input_dim < self.dim:
self.pad = lambda x: F.pad(x, ((self.dim - self.input_dim) // 2, (self.dim - self.input_dim + 1) // 2))
self.pad.__doc__ = f"Zero pad input from {self.input_dim} to {self.dim}"
else:
self.pad = lambda x: x
self.pad.__doc__ = f"Input size equals internal size {self.dim}"
# crop output to desired size
if self.output_dim < self.dim:
self.crop = lambda x: x[:, (self.dim - self.output_dim) // 2 : (x.shape[1] - (self.dim - self.output_dim + 1) // 2)]
self.crop.__doc__ = f"Crop output from {self.dim} to {self.output_dim}"
else:
self.crop = lambda x: x
self.crop.__doc__ = f"Output size equals internal size {self.dim}"
self.weight = nn.Parameter(torch.randn(self.dim, self.dim, dtype=self.dtype))
def reset_parameters(self):
q, _ = torch.linalg.qr(self.weight)
self.weight.data = q
# def get_M(self):
# return self.U @ self.sigma @ self.V
def forward(self, x):
return self.crop(self.pad(x) @ self.weight)
class SemiUnitaryLayer(nn.Module):
def __init__(self, input_dim, output_dim, dtype=None):
super(SemiUnitaryLayer, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
# Create a larger square matrix for QR decomposition
self.weight = nn.Parameter(torch.randn(max(input_dim, output_dim), max(input_dim, output_dim), dtype=dtype))
self.scale = nn.Parameter(torch.tensor(1.0, dtype=dtype.to_real()))
self.reset_parameters()
def reset_parameters(self):
# Ensure the weights are unitary by QR decomposition
q, _ = torch.linalg.qr(self.weight)
# A = QR with A being a complex square matrix -> Q is unitary, R is upper triangular
# truncate the matrix to the desired size
if self.input_dim > self.output_dim:
self.weight.data = q[: self.input_dim, : self.output_dim]
else:
self.weight.data = q[: self.output_dim, : self.input_dim].t()
...
def forward(self, x):
with torch.no_grad():
scale = torch.clamp(self.scale, 0.0, 1.0)
out = torch.matmul(x, scale * self.weight)
return out
def __repr__(self):
return f"SemiUnitaryLayer({self.input_dim}, {self.output_dim})"
# class SaturableAbsorberLambertW(nn.Module):
# """
# Implements the activation function for an optical saturable absorber
# base eqn: sigma*tau*I0 = 0.5*(log(Tm/T0))/(1-Tm),
# where: sigma is the absorption cross section
# tau is the radiative lifetime of the absorber material
# T0 is the initial transmittance
# I0 is the input intensity
# Tm is the transmittance of the absorber
# The activation function is defined as:
# Iout = I0 * Tm(I0)
# where Tm(I0) is the transmittance of the absorber as a function of the input intensity I0
# for a unit sigma*tau product, he solution Tm(I0) is given by:
# Tm(I0) = (W(2*exp(2*I0)*I0*T0))/(2*I0),
# where W is the Lambert W function
# if sigma*tau is not 1, I0 has to be scaled by sigma*tau
# (-> x has to be scaled by sqrt(sigma*tau))
# """
# def __init__(self, T0):
# super(SaturableAbsorberLambertW, self).__init__()
# self.register_buffer("T0", torch.tensor(T0))
# def forward(self, x: torch.Tensor):
# xc = x.conj()
# two_x_xc = (2 * x * xc).real
# return (lambertw(2 * torch.exp(two_x_xc) * (x * self.T0 * xc).real) / two_x_xc).to(dtype=x.dtype)
# def backward(self, x):
# xc = x.conj()
# lambert_eval = lambertw(2 * torch.exp(2 * x * xc).real * (x * self.T0 * xc).real)
# return (((xc * (-2 * lambert_eval + 2 * torch.square(x) - 1) + 2 * x * torch.square(xc) + x) * lambert_eval) / (
# 2 * torch.pow(x, 3) * xc * (lambert_eval + 1)
# )).to(dtype=x.dtype)
# class SaturableAbsorber(nn.Module):
# def __init__(self, alpha, I0):
# super(SaturableAbsorber, self).__init__()
# self.register_buffer("alpha", torch.tensor(alpha))
# self.register_buffer("I0", torch.tensor(I0))
# def forward(self, x):
# I = (x*x.conj()).to(dtype=x.dtype.to_real())
# A = self.alpha/(1+I/self.I0)
# class SpreadLayer(nn.Module):
# def __init__(self, in_features, out_features, dtype=None):
# super(SpreadLayer, self).__init__()
# self.in_features = in_features
# self.out_features = out_features
# self.mat = torch.ones(in_features, out_features, dtype=dtype)*torch.sqrt(torch.tensor(in_features/out_features))
# def forward(self, x):
# # N in_features -> M out_features, Enery is preserved (P = abs(x)^2)
# out = torch.matmul(x, self.mat)
# return out
#### as defined by zhang et al
class DropoutComplex(nn.Module):
def __init__(self, p=0.5):
super(DropoutComplex, self).__init__()
self.dropout = nn.Dropout(p=p)
def forward(self, x):
if x.is_complex():
mask = self.dropout(torch.ones_like(x.real))
return x * mask
else:
return self.dropout(x)
class Identity(nn.Module):
"""
implements the "activation" function
M(z) = z
"""
def __init__(self):
super(Identity, self).__init__()
def forward(self, x):
return x
class PowRot(nn.Module):
def __init__(self, bias=False):
super(PowRot, self).__init__()
self.scale = nn.Parameter(torch.tensor(1.0))
if bias:
self.bias = nn.Parameter(torch.tensor(0.0))
else:
self.register_buffer("bias", torch.tensor(0.0))
def forward(self, x: torch.Tensor):
if x.is_complex():
return x * torch.exp(-self.scale*1j*x.abs().square()+self.bias.to(dtype=x.dtype))
else:
return x
class Pow(nn.Module):
"""
implements the activation function
M(z) = ||z||^2 + b
"""
def __init__(self, bias=False):
super(Pow, self).__init__()
if bias:
self.bias = nn.Parameter(torch.tensor(0.0))
else:
self.register_buffer("bias", torch.tensor(0.0))
def forward(self, x: torch.Tensor):
return x.abs().square().add(self.bias).to(dtype=x.dtype)
class Mag(nn.Module):
"""
implements the activation function
M(z) = ||z||+b
"""
def __init__(self, bias=False):
super(Mag, self).__init__()
if bias:
self.bias = nn.Parameter(torch.tensor(0.0))
else:
self.register_buffer("bias", torch.tensor(0.0))
def forward(self, x: torch.Tensor):
return x.abs().add(self.bias).to(dtype=x.dtype)
class MagScale(nn.Module):
def __init__(self, bias=False):
super(MagScale, self).__init__()
if bias:
self.bias = nn.Parameter(torch.tensor(0.0))
else:
self.register_buffer("bias", torch.tensor(0.0))
def forward(self, x: torch.Tensor):
return x.abs().add(self.bias).to(dtype=x.dtype).sin().mul(x)
class PowScale(nn.Module):
def __init__(self, bias=False):
super(PowScale, self).__init__()
if bias:
self.bias = nn.Parameter(torch.tensor(0.0))
else:
self.register_buffer("bias", torch.tensor(0.0))
def forward(self, x: torch.Tensor):
return x.mul(x.abs().square().add(self.bias).to(dtype=x.dtype).sin())
class ModReLU(nn.Module):
"""
implements the activation function
M(z) = ReLU(||z|| + b)*exp(j*theta_z)
= ReLU(||z|| + b)*z/||z||
"""
def __init__(self, bias=True):
super(ModReLU, self).__init__()
if bias:
self.bias = nn.Parameter(torch.tensor(0.0))
else:
self.register_buffer("bias", torch.tensor(0.0))
def forward(self, x):
if x.is_complex():
mod = x.abs()
out = torch.relu(mod + self.bias) * x / mod
return out.to(dtype=x.dtype)
else:
return torch.relu(x + self.bias).to(dtype=x.dtype)
def __repr__(self):
return f"ModReLU(b={self.b})"
class CReLU(nn.Module):
"""
implements the activation function
M(z) = ReLU(Re(z)) + j*ReLU(Im(z))
"""
def __init__(self):
super(CReLU, self).__init__()
def forward(self, x):
if x.is_complex():
return torch.relu(x.real) + 1j * torch.relu(x.imag)
else:
return torch.relu(x)
class ZReLU(nn.Module):
"""
implements the activation function
M(z) = z if 0 <= angle(z) <= pi/2
= 0 otherwise
"""
def __init__(self):
super(ZReLU, self).__init__()
def forward(self, x):
if x.is_complex():
return x * (torch.angle(x) >= 0) * (torch.angle(x) <= torch.pi / 2)
else:
return torch.relu(x)
__all__ = [
complex_sse_loss,
complex_mse_loss,
UnitaryLayer,
unitary,
clamp,
ONN,
ONNMiller,
SemiUnitaryLayer,
DropoutComplex,
Identity,
Pow,
PowRot,
Mag,
ModReLU,
CReLU,
ZReLU,
# SaturableAbsorberLambertW,
# SaturableAbsorber,
# SpreadLayer,
]

View File

@@ -0,0 +1,282 @@
from pathlib import Path
import torch
from torch.utils.data import Dataset
# from torch.utils.data import Sampler
import numpy as np
import configparser
# class SubsetSampler(Sampler[int]):
# """
# Samples elements from a given list of indices.
# :param indices: List of indices to sample from.
# :type indices: list[int]
# """
# def __init__(self, indices):
# self.indices = indices
# def __iter__(self):
# return iter(self.indices)
# def __len__(self):
# return len(self.indices)
def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=False, device=None, dtype=None):
filepath = Path(config_path)
filepath = filepath.parent.glob(filepath.name)
config = configparser.ConfigParser()
config.read(filepath)
path_elements = (
config["data"]["dir"],
config["data"]["npy_dir"],
config["data"]["file"],
)
datapath = Path("/".join(path_elements).replace('"', ""))
sps = int(config["glova"]["sps"])
if symbols is None:
symbols = int(config["glova"]["nos"]) - skipfirst
data = np.load(datapath)[skipfirst * sps : symbols * sps + skipfirst * sps]
if normalize:
# square gets normalized to 1, as the power is (proportional to) the square of the amplitude
a, b, c, d = np.square(data.T)
a, b, c, d = a / np.max(np.abs(a)), b / np.max(np.abs(b)), c / np.max(np.abs(c)), d / np.max(np.abs(d))
data = np.sqrt(np.array([a, b, c, d]).T)
if real:
data = np.abs(data)
config["glova"]["nos"] = str(symbols)
data = torch.tensor(data, device=device, dtype=dtype)
return data, config
def roll_along(arr, shifts, dim):
# https://stackoverflow.com/a/76920720
# (c) Mateen Ulhaq, 2023
# CC BY-SA 4.0
shifts = torch.tensor(shifts)
assert arr.ndim - 1 == shifts.ndim
dim %= arr.ndim
shape = (1,) * dim + (-1,) + (1,) * (arr.ndim - dim - 1)
dim_indices = torch.arange(arr.shape[dim]).reshape(shape)
indices = (dim_indices - shifts.unsqueeze(dim)) % arr.shape[dim]
return torch.gather(arr, dim, indices)
class FiberRegenerationDataset(Dataset):
"""
Dataset for fiber regeneration training.
The dataset is loaded from a configuration file, which must contain (at least) the following sections:
```
[data]
dir = <data_dir>
npy_dir = <npy_dir>
file = <data_file>
[glova]
sps = <samples per symbol>
```
The data is loaded from the file `<data_dir>/<npy_dir>/<data_file>` and is assumed to be in the following format:
```
[ E_in_x,
E_in_y,
E_out_x,
E_out_y ]
```
The dataset is sliced into slices, where each slice consists of a (fractional) number of symbols.
The target can be delayed relative to the input data by a (fractional) number of symbols.
The x and y channels can be delayed relative to each other by a (fractional) number of symbols.
"""
def __init__(
self,
file_path: str | Path,
symbols: int | float,
*,
output_dim: int = None,
target_delay: float | int = 0,
xy_delay: float | int = 0,
drop_first: float | int = 0,
dtype: torch.dtype = None,
real: bool = False,
device=None,
**kwargs,
):
"""
Initialize the dataset.
:param file_path: Path to the data file. Can contain wildcards (*). The first
:type file_path: str | pathlib.Path
:param symbols: Number of symbols in each slice. Can be a float to specify a fraction of a symbol.
:type symbols: float | int
:param data_size: Number of samples in each slice. The data is reduced by taking equally spaced samples. If unset, each slice will contain symbols*samples_per_symbol samples.
:type data_size: int, optional
:param target_delay: Delay (in fractional symbols) between data and target. A positive delay means the target is delayed relative to the data. Default is 0.
:type target_delay: float | int, optional
:param xy_delay: Delay (in fractional symbols) between the x and y channels. A positive delay means the y channel is delayed relative to the x channel. Default is 0.
:type xy_delay: float | int, optional
:param drop_first: Number of (fractional) symbols to drop from the beginning
:type drop_first: float | int
"""
# check types
assert isinstance(file_path, str), "file_path must be a string"
assert isinstance(symbols, (float, int)), "symbols must be a float or an integer"
assert output_dim is None or isinstance(output_dim, int), "output_len must be an integer"
assert isinstance(target_delay, (float, int)), "target_delay must be a float or an integer"
assert isinstance(xy_delay, (float, int)), "xy_delay must be a float or an integer"
assert isinstance(drop_first, int), "drop_first must be an integer"
# check values
assert symbols > 0, "symbols must be positive"
assert output_dim is None or output_dim > 0, "output_len must be positive or None"
assert drop_first >= 0, "drop_first must be non-negative"
faux = kwargs.pop("faux", False)
if faux:
data_raw = np.array(
[[i + 0.1j, i + 0.2j, i + 1.1j, i + 1.2j] for i in range(12800)],
dtype=np.complex128,
)
data_raw = torch.tensor(data_raw, device=device, dtype=dtype)
self.config = {
"data": {"dir": '"."', "npy_dir": '"."', "file": "faux"},
"glova": {"sps": 128},
}
else:
data_raw, self.config = load_data(
file_path,
skipfirst=drop_first,
symbols=kwargs.pop("num_symbols", None),
real=real,
normalize=True,
device=device,
dtype=dtype,
)
self.device = data_raw.device
self.samples_per_symbol = int(self.config["glova"]["sps"])
self.samples_per_slice = int(symbols * self.samples_per_symbol)
self.symbols_per_slice = self.samples_per_slice / self.samples_per_symbol
self.output_dim = output_dim or self.samples_per_slice
self.target_delay = target_delay or 0
self.xy_delay = xy_delay or 0
ovrd_target_delay_samples = kwargs.pop("ovrd_target_delay_samples", None)
ovrd_xy_delay_samples = kwargs.pop("ovrd_xy_delay_samples", None)
self.target_delay_samples = (
ovrd_target_delay_samples
if ovrd_target_delay_samples is not None
else int(self.target_delay * self.samples_per_symbol)
)
self.xy_delay_samples = (
ovrd_xy_delay_samples if ovrd_xy_delay_samples is not None else int(self.xy_delay * self.samples_per_symbol)
)
# data_raw = torch.tensor(data_raw, dtype=dtype)
# data layout
# [ [E_in_x0, E_in_y0, E_out_x0, E_out_y0],
# [E_in_x1, E_in_y1, E_out_x1, E_out_y1],
# ...
# [E_in_xN, E_in_yN, E_out_xN, E_out_yN] ]
data_raw = data_raw.transpose(0, 1)
# data layout
# [ E_in_x[0:N],
# E_in_y[0:N],
# E_out_x[0:N],
# E_out_y[0:N] ]
# shift x data by xy_delay_samples relative to the y data (example value: 3)
# [ E_in_x [0:N], [ E_in_x [ 0:N ], [ E_in_x [3:N ],
# E_in_y [0:N], -> E_in_y [-3:N-3], -> E_in_y [0:N-3],
# E_out_x[0:N], E_out_x[ 0:N ], E_out_x[3:N ],
# E_out_y[0:N] ] E_out_y[-3:N-3] ] E_out_y[0:N-3] ]
if self.xy_delay_samples != 0:
data_raw = roll_along(data_raw, [0, self.xy_delay_samples, 0, self.xy_delay_samples], dim=1)
if self.xy_delay_samples > 0:
data_raw = data_raw[:, self.xy_delay_samples :]
elif self.xy_delay_samples < 0:
data_raw = data_raw[:, : self.xy_delay_samples]
# shift fiber input data (target) by target_delay_samples relative to the fiber output data (input)
# (example value: 5)
# [ E_in_x [0:N], [ E_in_x [-5:N-5], [ E_in_x [0:N-5],
# E_in_y [0:N], -> E_in_y [-5:N-5], -> E_in_y [0:N-5],
# E_out_x[0:N], E_out_x[ 0:N ], E_out_x[5:N ],
# E_out_y[0:N] ] E_out_y[ 0:N ] ] E_out_y[5:N ] ]
if self.target_delay_samples != 0:
data_raw = roll_along(
data_raw,
[self.target_delay_samples, self.target_delay_samples, 0, 0],
dim=1,
)
if self.target_delay_samples > 0:
data_raw = data_raw[:, self.target_delay_samples :]
elif self.target_delay_samples < 0:
data_raw = data_raw[:, : self.target_delay_samples]
data_raw = data_raw.view(2, 2, -1)
# data layout
# [ [E_in_x, E_in_y],
# [E_out_x, E_out_y] ]
self.data = data_raw.unfold(dimension=-1, size=self.samples_per_slice, step=1)
self.data = self.data.movedim(-2, 0)
# -> [no_slices, 2, 2, samples_per_slice]
# data layout
# [
# [ [E_in_x[0:N+0], E_in_y[0:N+0] ], [ E_out_x[0:N+0], E_out_y[0:N+0] ] ],
# [ [E_in_x[1:N+1], E_in_y[1:N+1] ], [ E_out_x[1:N+1], E_out_y[1:N+1] ] ],
# ...
# ] -> [no_slices, 2, 2, samples_per_slice]
...
def __len__(self):
return self.data.shape[0]
def __getitem__(self, idx):
if isinstance(idx, slice):
return [self.__getitem__(i) for i in range(*idx.indices(len(self)))]
else:
data, target = self.data[idx, 1].squeeze(), self.data[idx, 0].squeeze()
# reduce by by taking self.output_dim equally spaced samples
data = data[:, : data.shape[1] // self.output_dim * self.output_dim]
data = data.view(data.shape[0], self.output_dim, -1)
data = data[:, :, 0]
# target is corresponding to the middle of the data as the output sample is influenced by the data before and after it
target = target[:, : target.shape[1] // self.output_dim * self.output_dim]
target = target.view(target.shape[0], self.output_dim, -1)
target = target[:, 0, target.shape[2] // 2]
data = data.transpose(0, 1).flatten().squeeze()
target = target.flatten().squeeze()
# data layout:
# [sample_x0, sample_y0, sample_x1, sample_y1, ...]
# target layout:
# [sample_x0, sample_y0]
return data, target

View File

@@ -0,0 +1,21 @@
def multi_getattr(objs, attr, fallback=None):
"""
tries to get the attribute from a list of objects, returning the first hit
if no object has the attribute, it returns the fallback value if provided, otherwise raises AttributeError
"""
try:
return _multi_getattr(objs, attr)
except AttributeError as e:
if fallback is not None:
return fallback
raise e
def _multi_getattr(objs, attr):
if not isinstance(objs, (list, tuple)):
objs = [objs]
for obj in objs:
try:
return getattr(obj, attr)
except AttributeError:
pass
raise AttributeError(f"None of the objects has attribute {attr}")

View File

@@ -0,0 +1,335 @@
from typing import Any
from optuna import trial
def install_optional_suggests():
trial.Trial.suggest_categorical_optional = suggest_categorical_optional_wrapper
trial.Trial.suggest_int_optional = suggest_int_optional_wrapper
trial.Trial.suggest_float_optional = suggest_float_optional_wrapper
def _is_listlike(obj: Any) -> bool:
return hasattr(obj, "__iter__") and not isinstance(obj, str)
def _optional_suggest(
*,
trial: trial.Trial,
name: str,
range_or_value: Any,
type: str,
log: bool = False,
step: int | float | None = None,
add_user: bool = True,
force: bool = False,
multiply: float | int = 1,
set_new: bool = True,
):
"""
Suggest a value for a parameter with more control over the process
Parameters
----------
type : str
The type of the parameter
trial : optuna.trial.Trial
The trial object
name : str
The name of the parameter
range_or_value : Any
The range of values or a single value
log : bool, optional
Whether to use a logarithmic scale, by default False
step : int|float|None, optional
The step size, by default None
add_user : bool, optional
Whether to add the suggested value to the user attributes if not added as a parameter, by default False
force : bool, optional
Whether to force a single value to be suggested, by default False
multiply : float| int, optional
A multiplier to apply to the range or value, by default 1. Ignored for type "categorical".
set_new : bool, optional
Whether to override the parameter if it already exists, by default True
"""
# value should be retrieved from trial
if not set_new and name in trial.params:
return trial.params[name]
# value is not a list or tuple
if not _is_listlike(range_or_value):
range_or_value = (range_or_value,)
# range with only one value
if len(range_or_value) == 1 and not force:
if add_user:
trial.set_user_attr(name, range_or_value[0])
return range_or_value[0]
# normal operation
if type == "categorical":
return trial.suggest_categorical(name, range_or_value)
# multiply range
range_or_value = tuple(multiply * x for x in range_or_value)
#
if len(range_or_value) > 2:
raise UserWarning("More than two values in range, using highest and lowest")
low = min(range_or_value)
high = max(range_or_value)
if type == "float":
return trial.suggest_float(name, low, high, step=step, log=log)
if type == "int":
step = step or 1
lowi = int(low)
highi = int(high)
if lowi != low or highi != high:
raise ValueError(f"Range {low} to {high} (using multiplier {multiply}) is not valid for int")
return trial.suggest_int(name, lowi, highi, step=step, log=log)
raise ValueError(f"Unknown type: {type}")
def suggest_categorical_optional(
trial: trial.Trial,
name: str,
choices_or_value: tuple[Any] | list[Any] | Any,
add_user: bool = True,
force: bool = False,
set_new: bool = True,
):
"""
Suggest a value for a categorical parameter with more control over the process
Parameters
----------
trial : optuna.trial.Trial
The trial object
name : str
The name of the parameter
choices_or_value : tuple|list|Any
The choices or a single value
add_user : bool, optional
Whether to add the suggested value to the user attributes if not added as a parameter, by default False
force : bool, optional
Whether to suggest a single value as a parameter, by default False
set_new : bool, optional
Whether to override the parameter if it already exists, by default True
"""
return _optional_suggest(
trial=trial, name=name, range_or_value=choices_or_value, type="categorical", add_user=add_user, force=force, set_new=set_new
)
def suggest_int_optional(
trial: trial.Trial,
name: str,
range_or_value: tuple[int] | list[int] | int,
step: int = 1,
log: bool = False,
add_user: bool = True,
force: bool = False,
multiply: int = 1,
set_new: bool = True,
):
"""
Suggest a value for an integer parameter with more control over the process
Parameters
----------
trial : optuna.trial.Trial
The trial object
name : str
The name of the parameter
range_or_value : tuple|list|int
The range of values or a single value.
step : int, optional
The step size, by default 1
log : bool, optional
Whether to use a logarithmic scale, by default False
add_user : bool, optional
Whether to add the suggested value to the user attributes if not added as a parameter, by default False
force : bool, optional
Whether to suggest a single value as a parameter, by default False
"""
return _optional_suggest(
trial=trial,
name=name,
range_or_value=range_or_value,
step=step,
log=log,
type="int",
add_user=add_user,
force=force,
multiply=multiply,
set_new=set_new,
)
def suggest_float_optional(
trial: trial.Trial,
name: str,
range_or_value: tuple[float] | list[float] | float,
step: float | None = None,
log: bool = False,
add_user: bool = True,
force: bool = False,
multiply: float = 1,
set_new: bool = True,
):
"""
Suggest a value for a float parameter with more control over the process
Parameters
----------
trial : optuna.trial.Trial
The trial object
name : str
The name of the parameter
range_or_value : tuple|list|float
The range of values or a single value
step : float|None, optional
The step size, by default None
log : bool, optional
Whether to use a logarithmic scale, by default False
add_user : bool, optional
Whether to add the suggested value to the user attributes if not added as a parameter, by default False
force : bool, optional
Whether to suggest a single value as a parameter, by default False
multiply : float, optional
A multiplier to apply to the range or value, by default 1
set_new : bool, optional
Whether to override the parameter if it already exists, by default True
"""
return _optional_suggest(
trial=trial,
name=name,
range_or_value=range_or_value,
step=step,
log=log,
type="float",
add_user=add_user,
force=force,
multiply=multiply,
set_new=set_new,
)
def suggest_categorical_optional_wrapper(
self: trial.Trial,
name: str,
choices_or_value: tuple[Any] | list[Any] | Any,
add_user: bool = True,
force: bool = False,
set_new: bool = True,
):
"""
Suggest a value for a categorical parameter with more control over the process
Parameters
----------
name : str
The name of the parameter
choices_or_value : tuple|list|Any
The choices or a single value
add_user : bool, optional
Whether to add the suggested value to the user attributes if not added as a parameter, by default False
force : bool, optional
Whether to suggest a single value as a parameter, by default False
set_new : bool, optional
Whether to override the parameter if it already exists, by default True
"""
return suggest_categorical_optional(
trial=self, name=name, choices_or_value=choices_or_value, add_user=add_user, force=force, set_new=set_new
)
def suggest_int_optional_wrapper(
self: trial.Trial,
name: str,
range_or_value: tuple[int] | list[int] | int,
step: int = 1,
log: bool = False,
add_user: bool = True,
force: bool = False,
multiply: int = 1,
set_new: bool = True,
):
"""
Suggest a value for an integer parameter with more control over the process
Parameters
----------
name : str
The name of the parameter
range_or_value : tuple|list|int
The range of values or a single value.
step : int, optional
The step size, by default 1
log : bool, optional
Whether to use a logarithmic scale, by default False
add_user : bool, optional
Whether to add the suggested value to the user attributes if not added as a parameter, by default False
force : bool, optional
Whether to suggest a single value as a parameter, by default False
"""
return suggest_int_optional(
trial=self,
name=name,
range_or_value=range_or_value,
step=step,
log=log,
add_user=add_user,
force=force,
multiply=multiply,
set_new=set_new,
)
def suggest_float_optional_wrapper(
self: trial.Trial,
name: str,
range_or_value: tuple[float] | list[float] | float,
step: float | None = None,
log: bool = False,
add_user: bool = True,
force: bool = False,
multiply: float = 1,
set_new: bool = True,
):
"""
Suggest a value for a float parameter with more control over the process
Parameters
----------
name : str
The name of the parameter
range_or_value : tuple|list|float
The range of values or a single value
step : float|None, optional
The step size, by default None
log : bool, optional
Whether to use a logarithmic scale, by default False
add_user : bool, optional
Whether to add the suggested value to the user attributes if not added as a parameter, by default False
force : bool, optional
Whether to suggest a single value as a parameter, by default False
multiply : float, optional
A multiplier to apply to the range or value, by default 1
set_new : bool, optional
Whether to override the parameter if it already exists, by default True
"""
return suggest_float_optional(
trial=self,
name=name,
range_or_value=range_or_value,
step=step,
log=log,
add_user=add_user,
force=force,
multiply=multiply,
set_new=set_new,
)

View File

@@ -0,0 +1,18 @@
from dash import Dash, dcc, html
import logging
import dash_bootstrap_components as dbc
def show_figures(*figures):
for figure in figures:
figure.layout.template = 'plotly_dark'
app = Dash(external_stylesheets=[dbc.themes.DARKLY])
app.layout = html.Div([
dcc.Graph(figure=figure) for figure in figures
])
log = logging.getLogger('werkzeug')
log.setLevel(logging.ERROR)
app.show = lambda *args, **kwargs: app.run_server(*args, **kwargs, debug=False)
return app

View File

@@ -0,0 +1,73 @@
import matplotlib.pyplot as plt
import numpy as np
from .datasets import load_data
def eye(*, path=None, data=None, sps=None, title=None, symbols=1000, skipfirst=0, width=2, alpha=None, complex=False, show=True):
"""Plot an eye diagram for the data given by filepath.
Either path or data and sps must be given.
Args:
path (str): Path to the data description file.
data (np.ndarray): Data to plot.
sps (int): Samples per symbol.
title (str): Title of the plot.
head (int): Number of symbols to plot.
skipfirst (int): Number of symbols to skip.
show (bool): Whether to call plt.show().
"""
if path is None and data is None:
raise ValueError("Either path or data and sps must be given.")
if path is not None:
data, config = load_data(path, skipfirst, symbols)
sps = int(config["glova"]["sps"])
if sps is None:
raise ValueError("sps not set.")
xaxis = np.linspace(0, width, width*sps, endpoint=False)
fig, axs = plt.subplots(2, 2, figsize=(10, 10), sharex=True, sharey=True)
if complex:
# create secondary axis for phase
axs2 = axs[0, 0].twinx(), axs[0, 1].twinx(), axs[1, 0].twinx(), axs[1, 1].twinx()
axs2 = np.reshape(axs2, (2, 2))
for i in range(symbols-(width-1)):
inx, iny, outx, outy = data[i*sps:(i+width)*sps].T
if complex:
axs[0, 0].plot(xaxis, np.abs(inx), color="C0", alpha=alpha or 0.1)
axs[0, 1].plot(xaxis, np.abs(outx), color="C0", alpha=alpha or 0.1)
axs[1, 0].plot(xaxis, np.abs(iny), color="C0", alpha=alpha or 0.1)
axs[1, 1].plot(xaxis, np.abs(outy), color="C0", alpha=alpha or 0.1)
axs[0, 0].set_ylim(0, 1.1*np.max(np.abs(data)))
axs2[0, 0].plot(xaxis, np.angle(inx), color="C1", alpha=alpha or 0.1)
axs2[0, 1].plot(xaxis, np.angle(outx), color="C1", alpha=alpha or 0.1)
axs2[1, 0].plot(xaxis, np.angle(iny), color="C1", alpha=alpha or 0.1)
axs2[1, 1].plot(xaxis, np.angle(outy), color="C1", alpha=alpha or 0.1)
else:
axs[0, 0].plot(xaxis, np.abs(inx)**2, color="C0", alpha=alpha or 0.1)
axs[0, 1].plot(xaxis, np.abs(outx)**2, color="C0", alpha=alpha or 0.1)
axs[1, 0].plot(xaxis, np.abs(iny)**2, color="C0", alpha=alpha or 0.1)
axs[1, 1].plot(xaxis, np.abs(outy)**2, color="C0", alpha=alpha or 0.1)
if complex:
axs2[0, 0].sharey(axs2[0, 1])
axs2[0, 1].sharey(axs2[1, 0])
axs2[1, 0].sharey(axs2[1, 1])
# make y axis symmetric
ylim = np.max(np.abs(np.angle(data)))*1.1
if ylim != 0:
axs2[0, 0].set_ylim(-ylim, ylim)
else:
axs[0,0].set_ylim(0, 1.1*np.max(np.abs(data))**2)
axs[0, 0].set_title("Input x")
axs[0, 1].set_title("Output x")
axs[1, 0].set_title("Input y")
axs[1, 1].set_title("Output y")
fig.suptitle(title or "Eye diagram")
if show:
plt.show()
return fig