Compare commits
23 Commits
main
...
487288c923
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
487288c923 | ||
|
|
bdf6f5bfb8 | ||
|
|
e02662ed4f | ||
|
|
fd7a0b9c31 | ||
|
|
ff32aefd52 | ||
|
|
b156b9ceaf | ||
|
|
cfa08aae4e | ||
|
|
0422c81f3b | ||
|
|
7343ccb3a5 | ||
|
|
9a16a5637d | ||
|
|
80e9a3379e | ||
|
|
8d4d0468bd | ||
|
|
6358c95c42 | ||
|
|
674033ac2e | ||
|
|
cdca5de473 | ||
|
|
1622c38582 | ||
|
|
2bba760378 | ||
|
|
9ec548757d | ||
|
|
05a3ee9394 | ||
|
|
086240489a | ||
|
|
87f40fc37c | ||
|
|
90aa6dbaf8 | ||
|
|
744c5f5166 |
1
.gitattributes
vendored
1
.gitattributes
vendored
@@ -1,4 +1,5 @@
|
|||||||
data/**/* filter=lfs diff=lfs merge=lfs -text
|
data/**/* filter=lfs diff=lfs merge=lfs -text
|
||||||
|
data/*.db filter=lfs diff=lfs merge=lfs -text
|
||||||
data/*.ini filter=lfs diff=lfs merge=lfs text
|
data/*.ini filter=lfs diff=lfs merge=lfs text
|
||||||
|
|
||||||
## lfs setup
|
## lfs setup
|
||||||
|
|||||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -1,7 +1,5 @@
|
|||||||
src/**/*.ini
|
src/**/*.ini
|
||||||
|
.*
|
||||||
# VSCode
|
|
||||||
.vscode
|
|
||||||
|
|
||||||
# Byte-compiled / optimized / DLL files
|
# Byte-compiled / optimized / DLL files
|
||||||
__pycache__/
|
__pycache__/
|
||||||
|
|||||||
1
LICENSE
1
LICENSE
@@ -10,6 +10,7 @@ use is covered by a right of the copyright holder of the Work).
|
|||||||
The Work is provided under the terms of this Licence when the Licensor (as
|
The Work is provided under the terms of this Licence when the Licensor (as
|
||||||
defined below) has placed the following notice immediately following the
|
defined below) has placed the following notice immediately following the
|
||||||
copyright notice for the Work:
|
copyright notice for the Work:
|
||||||
|
|
||||||
```raw
|
```raw
|
||||||
Licensed under the EUPL
|
Licensed under the EUPL
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ Full license text in LICENSE file
|
|||||||
|
|
||||||
# optical-regeneration
|
# optical-regeneration
|
||||||
|
|
||||||
## Notes on cloning:
|
## Notes on cloning
|
||||||
|
|
||||||
- `pypho` is added as a submodule -> `--recurse-submodules`
|
- `pypho` is added as a submodule -> `--recurse-submodules`
|
||||||
- This repo has about 7.5GB of datasets in it. The `git lfs fetch` step will take a while.
|
- This repo has about 7.5GB of datasets in it. The `git lfs fetch` step will take a while.
|
||||||
@@ -29,4 +29,5 @@ git lfs checkout
|
|||||||
```
|
```
|
||||||
|
|
||||||
## License
|
## License
|
||||||
This project is licensed under EUPL-1.2.
|
|
||||||
|
This project is licensed under EUPL-1.2.
|
||||||
|
|||||||
3
data/optuna_single_core_regen.db
Normal file
3
data/optuna_single_core_regen.db
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:e12f0c21fca93620a165fbb6ed58d0b313093e972ef4416694c29c9cea6dc867
|
||||||
|
size 831488
|
||||||
3
data/single_core_regen.db
Normal file
3
data/single_core_regen.db
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:746ea83013870351296e01e294905ee027291ef79fd78c1e6b69dd9ebaa1cba0
|
||||||
|
size 10240000
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
# CUDA 12.4 Install
|
# CUDA 12.4 Install
|
||||||
|
|
||||||
> https://unknowndba.blogspot.com/2024/04/cuda-getting-started-on-wsl.html (@\_freddenis\_, 2024)
|
> <https://unknowndba.blogspot.com/2024/04/cuda-getting-started-on-wsl.html> (@\_freddenis\_, 2024)
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# get md5sums
|
# get md5sums
|
||||||
@@ -49,10 +49,9 @@ make
|
|||||||
./devicequery
|
./devicequery
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
### if the cuda-toolkit install fails with unmet dependencies
|
### if the cuda-toolkit install fails with unmet dependencies
|
||||||
|
|
||||||
>https://askubuntu.com/a/1493087 (jspinella, 2023, CC BY-SA 4.0)
|
><https://askubuntu.com/a/1493087> (jspinella, 2023, CC BY-SA 4.0)
|
||||||
|
|
||||||
1. Open the *new* file for storing the sources list
|
1. Open the *new* file for storing the sources list
|
||||||
|
|
||||||
@@ -60,7 +59,7 @@ make
|
|||||||
sudo nano /etc/apt/sources.list.d/ubuntu.sources
|
sudo nano /etc/apt/sources.list.d/ubuntu.sources
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Paste in the following at the end of the file:
|
2. Paste in the following at the end of the file:
|
||||||
|
|
||||||
```raw
|
```raw
|
||||||
Types: deb
|
Types: deb
|
||||||
@@ -71,4 +70,3 @@ make
|
|||||||
```
|
```
|
||||||
|
|
||||||
3. Save the file and run `sudo apt update` - now the install command for CUDA should work.
|
3. Save the file and run `sudo apt update` - now the install command for CUDA should work.
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
# useful links
|
# useful links
|
||||||
|
|
||||||
- (Optuna)[https://optuna.org] Hyperparameter optimization framework
|
- [Optuna](https://optuna.org) Hyperparameter optimization framework
|
||||||
`pip install optuna`
|
`pip install optuna`
|
||||||
|
|||||||
@@ -1,46 +1,42 @@
|
|||||||
# pyenv install
|
# pyenv installation
|
||||||
|
|
||||||
## install
|
## pyenv
|
||||||
|
|
||||||
nice to have:
|
1. Install pyenv
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
sudo apt install python-is-python3
|
curl https://pyenv.run | bash
|
||||||
```
|
```
|
||||||
|
|
||||||
```bash
|
2. setup zsh
|
||||||
curl https://pyenv.run | bash
|
|
||||||
```
|
|
||||||
|
|
||||||
## setup zsh
|
add the following to `.zshrc`:
|
||||||
|
|
||||||
add the following to `.zshrc`:
|
```bash
|
||||||
|
export PYENV_ROOT="$HOME/.pyenv"
|
||||||
|
[[ -d $PYENV_ROOT/bin ]] && export PATH="$PYENV_ROOT/bin:$PATH"
|
||||||
|
eval "$(pyenv init -)"
|
||||||
|
```
|
||||||
|
|
||||||
```bash
|
## python installation
|
||||||
export PYENV_ROOT="$HOME/.pyenv"
|
|
||||||
[[ -d $PYENV_ROOT/bin ]] && export PATH="$PYENV_ROOT/bin:$PATH"
|
|
||||||
eval "$(pyenv init -)"
|
|
||||||
```
|
|
||||||
|
|
||||||
## pyenv install
|
1. prerequisites
|
||||||
|
|
||||||
prerequisites:
|
```bash
|
||||||
|
sudo apt update
|
||||||
|
sudo apt install build-essential libssl-dev zlib1g-dev \
|
||||||
|
libbz2-dev libreadline-dev libsqlite3-dev curl git \
|
||||||
|
libncursesw5-dev xz-utils tk-dev libxml2-dev libxmlsec1-dev \
|
||||||
|
libffi-dev liblzma-dev python3-pip
|
||||||
|
```
|
||||||
|
|
||||||
```bash
|
2. install
|
||||||
sudo apt update
|
|
||||||
sudo apt install build-essential libssl-dev zlib1g-dev \
|
|
||||||
libbz2-dev libreadline-dev libsqlite3-dev curl git \
|
|
||||||
libncursesw5-dev xz-utils tk-dev libxml2-dev libxmlsec1-dev \
|
|
||||||
libffi-dev liblzma-dev python3-pip
|
|
||||||
```
|
|
||||||
|
|
||||||
install:
|
```bash
|
||||||
|
# using python 3.12.7 as an example
|
||||||
|
pyenv install 3.12.7
|
||||||
|
|
||||||
```bash
|
# optional
|
||||||
# using python 3.12.7 as an example
|
pyenv global 3.12.7
|
||||||
pyenv install 3.12.7
|
pyenv versions
|
||||||
|
```
|
||||||
# optional
|
|
||||||
pyenv global 3.12.7
|
|
||||||
pyenv versions
|
|
||||||
```
|
|
||||||
|
|||||||
@@ -8,7 +8,8 @@ source ./.venv/bin/activate
|
|||||||
```
|
```
|
||||||
|
|
||||||
## install pytorch
|
## install pytorch
|
||||||
> https://pytorch.org/get-started/locally/
|
|
||||||
|
> <https://pytorch.org/get-started/locally/>
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install torch torchvision torchaudio
|
pip install torch torchvision torchaudio
|
||||||
|
|||||||
37
src/single-core-data-gen/add_pypho.py
Normal file
37
src/single-core-data-gen/add_pypho.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
|
||||||
|
# add_pypho.py
|
||||||
|
#
|
||||||
|
# This file is part of the repo "optical-regeneration"
|
||||||
|
# https://git.suuppl.dev/seppl/optical-regeneration.git
|
||||||
|
#
|
||||||
|
# (c) Joseph Hopfmüller, 2024
|
||||||
|
# Licensed under the EUPL
|
||||||
|
#
|
||||||
|
# Full license text in LICENSE file
|
||||||
|
|
||||||
|
###
|
||||||
|
|
||||||
|
# copy this file into the directory where you want to use pypho
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
__log = []
|
||||||
|
|
||||||
|
# add the dir above the one where this file lives
|
||||||
|
__parent_dir = Path(__file__).parent
|
||||||
|
|
||||||
|
# search for a dir containing ./pypho/pypho, then add the lower ./pypho
|
||||||
|
while not (__parent_dir / "pypho" / "pypho").exists() and __parent_dir != Path("/"):
|
||||||
|
__parent_dir = __parent_dir.parent
|
||||||
|
|
||||||
|
if __parent_dir != Path("/"):
|
||||||
|
sys.path.append(str(__parent_dir / "pypho"))
|
||||||
|
__log.append(f"Added '{__parent_dir/ "pypho"}' to 'PATH'")
|
||||||
|
else:
|
||||||
|
__log.append('pypho not found')
|
||||||
|
|
||||||
|
|
||||||
|
def show_log():
|
||||||
|
for entry in __log:
|
||||||
|
print(entry)
|
||||||
@@ -19,9 +19,8 @@ import time
|
|||||||
from matplotlib import pyplot as plt # noqa: F401
|
from matplotlib import pyplot as plt # noqa: F401
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import _path_fix # noqa: F401
|
import add_pypho # noqa: F401
|
||||||
import pypho
|
import pypho
|
||||||
# import inspect
|
|
||||||
|
|
||||||
default_config = f"""
|
default_config = f"""
|
||||||
[glova]
|
[glova]
|
||||||
@@ -498,17 +497,18 @@ def plot_eye_diagram(
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
add_pypho.show_log()
|
||||||
config = get_config()
|
config = get_config()
|
||||||
|
|
||||||
length_ranges = [1000, 10000]
|
# length_ranges = [1000, 10000]
|
||||||
length_scales = [1, 2, 3, 4, 5, 6, 7, 8, 9]
|
# length_scales = [1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||||
|
|
||||||
lengths = [
|
# lengths = [
|
||||||
length_scale * length_range
|
# length_scale * length_range
|
||||||
for length_range in length_ranges
|
# for length_range in length_ranges
|
||||||
for length_scale in length_scales
|
# for length_scale in length_scales
|
||||||
]
|
# ]
|
||||||
lengths.append(max(length_ranges)*10)
|
# lengths.append(max(length_ranges)*10)
|
||||||
|
|
||||||
# length_loop(config, lengths)
|
# length_loop(config, lengths)
|
||||||
|
|
||||||
@@ -1,9 +0,0 @@
|
|||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
# hack to add the parent directory to the path -> pypho doesn't have to be installed as package
|
|
||||||
parent_dir = Path(__file__).parent
|
|
||||||
while not (parent_dir / "pypho" / "pypho").exists() and parent_dir != Path("/"):
|
|
||||||
parent_dir = parent_dir.parent
|
|
||||||
print(f"Adding '{parent_dir / "pypho"}' to 'sys.path' to enable import of '{parent_dir / 'pypho' / 'pypho'}'")
|
|
||||||
sys.path.append(str(parent_dir / "pypho"))
|
|
||||||
763
src/single-core-regen/hypertraining/hypertraining.py
Normal file
763
src/single-core-regen/hypertraining/hypertraining.py
Normal file
@@ -0,0 +1,763 @@
|
|||||||
|
import copy
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Literal
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import optuna
|
||||||
|
|
||||||
|
# import optunahub
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
# import torch.nn.functional as F # mse_loss doesn't support complex numbers
|
||||||
|
import torch.optim as optim
|
||||||
|
import torch.utils.data
|
||||||
|
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
# from rich.progress import (
|
||||||
|
# Progress,
|
||||||
|
# TextColumn,
|
||||||
|
# BarColumn,
|
||||||
|
# TaskProgressColumn,
|
||||||
|
# TimeRemainingColumn,
|
||||||
|
# MofNCompleteColumn,
|
||||||
|
# TimeElapsedColumn,
|
||||||
|
# )
|
||||||
|
# from rich.console import Console
|
||||||
|
# from rich import print as rprint
|
||||||
|
|
||||||
|
import multiprocessing
|
||||||
|
|
||||||
|
from util.datasets import FiberRegenerationDataset
|
||||||
|
|
||||||
|
# from util.optuna_helpers import (
|
||||||
|
# suggest_categorical_optional, # noqa: F401
|
||||||
|
# suggest_float_optional, # noqa: F401
|
||||||
|
# suggest_int_optional, # noqa: F401
|
||||||
|
# )
|
||||||
|
from util.optuna_helpers import install_optional_suggests
|
||||||
|
import util
|
||||||
|
|
||||||
|
from .settings import (
|
||||||
|
GlobalSettings,
|
||||||
|
DataSettings,
|
||||||
|
ModelSettings,
|
||||||
|
OptunaSettings,
|
||||||
|
OptimizerSettings,
|
||||||
|
PytorchSettings,
|
||||||
|
)
|
||||||
|
|
||||||
|
install_optional_suggests()
|
||||||
|
|
||||||
|
|
||||||
|
class HyperTraining:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
global_settings,
|
||||||
|
data_settings,
|
||||||
|
pytorch_settings,
|
||||||
|
model_settings,
|
||||||
|
optimizer_settings,
|
||||||
|
optuna_settings,
|
||||||
|
# console=None,
|
||||||
|
):
|
||||||
|
self.global_settings: GlobalSettings = global_settings
|
||||||
|
self.data_settings: DataSettings = data_settings
|
||||||
|
self.pytorch_settings: PytorchSettings = pytorch_settings
|
||||||
|
self.model_settings: ModelSettings = model_settings
|
||||||
|
self.optimizer_settings: OptimizerSettings = optimizer_settings
|
||||||
|
self.optuna_settings: OptunaSettings = optuna_settings
|
||||||
|
self.processes = None
|
||||||
|
|
||||||
|
# self.console = console or Console()
|
||||||
|
|
||||||
|
# set some extra settings to make the code more readable
|
||||||
|
self._extra_optuna_settings()
|
||||||
|
self.stop_study = True
|
||||||
|
|
||||||
|
def setup_tb_writer(self, study_name=None, append=None):
|
||||||
|
log_dir = self.pytorch_settings.summary_dir + "/" + (study_name or self.optuna_settings.study_name)
|
||||||
|
if append is not None:
|
||||||
|
log_dir += "_" + str(append)
|
||||||
|
|
||||||
|
return SummaryWriter(log_dir)
|
||||||
|
|
||||||
|
def resume_latest_study(self, verbose=True):
|
||||||
|
study_name = self.get_latest_study()
|
||||||
|
|
||||||
|
if study_name:
|
||||||
|
if verbose:
|
||||||
|
print(f"Resuming study: {study_name}")
|
||||||
|
self.optuna_settings.study_name = study_name
|
||||||
|
|
||||||
|
def get_latest_study(self, verbose=False) -> optuna.Study:
|
||||||
|
studies = self.get_studies()
|
||||||
|
study = None
|
||||||
|
for study in studies:
|
||||||
|
study.datetime_start = study.datetime_start or datetime.min
|
||||||
|
if studies:
|
||||||
|
study = sorted(studies, key=lambda x: x.datetime_start, reverse=True)[0]
|
||||||
|
if verbose:
|
||||||
|
print(f"Last study: {study.study_name}")
|
||||||
|
else:
|
||||||
|
if verbose:
|
||||||
|
print("No previous studies found")
|
||||||
|
return optuna.load_study(study_name=study.study_name, storage=self.optuna_settings.storage)
|
||||||
|
|
||||||
|
# def study(self) -> optuna.Study:
|
||||||
|
# return optuna.load_study(self.optuna_settings.study_name, storage=self.optuna_settings.storage)
|
||||||
|
|
||||||
|
def get_studies(self):
|
||||||
|
return optuna.get_all_study_summaries(storage=self.optuna_settings.storage)
|
||||||
|
|
||||||
|
def setup_study(self):
|
||||||
|
# module = optunahub.load_module(package="samplers/auto_sampler")
|
||||||
|
if self.optuna_settings._parallel:
|
||||||
|
self.processes = []
|
||||||
|
|
||||||
|
pruner = getattr(optuna.pruners, self.optuna_settings.pruner, None)
|
||||||
|
|
||||||
|
if pruner and self.optuna_settings.pruner_kwargs is not None:
|
||||||
|
pruner = pruner(**self.optuna_settings.pruner_kwargs)
|
||||||
|
elif pruner:
|
||||||
|
pruner = pruner()
|
||||||
|
|
||||||
|
self.study = optuna.create_study(
|
||||||
|
study_name=self.optuna_settings.study_name,
|
||||||
|
storage=self.optuna_settings.storage,
|
||||||
|
load_if_exists=True,
|
||||||
|
direction=self.optuna_settings._direction,
|
||||||
|
directions=self.optuna_settings._directions,
|
||||||
|
pruner=pruner,
|
||||||
|
# sampler=module.AutoSampler(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# print("using sampler:", self.study.sampler)
|
||||||
|
|
||||||
|
with warnings.catch_warnings(action="ignore"):
|
||||||
|
self.study.set_metric_names(self.optuna_settings.metrics_names)
|
||||||
|
|
||||||
|
def run_study(self):
|
||||||
|
try:
|
||||||
|
if self.optuna_settings._parallel:
|
||||||
|
self._run_parallel_study()
|
||||||
|
else:
|
||||||
|
self._run_study()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("Stopping. Please wait for the processes to finish.")
|
||||||
|
self.stop_study = True
|
||||||
|
|
||||||
|
def trials_left(self):
|
||||||
|
return self.optuna_settings.n_trials - len(self.study.get_trials(states=self.optuna_settings.n_trials_filter))
|
||||||
|
|
||||||
|
def remove_completed_processes(self):
|
||||||
|
if self.processes is None:
|
||||||
|
return
|
||||||
|
for p, process in enumerate(self.processes):
|
||||||
|
if not process.is_alive():
|
||||||
|
process.join()
|
||||||
|
self.processes.pop(p)
|
||||||
|
|
||||||
|
def remove_outliers(self):
|
||||||
|
if self.optuna_settings.remove_outliers is not None:
|
||||||
|
trials = self.study.get_trials(states=(optuna.trial.TrialState.COMPLETE,))
|
||||||
|
if len(trials) == 0:
|
||||||
|
return
|
||||||
|
vals = [trial.value for trial in trials]
|
||||||
|
vals = np.log(vals)
|
||||||
|
mean = np.mean(vals)
|
||||||
|
std = np.std(vals)
|
||||||
|
outliers = [
|
||||||
|
trial for trial in trials if np.log(trial.value) > mean + self.optuna_settings.remove_outliers * std
|
||||||
|
]
|
||||||
|
for trial in outliers:
|
||||||
|
trial: optuna.trial.Trial = trial
|
||||||
|
trial.state = optuna.trial.TrialState.FAIL
|
||||||
|
trial.set_user_attr("outlier", True)
|
||||||
|
|
||||||
|
def _run_study(self):
|
||||||
|
while trials_left := self.trials_left():
|
||||||
|
self.remove_outliers()
|
||||||
|
self._run_optimize(n_trials=trials_left, timeout=self.optuna_settings.timeout)
|
||||||
|
|
||||||
|
def _run_parallel_study(self):
|
||||||
|
while trials_left := self.trials_left():
|
||||||
|
self.remove_outliers()
|
||||||
|
self.remove_completed_processes()
|
||||||
|
|
||||||
|
n_trials = max(trials_left, self.optuna_settings._n_threads) // self.optuna_settings._n_threads
|
||||||
|
|
||||||
|
def target_fun():
|
||||||
|
self._run_optimize(n_trials=n_trials, timeout=self.optuna_settings.timeout)
|
||||||
|
|
||||||
|
for _ in range(self.optuna_settings._n_threads - len(self.processes)):
|
||||||
|
self.processes.append(multiprocessing.Process(target=target_fun))
|
||||||
|
self.processes[-1].start()
|
||||||
|
|
||||||
|
def _run_optimize(self, **kwargs):
|
||||||
|
self.study.optimize(
|
||||||
|
self.objective,
|
||||||
|
**kwargs,
|
||||||
|
show_progress_bar=not self.optuna_settings._parallel,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _extra_optuna_settings(self):
|
||||||
|
self.optuna_settings._multi_objective = len(self.optuna_settings.directions) > 1
|
||||||
|
|
||||||
|
if self.optuna_settings._multi_objective:
|
||||||
|
self.optuna_settings._direction = None
|
||||||
|
self.optuna_settings._directions = self.optuna_settings.directions
|
||||||
|
else:
|
||||||
|
self.optuna_settings._direction = self.optuna_settings.directions[0]
|
||||||
|
self.optuna_settings._directions = None
|
||||||
|
|
||||||
|
self.optuna_settings._n_train_batches = (
|
||||||
|
self.optuna_settings.n_train_batches if self.optuna_settings.limit_examples else float("inf")
|
||||||
|
)
|
||||||
|
self.optuna_settings._n_valid_batches = (
|
||||||
|
self.optuna_settings.n_valid_batches if self.optuna_settings.limit_examples else float("inf")
|
||||||
|
)
|
||||||
|
|
||||||
|
self.optuna_settings._n_threads = self.optuna_settings.n_workers
|
||||||
|
|
||||||
|
self.optuna_settings._parallel = self.optuna_settings._n_threads > 1
|
||||||
|
|
||||||
|
def define_model(self, trial: optuna.Trial, writer=None):
|
||||||
|
n_layers = trial.suggest_int_optional("model_n_hidden_layers", self.model_settings.n_hidden_layers)
|
||||||
|
|
||||||
|
input_dim = trial.suggest_int_optional(
|
||||||
|
"model_input_dim",
|
||||||
|
self.data_settings.output_size,
|
||||||
|
step=2,
|
||||||
|
multiply=2,
|
||||||
|
set_new=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# trial.set_user_attr("model_input_dim", input_dim)
|
||||||
|
|
||||||
|
dtype = trial.suggest_categorical_optional("model_dtype", self.data_settings.dtype, set_new=False)
|
||||||
|
dtype = getattr(torch, dtype)
|
||||||
|
|
||||||
|
afunc = trial.suggest_categorical_optional("model_activation_func", self.model_settings.model_activation_func)
|
||||||
|
# T0 = trial.suggest_float_optional("T0", self.model_settings.satabsT0 , log=True)
|
||||||
|
|
||||||
|
layers = []
|
||||||
|
last_dim = input_dim
|
||||||
|
n_nodes = last_dim
|
||||||
|
for i in range(n_layers):
|
||||||
|
if hidden_dim_override := self.model_settings.overrides.get(f"n_hidden_nodes_{i}", False):
|
||||||
|
hidden_dim = trial.suggest_int_optional(f"model_hidden_dim_{i}", hidden_dim_override)
|
||||||
|
else:
|
||||||
|
hidden_dim = trial.suggest_int_optional(
|
||||||
|
f"model_hidden_dim_{i}",
|
||||||
|
self.model_settings.n_hidden_nodes,
|
||||||
|
)
|
||||||
|
layers.append(util.complexNN.SemiUnitaryLayer(last_dim, hidden_dim, dtype=dtype))
|
||||||
|
last_dim = hidden_dim
|
||||||
|
layers.append(getattr(util.complexNN, afunc)())
|
||||||
|
n_nodes += last_dim
|
||||||
|
|
||||||
|
layers.append(util.complexNN.SemiUnitaryLayer(last_dim, self.model_settings.output_dim, dtype=dtype))
|
||||||
|
|
||||||
|
model = nn.Sequential(*layers)
|
||||||
|
|
||||||
|
if writer is not None:
|
||||||
|
writer.add_graph(model, torch.zeros(1, input_dim, dtype=dtype), use_strict_trace=False)
|
||||||
|
|
||||||
|
n_params = sum(p.numel() for p in model.parameters())
|
||||||
|
trial.set_user_attr("model_n_params", n_params)
|
||||||
|
trial.set_user_attr("model_n_nodes", n_nodes)
|
||||||
|
|
||||||
|
return model.to(self.pytorch_settings.device)
|
||||||
|
|
||||||
|
def get_sliced_data(self, trial: optuna.Trial, override=None):
|
||||||
|
symbols = trial.suggest_float_optional("dataset_symbols", self.data_settings.symbols, set_new=False)
|
||||||
|
|
||||||
|
in_out_delay = trial.suggest_float_optional(
|
||||||
|
"dataset_in_out_delay", self.data_settings.in_out_delay, set_new=False
|
||||||
|
)
|
||||||
|
|
||||||
|
xy_delay = trial.suggest_float_optional("dataset_xy_delay", self.data_settings.xy_delay, set_new=False)
|
||||||
|
|
||||||
|
data_size = int(
|
||||||
|
0.5
|
||||||
|
* trial.suggest_int_optional(
|
||||||
|
"model_input_dim",
|
||||||
|
self.data_settings.output_size,
|
||||||
|
step=2,
|
||||||
|
multiply=2,
|
||||||
|
set_new=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
dtype = trial.suggest_categorical_optional("model_dtype", self.data_settings.dtype, set_new=False)
|
||||||
|
dtype = getattr(torch, dtype)
|
||||||
|
|
||||||
|
num_symbols = None
|
||||||
|
if override is not None:
|
||||||
|
num_symbols = override.get("num_symbols", None)
|
||||||
|
# get dataset
|
||||||
|
dataset = FiberRegenerationDataset(
|
||||||
|
file_path=self.data_settings.config_path,
|
||||||
|
symbols=symbols,
|
||||||
|
output_dim=data_size,
|
||||||
|
target_delay=in_out_delay,
|
||||||
|
xy_delay=xy_delay,
|
||||||
|
drop_first=self.data_settings.drop_first,
|
||||||
|
dtype=dtype,
|
||||||
|
real=not dtype.is_complex,
|
||||||
|
num_symbols=num_symbols,
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset_size = len(dataset)
|
||||||
|
indices = list(range(dataset_size))
|
||||||
|
split = int(np.floor(self.data_settings.train_split * dataset_size))
|
||||||
|
if self.data_settings.shuffle:
|
||||||
|
np.random.seed(self.global_settings.seed)
|
||||||
|
np.random.shuffle(indices)
|
||||||
|
|
||||||
|
train_indices, valid_indices = indices[:split], indices[split:]
|
||||||
|
|
||||||
|
if self.data_settings.shuffle:
|
||||||
|
train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
|
||||||
|
valid_sampler = torch.utils.data.SubsetRandomSampler(valid_indices)
|
||||||
|
else:
|
||||||
|
train_sampler = train_indices
|
||||||
|
valid_sampler = valid_indices
|
||||||
|
|
||||||
|
train_loader = torch.utils.data.DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=self.pytorch_settings.batchsize,
|
||||||
|
sampler=train_sampler,
|
||||||
|
drop_last=True,
|
||||||
|
pin_memory=True,
|
||||||
|
num_workers=self.pytorch_settings.dataloader_workers,
|
||||||
|
prefetch_factor=self.pytorch_settings.dataloader_prefetch,
|
||||||
|
)
|
||||||
|
|
||||||
|
valid_loader = torch.utils.data.DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=self.pytorch_settings.batchsize,
|
||||||
|
sampler=valid_sampler,
|
||||||
|
drop_last=True,
|
||||||
|
pin_memory=True,
|
||||||
|
num_workers=self.pytorch_settings.dataloader_workers,
|
||||||
|
prefetch_factor=self.pytorch_settings.dataloader_prefetch,
|
||||||
|
)
|
||||||
|
|
||||||
|
return train_loader, valid_loader
|
||||||
|
|
||||||
|
def train_model(
|
||||||
|
self,
|
||||||
|
trial,
|
||||||
|
model,
|
||||||
|
optimizer,
|
||||||
|
train_loader,
|
||||||
|
epoch,
|
||||||
|
writer=None,
|
||||||
|
# enable_progress=False,
|
||||||
|
):
|
||||||
|
# if enable_progress:
|
||||||
|
# progress = Progress(
|
||||||
|
# TextColumn("[yellow] Training..."),
|
||||||
|
# TextColumn("Error: {task.description}"),
|
||||||
|
# BarColumn(),
|
||||||
|
# TaskProgressColumn(),
|
||||||
|
# TextColumn("[green]Batch"),
|
||||||
|
# MofNCompleteColumn(),
|
||||||
|
# TimeRemainingColumn(),
|
||||||
|
# TimeElapsedColumn(),
|
||||||
|
# # description="Training",
|
||||||
|
# transient=False,
|
||||||
|
# console=self.console,
|
||||||
|
# refresh_per_second=10,
|
||||||
|
# )
|
||||||
|
# task = progress.add_task("-.---e--", total=len(train_loader))
|
||||||
|
# progress.start()
|
||||||
|
|
||||||
|
running_loss2 = 0.0
|
||||||
|
running_loss = 0.0
|
||||||
|
model.train()
|
||||||
|
for batch_idx, (x, y) in enumerate(train_loader):
|
||||||
|
if batch_idx >= self.optuna_settings._n_train_batches:
|
||||||
|
break
|
||||||
|
model.zero_grad(set_to_none=True)
|
||||||
|
x, y = (
|
||||||
|
x.to(self.pytorch_settings.device),
|
||||||
|
y.to(self.pytorch_settings.device),
|
||||||
|
)
|
||||||
|
y_pred = model(x)
|
||||||
|
loss = util.complexNN.complex_mse_loss(y_pred, y)
|
||||||
|
loss_value = loss.item()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
running_loss2 += loss_value
|
||||||
|
running_loss += loss_value
|
||||||
|
|
||||||
|
# if enable_progress:
|
||||||
|
# progress.update(task, advance=1, description=f"{loss_value:.3e}")
|
||||||
|
|
||||||
|
if writer is not None:
|
||||||
|
if batch_idx % self.pytorch_settings.write_every == 0:
|
||||||
|
writer.add_scalar(
|
||||||
|
"training loss",
|
||||||
|
running_loss2 / (self.pytorch_settings.write_every if batch_idx > 0 else 1),
|
||||||
|
epoch * min(len(train_loader), self.optuna_settings._n_train_batches) + batch_idx,
|
||||||
|
)
|
||||||
|
running_loss2 = 0.0
|
||||||
|
|
||||||
|
# if enable_progress:
|
||||||
|
# progress.stop()
|
||||||
|
|
||||||
|
return running_loss / min(len(train_loader), self.optuna_settings._n_train_batches)
|
||||||
|
|
||||||
|
def eval_model(
|
||||||
|
self,
|
||||||
|
trial,
|
||||||
|
model,
|
||||||
|
valid_loader,
|
||||||
|
epoch,
|
||||||
|
writer=None,
|
||||||
|
# enable_progress=True
|
||||||
|
):
|
||||||
|
# if enable_progress:
|
||||||
|
# progress = Progress(
|
||||||
|
# TextColumn("[green]Evaluating..."),
|
||||||
|
# TextColumn("Error: {task.description}"),
|
||||||
|
# BarColumn(),
|
||||||
|
# TaskProgressColumn(),
|
||||||
|
# TextColumn("[green]Batch"),
|
||||||
|
# MofNCompleteColumn(),
|
||||||
|
# TimeRemainingColumn(),
|
||||||
|
# TimeElapsedColumn(),
|
||||||
|
# # description="Training",
|
||||||
|
# transient=False,
|
||||||
|
# console=self.console,
|
||||||
|
# refresh_per_second=10,
|
||||||
|
# )
|
||||||
|
# progress.start()
|
||||||
|
# task = progress.add_task("-.---e--", total=len(valid_loader))
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
running_error = 0
|
||||||
|
running_error_2 = 0
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch_idx, (x, y) in enumerate(valid_loader):
|
||||||
|
if batch_idx >= self.optuna_settings._n_valid_batches:
|
||||||
|
break
|
||||||
|
x, y = (
|
||||||
|
x.to(self.pytorch_settings.device),
|
||||||
|
y.to(self.pytorch_settings.device),
|
||||||
|
)
|
||||||
|
y_pred = model(x)
|
||||||
|
error = util.complexNN.complex_mse_loss(y_pred, y)
|
||||||
|
error_value = error.item()
|
||||||
|
running_error += error_value
|
||||||
|
running_error_2 += error_value
|
||||||
|
|
||||||
|
# if enable_progress:
|
||||||
|
# progress.update(task, advance=1, description=f"{error_value:.3e}")
|
||||||
|
|
||||||
|
if writer is not None:
|
||||||
|
if batch_idx % self.pytorch_settings.write_every == 0:
|
||||||
|
writer.add_scalar(
|
||||||
|
"eval loss",
|
||||||
|
running_error_2 / (self.pytorch_settings.write_every if batch_idx > 0 else 1),
|
||||||
|
epoch * min(len(valid_loader), self.optuna_settings._n_valid_batches) + batch_idx,
|
||||||
|
)
|
||||||
|
running_error_2 = 0.0
|
||||||
|
|
||||||
|
running_error /= min(len(valid_loader), self.optuna_settings._n_valid_batches)
|
||||||
|
|
||||||
|
if writer is not None:
|
||||||
|
title_append, subtitle = self.build_title(trial)
|
||||||
|
writer.add_figure(
|
||||||
|
"fiber response",
|
||||||
|
self.plot_model_response(
|
||||||
|
trial,
|
||||||
|
model=model,
|
||||||
|
title_append=title_append,
|
||||||
|
subtitle=subtitle,
|
||||||
|
show=False,
|
||||||
|
),
|
||||||
|
epoch + 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# if enable_progress:
|
||||||
|
# progress.stop()
|
||||||
|
|
||||||
|
return running_error
|
||||||
|
|
||||||
|
def run_model(self, model, loader):
|
||||||
|
model.eval()
|
||||||
|
xs = []
|
||||||
|
ys = []
|
||||||
|
y_preds = []
|
||||||
|
with torch.no_grad():
|
||||||
|
model = model.to(self.pytorch_settings.device)
|
||||||
|
for x, y in loader:
|
||||||
|
x, y = (
|
||||||
|
x.to(self.pytorch_settings.device),
|
||||||
|
y.to(self.pytorch_settings.device),
|
||||||
|
)
|
||||||
|
y_pred = model(x).cpu()
|
||||||
|
# x = x.cpu()
|
||||||
|
# y = y.cpu()
|
||||||
|
y_pred = y_pred.view(y_pred.shape[0], -1, 2)
|
||||||
|
y = y.view(y.shape[0], -1, 2)
|
||||||
|
x = x.view(x.shape[0], -1, 2)
|
||||||
|
xs.append(x[:, 0, :].squeeze())
|
||||||
|
ys.append(y.squeeze())
|
||||||
|
y_preds.append(y_pred.squeeze())
|
||||||
|
|
||||||
|
xs = torch.vstack(xs).cpu()
|
||||||
|
ys = torch.vstack(ys).cpu()
|
||||||
|
y_preds = torch.vstack(y_preds).cpu()
|
||||||
|
return ys, xs, y_preds
|
||||||
|
|
||||||
|
def objective(self, trial: optuna.Trial, plot_before=False):
|
||||||
|
if self.stop_study:
|
||||||
|
trial.study.stop()
|
||||||
|
model = None
|
||||||
|
|
||||||
|
writer = self.setup_tb_writer(
|
||||||
|
self.optuna_settings.study_name,
|
||||||
|
f"{trial.number:0{len(str(self.optuna_settings.n_trials)) + 2}}",
|
||||||
|
)
|
||||||
|
|
||||||
|
model = self.define_model(trial, writer)
|
||||||
|
|
||||||
|
# n_nodes = trial.params.get("model_n_hidden_layers", self.model_settings.model_n_layers) * trial.params.get("model_hidden_dim", self.model_settings.unit_count)
|
||||||
|
|
||||||
|
title_append, subtitle = self.build_title(trial)
|
||||||
|
|
||||||
|
writer.add_figure(
|
||||||
|
"fiber response",
|
||||||
|
self.plot_model_response(
|
||||||
|
trial,
|
||||||
|
model=model,
|
||||||
|
title_append=title_append,
|
||||||
|
subtitle=subtitle,
|
||||||
|
show=plot_before,
|
||||||
|
),
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
|
||||||
|
train_loader, valid_loader = self.get_sliced_data(trial)
|
||||||
|
|
||||||
|
optimizer_name = trial.suggest_categorical_optional("optimizer", self.optimizer_settings.optimizer)
|
||||||
|
|
||||||
|
lr = trial.suggest_float_optional("lr", self.optimizer_settings.learning_rate, log=True)
|
||||||
|
|
||||||
|
optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)
|
||||||
|
if self.optimizer_settings.scheduler is not None:
|
||||||
|
scheduler = getattr(optim.lr_scheduler, self.optimizer_settings.scheduler)(
|
||||||
|
optimizer, **self.optimizer_settings.scheduler_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
for epoch in range(self.pytorch_settings.epochs):
|
||||||
|
trial.set_user_attr("epoch", epoch)
|
||||||
|
# enable_progress = self.optuna_settings.n_threads == 1
|
||||||
|
# if enable_progress:
|
||||||
|
# self.console.rule(
|
||||||
|
# f"Epoch {epoch + 1}/{self.pytorch_settings.epochs}"
|
||||||
|
# )
|
||||||
|
self.train_model(
|
||||||
|
trial,
|
||||||
|
model,
|
||||||
|
optimizer,
|
||||||
|
train_loader,
|
||||||
|
epoch,
|
||||||
|
writer,
|
||||||
|
# enable_progress=enable_progress,
|
||||||
|
)
|
||||||
|
error = self.eval_model(
|
||||||
|
trial,
|
||||||
|
model,
|
||||||
|
valid_loader,
|
||||||
|
epoch,
|
||||||
|
writer,
|
||||||
|
# enable_progress=enable_progress,
|
||||||
|
)
|
||||||
|
if self.optimizer_settings.scheduler is not None:
|
||||||
|
scheduler.step(error)
|
||||||
|
|
||||||
|
trial.set_user_attr("mse", error)
|
||||||
|
trial.set_user_attr("log_mse", np.log10(error + np.finfo(float).eps))
|
||||||
|
trial.set_user_attr("neg_mse", -error)
|
||||||
|
trial.set_user_attr("neg_log_mse", -np.log10(error + np.finfo(float).eps))
|
||||||
|
if not self.optuna_settings._multi_objective:
|
||||||
|
trial.report(error, epoch)
|
||||||
|
if trial.should_prune():
|
||||||
|
raise optuna.exceptions.TrialPruned()
|
||||||
|
|
||||||
|
writer.close()
|
||||||
|
|
||||||
|
if self.optuna_settings._multi_objective:
|
||||||
|
return -np.log10(error + np.finfo(float).eps), trial.user_attrs.get("model_n_nodes", -1)
|
||||||
|
|
||||||
|
if self.pytorch_settings.save_models and model is not None:
|
||||||
|
save_path = Path(self.pytorch_settings.model_dir) / f"{self.optuna_settings.study_name}_{trial.number}.pth"
|
||||||
|
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
torch.save(model, save_path)
|
||||||
|
|
||||||
|
return error
|
||||||
|
|
||||||
|
def _plot_model_response_eye(self, *signals, labels=None, sps=None, title_append="", subtitle="", show=True):
|
||||||
|
if sps is None:
|
||||||
|
raise ValueError("sps must be provided")
|
||||||
|
if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))):
|
||||||
|
labels = [labels]
|
||||||
|
else:
|
||||||
|
labels = list(labels)
|
||||||
|
|
||||||
|
while len(labels) < len(signals):
|
||||||
|
labels.append(None)
|
||||||
|
|
||||||
|
# check if there are any labels
|
||||||
|
if not any(labels):
|
||||||
|
labels = [f"signal {i + 1}" for i in range(len(signals))]
|
||||||
|
|
||||||
|
fig, axs = plt.subplots(2, len(signals), sharex=True, sharey=True)
|
||||||
|
fig.suptitle(f"Eye diagram{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}")
|
||||||
|
xaxis = np.linspace(0, 2, 2 * sps, endpoint=False)
|
||||||
|
for j, (label, signal) in enumerate(zip(labels, signals)):
|
||||||
|
# signal = signal.cpu().numpy()
|
||||||
|
for i in range(len(signal) // sps - 1):
|
||||||
|
x, y = signal[i * sps : (i + 2) * sps].T
|
||||||
|
axs[0, j].plot(xaxis, np.abs(x) ** 2, color="C0", alpha=0.02)
|
||||||
|
axs[1, j].plot(xaxis, np.abs(y) ** 2, color="C0", alpha=0.02)
|
||||||
|
axs[0, j].set_title(label + " x")
|
||||||
|
axs[1, j].set_title(label + " y")
|
||||||
|
axs[0, j].set_xlabel("Symbol")
|
||||||
|
axs[1, j].set_xlabel("Symbol")
|
||||||
|
axs[0, j].set_ylabel("normalized power")
|
||||||
|
axs[1, j].set_ylabel("normalized power")
|
||||||
|
|
||||||
|
if show:
|
||||||
|
plt.show()
|
||||||
|
return fig
|
||||||
|
|
||||||
|
def _plot_model_response_head(self, *signals, labels=None, sps=None, title_append="", subtitle="", show=True):
|
||||||
|
if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))):
|
||||||
|
labels = [labels]
|
||||||
|
else:
|
||||||
|
labels = list(labels)
|
||||||
|
|
||||||
|
while len(labels) < len(signals):
|
||||||
|
labels.append(None)
|
||||||
|
|
||||||
|
# check if there are any labels
|
||||||
|
if not any(labels):
|
||||||
|
labels = [f"signal {i + 1}" for i in range(len(signals))]
|
||||||
|
|
||||||
|
fig, axs = plt.subplots(1, 2, sharex=True, sharey=True)
|
||||||
|
fig.set_size_inches(18, 6)
|
||||||
|
fig.suptitle(f"Fiber response{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}")
|
||||||
|
for i, ax in enumerate(axs):
|
||||||
|
for signal, label in zip(signals, labels):
|
||||||
|
if sps is not None:
|
||||||
|
xaxis = np.linspace(0, len(signal) / sps, len(signal), endpoint=False)
|
||||||
|
else:
|
||||||
|
xaxis = np.arange(len(signal))
|
||||||
|
ax.plot(xaxis, np.abs(signal[:, i]) ** 2, label=label)
|
||||||
|
ax.set_xlabel("Sample" if sps is None else "Symbol")
|
||||||
|
ax.set_ylabel("normalized power")
|
||||||
|
ax.legend(loc="upper right")
|
||||||
|
if show:
|
||||||
|
plt.show()
|
||||||
|
return fig
|
||||||
|
|
||||||
|
def plot_model_response(
|
||||||
|
self,
|
||||||
|
trial,
|
||||||
|
model=None,
|
||||||
|
title_append="",
|
||||||
|
subtitle="",
|
||||||
|
mode: Literal["eye", "head"] = "head",
|
||||||
|
show=True,
|
||||||
|
):
|
||||||
|
data_settings_backup = copy.deepcopy(self.data_settings)
|
||||||
|
pytorch_settings_backup = copy.deepcopy(self.pytorch_settings)
|
||||||
|
self.data_settings.drop_first = 100*128
|
||||||
|
self.data_settings.shuffle = False
|
||||||
|
self.data_settings.train_split = 1.0
|
||||||
|
self.pytorch_settings.batchsize = (
|
||||||
|
self.pytorch_settings.eye_symbols if mode == "eye" else self.pytorch_settings.head_symbols
|
||||||
|
)
|
||||||
|
plot_loader, _ = self.get_sliced_data(trial, override={"num_symbols": self.pytorch_settings.batchsize})
|
||||||
|
self.data_settings = data_settings_backup
|
||||||
|
self.pytorch_settings = pytorch_settings_backup
|
||||||
|
|
||||||
|
fiber_in, fiber_out, regen = self.run_model(model, plot_loader)
|
||||||
|
fiber_in = fiber_in.view(-1, 2)
|
||||||
|
fiber_out = fiber_out.view(-1, 2)
|
||||||
|
regen = regen.view(-1, 2)
|
||||||
|
|
||||||
|
fiber_in = fiber_in.numpy()
|
||||||
|
fiber_out = fiber_out.numpy()
|
||||||
|
regen = regen.numpy()
|
||||||
|
|
||||||
|
# https://github.com/matplotlib/matplotlib/issues/27713#issue-2104110987
|
||||||
|
# https://github.com/matplotlib/matplotlib/issues/27713#issuecomment-1915497463
|
||||||
|
import gc
|
||||||
|
|
||||||
|
if mode == "head":
|
||||||
|
fig = self._plot_model_response_head(
|
||||||
|
fiber_in,
|
||||||
|
fiber_out,
|
||||||
|
regen,
|
||||||
|
labels=("fiber in", "fiber out", "regen"),
|
||||||
|
sps=plot_loader.dataset.samples_per_symbol,
|
||||||
|
title_append=title_append,
|
||||||
|
subtitle=subtitle,
|
||||||
|
show=show,
|
||||||
|
)
|
||||||
|
elif mode == "eye":
|
||||||
|
# raise NotImplementedError("Eye diagram not implemented")
|
||||||
|
fig = self._plot_model_response_eye(
|
||||||
|
fiber_in,
|
||||||
|
fiber_out,
|
||||||
|
regen,
|
||||||
|
labels=("fiber in", "fiber out", "regen"),
|
||||||
|
sps=plot_loader.dataset.samples_per_symbol,
|
||||||
|
title_append=title_append,
|
||||||
|
subtitle=subtitle,
|
||||||
|
show=show,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown mode: {mode}")
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def build_title(trial: optuna.trial.Trial):
|
||||||
|
title_append = f"for trial {trial.number}"
|
||||||
|
model_n_hidden_layers = util.misc.multi_getattr((trial.params, trial.user_attrs), "model_n_hidden_layers", 0)
|
||||||
|
input_dim = util.misc.multi_getattr((trial.params, trial.user_attrs), "model_input_dim", 0)
|
||||||
|
model_dims = [
|
||||||
|
util.misc.multi_getattr((trial.params, trial.user_attrs), f"model_hidden_dim_{i}", 0)
|
||||||
|
for i in range(model_n_hidden_layers)
|
||||||
|
]
|
||||||
|
model_dims.insert(0, input_dim)
|
||||||
|
model_dims.append(2)
|
||||||
|
model_dims = [str(dim) for dim in model_dims]
|
||||||
|
model_activation_func = util.misc.multi_getattr(
|
||||||
|
(trial.params, trial.user_attrs),
|
||||||
|
"model_activation_func",
|
||||||
|
"unknown act. fun",
|
||||||
|
)
|
||||||
|
model_dtype = util.misc.multi_getattr((trial.params, trial.user_attrs), "model_dtype", "unknown dtype")
|
||||||
|
|
||||||
|
subtitle = (
|
||||||
|
f"{model_n_hidden_layers+2} layers à ({', '.join(model_dims)}) units, {model_activation_func}, {model_dtype}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return title_append, subtitle
|
||||||
98
src/single-core-regen/hypertraining/settings.py
Normal file
98
src/single-core-regen/hypertraining/settings.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
|
# global settings
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class GlobalSettings:
|
||||||
|
seed: int = 42
|
||||||
|
|
||||||
|
|
||||||
|
# data settings
|
||||||
|
@dataclass
|
||||||
|
class DataSettings:
|
||||||
|
config_path: str # = "data/*-128-16384-100000-0-0-17-0-PAM4-0.ini"
|
||||||
|
dtype: tuple = ("complex64", "float64")
|
||||||
|
symbols: tuple | float | int = 8
|
||||||
|
output_size: tuple | float | int = 64
|
||||||
|
shuffle: bool = True
|
||||||
|
in_out_delay: float = 0
|
||||||
|
xy_delay: tuple | float | int = 0
|
||||||
|
drop_first: int = 1000
|
||||||
|
train_split: float = 0.8
|
||||||
|
|
||||||
|
|
||||||
|
# pytorch settings
|
||||||
|
@dataclass
|
||||||
|
class PytorchSettings:
|
||||||
|
epochs: int = 1
|
||||||
|
batchsize: int = 2**10
|
||||||
|
|
||||||
|
device: str = "cuda"
|
||||||
|
|
||||||
|
dataloader_workers: int = 2
|
||||||
|
dataloader_prefetch: int = 2
|
||||||
|
|
||||||
|
save_models: bool = True
|
||||||
|
model_dir: str = ".models"
|
||||||
|
|
||||||
|
summary_dir: str = ".runs"
|
||||||
|
write_every: int = 10
|
||||||
|
head_symbols: int = 40
|
||||||
|
eye_symbols: int = 400
|
||||||
|
|
||||||
|
|
||||||
|
# model settings
|
||||||
|
@dataclass
|
||||||
|
class ModelSettings:
|
||||||
|
output_dim: int = 2
|
||||||
|
n_hidden_layers: tuple | int = 3
|
||||||
|
n_hidden_nodes: tuple | int = 8
|
||||||
|
model_activation_func: tuple | str = "ModReLU"
|
||||||
|
overrides: dict = field(default_factory=dict)
|
||||||
|
dropout_prob: float | None = None
|
||||||
|
model_layer_function: str | None = None
|
||||||
|
model_layer_parametrizations: list= field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OptimizerSettings:
|
||||||
|
optimizer: tuple | str = ("Adam", "RMSprop", "SGD")
|
||||||
|
learning_rate: tuple | float = (1e-5, 1e-1)
|
||||||
|
scheduler: str | None = None
|
||||||
|
scheduler_kwargs: dict | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def _pruner_default_kwargs():
|
||||||
|
# MedianPruner
|
||||||
|
return {
|
||||||
|
"n_startup_trials": 0,
|
||||||
|
"n_warmup_steps": 5,
|
||||||
|
}
|
||||||
|
# optuna settings
|
||||||
|
@dataclass
|
||||||
|
class OptunaSettings:
|
||||||
|
n_trials: int = 128
|
||||||
|
n_workers: int = 1
|
||||||
|
timeout: int = None
|
||||||
|
pruner: str = "MedianPruner"
|
||||||
|
pruner_kwargs: dict = field(default_factory=_pruner_default_kwargs)
|
||||||
|
directions: tuple = ("minimize",)
|
||||||
|
metrics_names: tuple = ("mse",)
|
||||||
|
limit_examples: bool = True
|
||||||
|
n_train_batches: int = float("inf")
|
||||||
|
n_valid_batches: int = float("inf")
|
||||||
|
storage: str = "sqlite:///example.db"
|
||||||
|
study_name: str = (
|
||||||
|
f"optuna_study_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}"
|
||||||
|
)
|
||||||
|
n_trials_filter: tuple|list = None, #(optuna.trial.TrialState.COMPLETE,)
|
||||||
|
remove_outliers: float|int = None
|
||||||
|
## reserved, set by HyperTraining
|
||||||
|
_multi_objective = None
|
||||||
|
_parallel = None
|
||||||
|
_n_threads = None
|
||||||
|
_directions = None
|
||||||
|
_direction = None
|
||||||
|
_n_train_batches = None
|
||||||
|
_n_valid_batches = None
|
||||||
739
src/single-core-regen/hypertraining/training.py
Normal file
739
src/single-core-regen/hypertraining/training.py
Normal file
@@ -0,0 +1,739 @@
|
|||||||
|
import copy
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Literal
|
||||||
|
import matplotlib
|
||||||
|
import torch.nn.utils.parametrize
|
||||||
|
|
||||||
|
try:
|
||||||
|
matplotlib.use("cairo")
|
||||||
|
except ImportError:
|
||||||
|
matplotlib.use("Agg")
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
# import torch.nn.functional as F # mse_loss doesn't support complex numbers
|
||||||
|
import torch.optim as optim
|
||||||
|
import torch.utils.data
|
||||||
|
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
from rich.progress import (
|
||||||
|
Progress,
|
||||||
|
TextColumn,
|
||||||
|
BarColumn,
|
||||||
|
TaskProgressColumn,
|
||||||
|
TimeRemainingColumn,
|
||||||
|
MofNCompleteColumn,
|
||||||
|
TimeElapsedColumn,
|
||||||
|
)
|
||||||
|
from rich.console import Console
|
||||||
|
|
||||||
|
from util.datasets import FiberRegenerationDataset
|
||||||
|
import util
|
||||||
|
|
||||||
|
from .settings import (
|
||||||
|
GlobalSettings,
|
||||||
|
DataSettings,
|
||||||
|
ModelSettings,
|
||||||
|
OptimizerSettings,
|
||||||
|
PytorchSettings,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class regenerator(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*dims,
|
||||||
|
layer_function=util.complexNN.ONN,
|
||||||
|
layer_parametrizations: list[dict] = None,
|
||||||
|
# [
|
||||||
|
# {
|
||||||
|
# "tensor_name": "weight",
|
||||||
|
# "parametrization": util.complexNN.Unitary,
|
||||||
|
# },
|
||||||
|
# {
|
||||||
|
# "tensor_name": "scale",
|
||||||
|
# "parametrization": util.complexNN.Clamp,
|
||||||
|
# },
|
||||||
|
# ],
|
||||||
|
activation_function=util.complexNN.Pow,
|
||||||
|
dtype=torch.float64,
|
||||||
|
dropout_prob=0.01,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super(regenerator, self).__init__()
|
||||||
|
if len(dims) == 0:
|
||||||
|
try:
|
||||||
|
dims = kwargs["dims"]
|
||||||
|
except KeyError:
|
||||||
|
raise ValueError("dims must be provided")
|
||||||
|
self._n_hidden_layers = len(dims) - 2
|
||||||
|
self._layers = nn.Sequential()
|
||||||
|
|
||||||
|
for i in range(self._n_hidden_layers + 1):
|
||||||
|
self._layers.append(layer_function(dims[i], dims[i + 1], dtype=dtype))
|
||||||
|
if i < self._n_hidden_layers:
|
||||||
|
if dropout_prob is not None:
|
||||||
|
self._layers.append(util.complexNN.DropoutComplex(p=dropout_prob))
|
||||||
|
self._layers.append(activation_function())
|
||||||
|
|
||||||
|
# add parametrizations
|
||||||
|
if layer_parametrizations is not None:
|
||||||
|
for layer_parametrization in layer_parametrizations:
|
||||||
|
tensor_name = layer_parametrization.get("tensor_name", None)
|
||||||
|
parametrization = layer_parametrization.get("parametrization", None)
|
||||||
|
param_kwargs = layer_parametrization.get("kwargs", {})
|
||||||
|
if (
|
||||||
|
tensor_name is not None
|
||||||
|
and tensor_name in self._layers[-1]._parameters
|
||||||
|
and parametrization is not None
|
||||||
|
):
|
||||||
|
parametrization(self._layers[-1], tensor_name, **param_kwargs)
|
||||||
|
|
||||||
|
def forward(self, input_x):
|
||||||
|
x = input_x
|
||||||
|
# check if tracing
|
||||||
|
if torch.jit.is_tracing():
|
||||||
|
for layer in self._layers:
|
||||||
|
x = layer(x)
|
||||||
|
else:
|
||||||
|
# with torch.nn.utils.parametrize.cached():
|
||||||
|
for layer in self._layers:
|
||||||
|
x = layer(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def traverse_dict_update(target, source):
|
||||||
|
for k, v in source.items():
|
||||||
|
if isinstance(v, dict):
|
||||||
|
if k not in target:
|
||||||
|
target[k] = {}
|
||||||
|
traverse_dict_update(target[k], v)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
target[k] = v
|
||||||
|
except TypeError:
|
||||||
|
target.__dict__[k] = v
|
||||||
|
|
||||||
|
class Trainer:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
global_settings=None,
|
||||||
|
data_settings=None,
|
||||||
|
pytorch_settings=None,
|
||||||
|
model_settings=None,
|
||||||
|
optimizer_settings=None,
|
||||||
|
console=None,
|
||||||
|
checkpoint_path=None,
|
||||||
|
settings_override=None,
|
||||||
|
reset_epoch=False,
|
||||||
|
):
|
||||||
|
self.resume = checkpoint_path is not None
|
||||||
|
torch.serialization.add_safe_globals([
|
||||||
|
*util.complexNN.__all__,
|
||||||
|
GlobalSettings,
|
||||||
|
DataSettings,
|
||||||
|
ModelSettings,
|
||||||
|
OptimizerSettings,
|
||||||
|
PytorchSettings,
|
||||||
|
regenerator,
|
||||||
|
torch.nn.utils.parametrizations.orthogonal
|
||||||
|
])
|
||||||
|
if self.resume:
|
||||||
|
self.checkpoint_dict = torch.load(checkpoint_path, weights_only=True)
|
||||||
|
if settings_override is not None:
|
||||||
|
traverse_dict_update(self.checkpoint_dict["settings"], settings_override)
|
||||||
|
if reset_epoch:
|
||||||
|
self.checkpoint_dict["epoch"] = -1
|
||||||
|
|
||||||
|
self.global_settings: GlobalSettings = self.checkpoint_dict["settings"]["global_settings"]
|
||||||
|
self.data_settings: DataSettings = self.checkpoint_dict["settings"]["data_settings"]
|
||||||
|
self.pytorch_settings: PytorchSettings = self.checkpoint_dict["settings"]["pytorch_settings"]
|
||||||
|
self.model_settings: ModelSettings = self.checkpoint_dict["settings"]["model_settings"]
|
||||||
|
self.optimizer_settings: OptimizerSettings = self.checkpoint_dict["settings"]["optimizer_settings"]
|
||||||
|
else:
|
||||||
|
if global_settings is None:
|
||||||
|
raise ValueError("global_settings must be provided")
|
||||||
|
if data_settings is None:
|
||||||
|
raise ValueError("data_settings must be provided")
|
||||||
|
if pytorch_settings is None:
|
||||||
|
raise ValueError("pytorch_settings must be provided")
|
||||||
|
if model_settings is None:
|
||||||
|
raise ValueError("model_settings must be provided")
|
||||||
|
if optimizer_settings is None:
|
||||||
|
raise ValueError("optimizer_settings must be provided")
|
||||||
|
|
||||||
|
self.global_settings: GlobalSettings = global_settings
|
||||||
|
self.data_settings: DataSettings = data_settings
|
||||||
|
self.pytorch_settings: PytorchSettings = pytorch_settings
|
||||||
|
self.model_settings: ModelSettings = model_settings
|
||||||
|
self.optimizer_settings: OptimizerSettings = optimizer_settings
|
||||||
|
|
||||||
|
self.console = console or Console()
|
||||||
|
self.writer = None
|
||||||
|
|
||||||
|
def setup_tb_writer(self, append=None):
|
||||||
|
log_dir = self.pytorch_settings.summary_dir + "/" + (datetime.now().strftime("%Y%m%d_%H%M%S"))
|
||||||
|
if append is not None:
|
||||||
|
log_dir += "_" + str(append)
|
||||||
|
|
||||||
|
print(f"Logging to {log_dir}")
|
||||||
|
self.writer = SummaryWriter(log_dir=log_dir)
|
||||||
|
|
||||||
|
def save_checkpoint(self, save_dict, filename):
|
||||||
|
torch.save(save_dict, filename)
|
||||||
|
|
||||||
|
def build_checkpoint_dict(self, loss=None, epoch=None):
|
||||||
|
return {
|
||||||
|
"epoch": -1 if epoch is None else epoch,
|
||||||
|
"loss": float("inf") if loss is None else loss,
|
||||||
|
"model_state_dict": copy.deepcopy(self.model.state_dict()),
|
||||||
|
"optimizer_state_dict": copy.deepcopy(self.optimizer.state_dict()),
|
||||||
|
"scheduler_state_dict": copy.deepcopy(self.scheduler.state_dict()) if hasattr(self, "scheduler") else None,
|
||||||
|
"model_kwargs": copy.deepcopy(self.model_kwargs),
|
||||||
|
"settings": {
|
||||||
|
"global_settings": copy.deepcopy(self.global_settings),
|
||||||
|
"data_settings": copy.deepcopy(self.data_settings),
|
||||||
|
"pytorch_settings": copy.deepcopy(self.pytorch_settings),
|
||||||
|
"model_settings": copy.deepcopy(self.model_settings),
|
||||||
|
"optimizer_settings": copy.deepcopy(self.optimizer_settings),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def define_model(self, model_kwargs=None):
|
||||||
|
if model_kwargs is None:
|
||||||
|
n_hidden_layers = self.model_settings.n_hidden_layers
|
||||||
|
|
||||||
|
input_dim = 2 * self.data_settings.output_size
|
||||||
|
|
||||||
|
dtype = getattr(torch, self.data_settings.dtype)
|
||||||
|
|
||||||
|
afunc = getattr(util.complexNN, self.model_settings.model_activation_func)
|
||||||
|
|
||||||
|
layer_func = getattr(util.complexNN, self.model_settings.model_layer_function)
|
||||||
|
|
||||||
|
layer_parametrizations = self.model_settings.model_layer_parametrizations
|
||||||
|
|
||||||
|
hidden_dims = [self.model_settings.overrides.get(f"n_hidden_nodes_{i}") for i in range(n_hidden_layers)]
|
||||||
|
|
||||||
|
self.model_kwargs = {
|
||||||
|
"dims": (input_dim, *hidden_dims, self.model_settings.output_dim),
|
||||||
|
"layer_function": layer_func,
|
||||||
|
"layer_parametrizations": layer_parametrizations,
|
||||||
|
"activation_function": afunc,
|
||||||
|
"dtype": dtype,
|
||||||
|
"dropout_prob": self.model_settings.dropout_prob,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
self.model_kwargs = model_kwargs
|
||||||
|
input_dim = self.model_kwargs["dims"][0]
|
||||||
|
dtype = self.model_kwargs["dtype"]
|
||||||
|
|
||||||
|
# dims = self.model_kwargs.pop("dims")
|
||||||
|
self.model = regenerator(**self.model_kwargs)
|
||||||
|
|
||||||
|
self.writer.add_graph(self.model, torch.zeros(1, input_dim, dtype=dtype))
|
||||||
|
|
||||||
|
self.model = self.model.to(self.pytorch_settings.device)
|
||||||
|
|
||||||
|
def get_sliced_data(self, override=None):
|
||||||
|
symbols = self.data_settings.symbols
|
||||||
|
|
||||||
|
in_out_delay = self.data_settings.in_out_delay
|
||||||
|
|
||||||
|
xy_delay = self.data_settings.xy_delay
|
||||||
|
|
||||||
|
data_size = self.data_settings.output_size
|
||||||
|
|
||||||
|
dtype = getattr(torch, self.data_settings.dtype)
|
||||||
|
|
||||||
|
num_symbols = None
|
||||||
|
if override is not None:
|
||||||
|
num_symbols = override.get("num_symbols", None)
|
||||||
|
# get dataset
|
||||||
|
dataset = FiberRegenerationDataset(
|
||||||
|
file_path=self.data_settings.config_path,
|
||||||
|
symbols=symbols,
|
||||||
|
output_dim=data_size,
|
||||||
|
target_delay=in_out_delay,
|
||||||
|
xy_delay=xy_delay,
|
||||||
|
drop_first=self.data_settings.drop_first,
|
||||||
|
dtype=dtype,
|
||||||
|
real=not dtype.is_complex,
|
||||||
|
num_symbols=num_symbols,
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset_size = len(dataset)
|
||||||
|
indices = list(range(dataset_size))
|
||||||
|
split = int(np.floor(self.data_settings.train_split * dataset_size))
|
||||||
|
if self.data_settings.shuffle:
|
||||||
|
np.random.seed(self.global_settings.seed)
|
||||||
|
np.random.shuffle(indices)
|
||||||
|
|
||||||
|
train_indices, valid_indices = indices[:split], indices[split:]
|
||||||
|
|
||||||
|
if self.data_settings.shuffle:
|
||||||
|
train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
|
||||||
|
valid_sampler = torch.utils.data.SubsetRandomSampler(valid_indices)
|
||||||
|
else:
|
||||||
|
train_sampler = train_indices
|
||||||
|
valid_sampler = valid_indices
|
||||||
|
|
||||||
|
train_loader = torch.utils.data.DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=self.pytorch_settings.batchsize,
|
||||||
|
sampler=train_sampler,
|
||||||
|
drop_last=True,
|
||||||
|
pin_memory=True,
|
||||||
|
num_workers=self.pytorch_settings.dataloader_workers,
|
||||||
|
prefetch_factor=self.pytorch_settings.dataloader_prefetch,
|
||||||
|
)
|
||||||
|
|
||||||
|
valid_loader = torch.utils.data.DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=self.pytorch_settings.batchsize,
|
||||||
|
sampler=valid_sampler,
|
||||||
|
drop_last=True,
|
||||||
|
pin_memory=True,
|
||||||
|
num_workers=self.pytorch_settings.dataloader_workers,
|
||||||
|
prefetch_factor=self.pytorch_settings.dataloader_prefetch,
|
||||||
|
)
|
||||||
|
|
||||||
|
return train_loader, valid_loader
|
||||||
|
|
||||||
|
def train_model(
|
||||||
|
self,
|
||||||
|
optimizer,
|
||||||
|
train_loader,
|
||||||
|
epoch,
|
||||||
|
enable_progress=False,
|
||||||
|
):
|
||||||
|
if enable_progress:
|
||||||
|
progress = Progress(
|
||||||
|
TextColumn("[yellow] Training..."),
|
||||||
|
TextColumn("Error: {task.description}"),
|
||||||
|
BarColumn(),
|
||||||
|
TaskProgressColumn(),
|
||||||
|
TextColumn("[green]Batch"),
|
||||||
|
MofNCompleteColumn(),
|
||||||
|
TimeRemainingColumn(),
|
||||||
|
TimeElapsedColumn(),
|
||||||
|
transient=False,
|
||||||
|
console=self.console,
|
||||||
|
refresh_per_second=10,
|
||||||
|
)
|
||||||
|
task = progress.add_task("-.---e--", total=len(train_loader))
|
||||||
|
progress.start()
|
||||||
|
|
||||||
|
running_loss2 = 0.0
|
||||||
|
running_loss = 0.0
|
||||||
|
self.model.train()
|
||||||
|
for batch_idx, (x, y) in enumerate(train_loader):
|
||||||
|
self.model.zero_grad(set_to_none=True)
|
||||||
|
x, y = (
|
||||||
|
x.to(self.pytorch_settings.device),
|
||||||
|
y.to(self.pytorch_settings.device),
|
||||||
|
)
|
||||||
|
y_pred = self.model(x)
|
||||||
|
loss = util.complexNN.complex_mse_loss(y_pred, y, power=True)
|
||||||
|
loss_value = loss.item()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
running_loss2 += loss_value
|
||||||
|
running_loss += loss_value
|
||||||
|
|
||||||
|
if enable_progress:
|
||||||
|
progress.update(task, advance=1, description=f"{loss_value:.3e}")
|
||||||
|
|
||||||
|
if batch_idx % self.pytorch_settings.write_every == 0:
|
||||||
|
self.writer.add_scalar(
|
||||||
|
"training loss",
|
||||||
|
running_loss2 / (self.pytorch_settings.write_every if batch_idx > 0 else 1),
|
||||||
|
epoch * len(train_loader) + batch_idx,
|
||||||
|
)
|
||||||
|
running_loss2 = 0.0
|
||||||
|
|
||||||
|
if enable_progress:
|
||||||
|
progress.stop()
|
||||||
|
|
||||||
|
return running_loss / len(train_loader)
|
||||||
|
|
||||||
|
def eval_model(self, valid_loader, epoch, enable_progress=True):
|
||||||
|
if enable_progress:
|
||||||
|
progress = Progress(
|
||||||
|
TextColumn("[green]Evaluating..."),
|
||||||
|
TextColumn("Error: {task.description}"),
|
||||||
|
BarColumn(),
|
||||||
|
TaskProgressColumn(),
|
||||||
|
TextColumn("[green]Batch"),
|
||||||
|
MofNCompleteColumn(),
|
||||||
|
TimeRemainingColumn(),
|
||||||
|
TimeElapsedColumn(),
|
||||||
|
transient=False,
|
||||||
|
console=self.console,
|
||||||
|
refresh_per_second=10,
|
||||||
|
)
|
||||||
|
progress.start()
|
||||||
|
task = progress.add_task("-.---e--", total=len(valid_loader))
|
||||||
|
|
||||||
|
self.model.eval()
|
||||||
|
running_error = 0
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch_idx, (x, y) in enumerate(valid_loader):
|
||||||
|
x, y = (
|
||||||
|
x.to(self.pytorch_settings.device),
|
||||||
|
y.to(self.pytorch_settings.device),
|
||||||
|
)
|
||||||
|
y_pred = self.model(x)
|
||||||
|
error = util.complexNN.complex_mse_loss(y_pred, y, power=True)
|
||||||
|
error_value = error.item()
|
||||||
|
running_error += error_value
|
||||||
|
|
||||||
|
if enable_progress:
|
||||||
|
progress.update(task, advance=1, description=f"{error_value:.3e}")
|
||||||
|
|
||||||
|
|
||||||
|
running_error /= len(valid_loader)
|
||||||
|
self.writer.add_scalar(
|
||||||
|
"eval loss",
|
||||||
|
running_error,
|
||||||
|
epoch,
|
||||||
|
)
|
||||||
|
title_append, subtitle = self.build_title(epoch + 1)
|
||||||
|
self.writer.add_figure(
|
||||||
|
"fiber response",
|
||||||
|
self.plot_model_response(
|
||||||
|
model=self.model,
|
||||||
|
title_append=title_append,
|
||||||
|
subtitle=subtitle,
|
||||||
|
show=False,
|
||||||
|
),
|
||||||
|
epoch + 1,
|
||||||
|
)
|
||||||
|
self.writer.add_figure(
|
||||||
|
"eye diagram",
|
||||||
|
self.plot_model_response(
|
||||||
|
model=self.model,
|
||||||
|
title_append=title_append,
|
||||||
|
subtitle=subtitle,
|
||||||
|
show=False,
|
||||||
|
mode="eye",
|
||||||
|
),
|
||||||
|
epoch + 1,
|
||||||
|
)
|
||||||
|
self.writer_histograms(epoch + 1)
|
||||||
|
|
||||||
|
if enable_progress:
|
||||||
|
progress.stop()
|
||||||
|
|
||||||
|
return running_error
|
||||||
|
|
||||||
|
def run_model(self, model, loader):
|
||||||
|
model.eval()
|
||||||
|
xs = []
|
||||||
|
ys = []
|
||||||
|
y_preds = []
|
||||||
|
with torch.no_grad():
|
||||||
|
model = model.to(self.pytorch_settings.device)
|
||||||
|
for x, y in loader:
|
||||||
|
x, y = (
|
||||||
|
x.to(self.pytorch_settings.device),
|
||||||
|
y.to(self.pytorch_settings.device),
|
||||||
|
)
|
||||||
|
y_pred = model(x).cpu()
|
||||||
|
# x = x.cpu()
|
||||||
|
# y = y.cpu()
|
||||||
|
y_pred = y_pred.view(y_pred.shape[0], -1, 2)
|
||||||
|
y = y.view(y.shape[0], -1, 2)
|
||||||
|
x = x.view(x.shape[0], -1, 2)
|
||||||
|
xs.append(x[:, 0, :].squeeze())
|
||||||
|
ys.append(y.squeeze())
|
||||||
|
y_preds.append(y_pred.squeeze())
|
||||||
|
|
||||||
|
xs = torch.vstack(xs).cpu()
|
||||||
|
ys = torch.vstack(ys).cpu()
|
||||||
|
y_preds = torch.vstack(y_preds).cpu()
|
||||||
|
return ys, xs, y_preds
|
||||||
|
|
||||||
|
def writer_histograms(self, epoch, attributes=["weight", "weight_U", "weight_V", "bias", "sigma", "scale"]):
|
||||||
|
for i, layer in enumerate(self.model._layers):
|
||||||
|
tag = f"layer {i}"
|
||||||
|
for attribute in attributes:
|
||||||
|
if hasattr(layer, attribute):
|
||||||
|
vals: np.ndarray = getattr(layer, attribute).detach().cpu().numpy().flatten()
|
||||||
|
if vals.ndim <= 1 and len(vals) == 1:
|
||||||
|
if np.iscomplexobj(vals):
|
||||||
|
self.writer.add_scalar(f"{tag} {attribute} (Mag)", np.abs(vals), epoch)
|
||||||
|
self.writer.add_scalar(f"{tag} {attribute} (Phase)", np.angle(vals), epoch)
|
||||||
|
else:
|
||||||
|
self.writer.add_scalar(f"{tag} {attribute}", vals, epoch)
|
||||||
|
else:
|
||||||
|
if np.iscomplexobj(vals):
|
||||||
|
self.writer.add_histogram(f"{tag} {attribute} (Mag)", np.abs(vals), epoch, bins="fd")
|
||||||
|
self.writer.add_histogram(f"{tag} {attribute} (Phase)", np.angle(vals), epoch, bins="fd")
|
||||||
|
else:
|
||||||
|
self.writer.add_histogram(f"{tag} {attribute}", vals, epoch, bins="fd")
|
||||||
|
|
||||||
|
def train(self):
|
||||||
|
if self.writer is None:
|
||||||
|
self.setup_tb_writer()
|
||||||
|
|
||||||
|
if self.resume:
|
||||||
|
model_kwargs = self.checkpoint_dict["model_kwargs"]
|
||||||
|
else:
|
||||||
|
model_kwargs = None
|
||||||
|
|
||||||
|
self.define_model(model_kwargs=model_kwargs)
|
||||||
|
|
||||||
|
print(f"number of parameters (trainable): {sum(p.numel() for p in self.model.parameters())} ({sum(p.numel() for p in self.model.parameters() if p.requires_grad)})")
|
||||||
|
|
||||||
|
title_append, subtitle = self.build_title(0)
|
||||||
|
|
||||||
|
self.writer.add_figure(
|
||||||
|
"fiber response",
|
||||||
|
self.plot_model_response(
|
||||||
|
model=self.model,
|
||||||
|
title_append=title_append,
|
||||||
|
subtitle=subtitle,
|
||||||
|
show=False,
|
||||||
|
),
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
self.writer.add_figure(
|
||||||
|
"eye diagram",
|
||||||
|
self.plot_model_response(
|
||||||
|
model=self.model,
|
||||||
|
title_append=title_append,
|
||||||
|
subtitle=subtitle,
|
||||||
|
mode="eye",
|
||||||
|
show=False,
|
||||||
|
),
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
self.writer_histograms(0)
|
||||||
|
|
||||||
|
train_loader, valid_loader = self.get_sliced_data()
|
||||||
|
|
||||||
|
optimizer_name = self.optimizer_settings.optimizer
|
||||||
|
|
||||||
|
lr = self.optimizer_settings.learning_rate
|
||||||
|
|
||||||
|
self.optimizer: optim.Optimizer = getattr(optim, optimizer_name)(self.model.parameters(), lr=lr)
|
||||||
|
if self.optimizer_settings.scheduler is not None:
|
||||||
|
self.scheduler = getattr(optim.lr_scheduler, self.optimizer_settings.scheduler)(
|
||||||
|
self.optimizer, **self.optimizer_settings.scheduler_kwargs
|
||||||
|
)
|
||||||
|
if self.resume:
|
||||||
|
try:
|
||||||
|
self.scheduler.load_state_dict(self.checkpoint_dict["scheduler_state_dict"])
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
self.writer.add_scalar("learning rate", self.scheduler.get_last_lr()[0], -1)
|
||||||
|
|
||||||
|
|
||||||
|
if not self.resume:
|
||||||
|
self.best = self.build_checkpoint_dict()
|
||||||
|
else:
|
||||||
|
self.best = self.checkpoint_dict
|
||||||
|
self.model.load_state_dict(self.best["model_state_dict"], strict=False)
|
||||||
|
try:
|
||||||
|
self.optimizer.load_state_dict(self.best["optimizer_state_dict"])
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
for epoch in range(self.best["epoch"] + 1, self.pytorch_settings.epochs):
|
||||||
|
enable_progress = True
|
||||||
|
if enable_progress:
|
||||||
|
self.console.rule(f"Epoch {epoch + 1}/{self.pytorch_settings.epochs}")
|
||||||
|
self.train_model(
|
||||||
|
self.optimizer,
|
||||||
|
train_loader,
|
||||||
|
epoch,
|
||||||
|
enable_progress=enable_progress,
|
||||||
|
)
|
||||||
|
loss = self.eval_model(
|
||||||
|
valid_loader,
|
||||||
|
epoch,
|
||||||
|
enable_progress=enable_progress,
|
||||||
|
)
|
||||||
|
if self.optimizer_settings.scheduler is not None:
|
||||||
|
lr_old = self.scheduler.get_last_lr()
|
||||||
|
self.scheduler.step(loss)
|
||||||
|
lr_new = self.scheduler.get_last_lr()
|
||||||
|
if lr_old[0] != lr_new[0]:
|
||||||
|
self.writer.add_scalar("learning rate", lr_new[0], epoch)
|
||||||
|
|
||||||
|
if self.pytorch_settings.save_models and self.model is not None:
|
||||||
|
save_path = (
|
||||||
|
Path(self.pytorch_settings.model_dir) / f"{self.writer.get_logdir().split('/')[-1]}_{epoch}.tar"
|
||||||
|
)
|
||||||
|
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
checkpoint = self.build_checkpoint_dict(loss, epoch)
|
||||||
|
self.save_checkpoint(checkpoint, save_path)
|
||||||
|
|
||||||
|
if loss < self.best["loss"]:
|
||||||
|
self.best = checkpoint
|
||||||
|
save_path = (
|
||||||
|
Path(self.pytorch_settings.model_dir) / f"best_{self.writer.get_logdir().split('/')[-1]}.tar"
|
||||||
|
)
|
||||||
|
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
self.save_checkpoint(self.best, save_path)
|
||||||
|
self.writer.flush()
|
||||||
|
|
||||||
|
self.writer.close()
|
||||||
|
return self.best
|
||||||
|
|
||||||
|
def _plot_model_response_eye(self, *signals, labels=None, sps=None, title_append="", subtitle="", show=True):
|
||||||
|
if sps is None:
|
||||||
|
raise ValueError("sps must be provided")
|
||||||
|
if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))):
|
||||||
|
labels = [labels]
|
||||||
|
else:
|
||||||
|
labels = list(labels)
|
||||||
|
|
||||||
|
while len(labels) < len(signals):
|
||||||
|
labels.append(None)
|
||||||
|
|
||||||
|
# check if there are any labels
|
||||||
|
if not any(labels):
|
||||||
|
labels = [f"signal {i + 1}" for i in range(len(signals))]
|
||||||
|
|
||||||
|
fig, axs = plt.subplots(1, 2 * len(signals), sharex=True, sharey=True)
|
||||||
|
fig.set_figwidth(18)
|
||||||
|
fig.suptitle(f"Eye diagram{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}")
|
||||||
|
xaxis = np.linspace(0, 2, 2 * sps, endpoint=False)
|
||||||
|
for j, (label, signal) in enumerate(zip(labels, signals)):
|
||||||
|
# signal = signal.cpu().numpy()
|
||||||
|
for i in range(len(signal) // sps - 1):
|
||||||
|
x, y = signal[i * sps : (i + 2) * sps].T
|
||||||
|
axs[0 + 2 * j].plot(xaxis, np.abs(x) ** 2, color=f"C{j}", alpha=1 / (len(signal) // sps) * 10)
|
||||||
|
axs[1 + 2 * j].plot(xaxis, np.abs(y) ** 2, color=f"C{j}", alpha=1 / (len(signal) // sps) * 10)
|
||||||
|
axs[0 + 2 * j].set_title(label + " x")
|
||||||
|
axs[1 + 2 * j].set_title(label + " y")
|
||||||
|
axs[0 + 2 * j].set_xlabel("Symbol")
|
||||||
|
axs[1 + 2 * j].set_xlabel("Symbol")
|
||||||
|
axs[0 + 2 * j].set_box_aspect(1)
|
||||||
|
axs[1 + 2 * j].set_box_aspect(1)
|
||||||
|
axs[0].set_ylabel("normalized power")
|
||||||
|
fig.tight_layout()
|
||||||
|
# axs[1+2*len(labels)-1].set_ylabel("normalized power")
|
||||||
|
|
||||||
|
if show:
|
||||||
|
plt.show()
|
||||||
|
return fig
|
||||||
|
|
||||||
|
def _plot_model_response_head(self, *signals, labels=None, sps=None, title_append="", subtitle="", show=True):
|
||||||
|
if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))):
|
||||||
|
labels = [labels]
|
||||||
|
else:
|
||||||
|
labels = list(labels)
|
||||||
|
|
||||||
|
while len(labels) < len(signals):
|
||||||
|
labels.append(None)
|
||||||
|
|
||||||
|
# check if there are any labels
|
||||||
|
if not any(labels):
|
||||||
|
labels = [f"signal {i + 1}" for i in range(len(signals))]
|
||||||
|
|
||||||
|
fig, axs = plt.subplots(1, 2, sharex=True, sharey=True)
|
||||||
|
fig.set_figwidth(18)
|
||||||
|
fig.set_figheight(4)
|
||||||
|
fig.suptitle(f"Fiber response{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}")
|
||||||
|
for i, ax in enumerate(axs):
|
||||||
|
for signal, label in zip(signals, labels):
|
||||||
|
if sps is not None:
|
||||||
|
xaxis = np.linspace(0, len(signal) / sps, len(signal), endpoint=False)
|
||||||
|
else:
|
||||||
|
xaxis = np.arange(len(signal))
|
||||||
|
ax.plot(xaxis, np.abs(signal[:, i]) ** 2, label=label)
|
||||||
|
ax.set_xlabel("Sample" if sps is None else "Symbol")
|
||||||
|
ax.set_ylabel("normalized power")
|
||||||
|
ax.legend(loc="upper right")
|
||||||
|
fig.tight_layout()
|
||||||
|
if show:
|
||||||
|
plt.show()
|
||||||
|
return fig
|
||||||
|
|
||||||
|
def plot_model_response(
|
||||||
|
self,
|
||||||
|
model=None,
|
||||||
|
title_append="",
|
||||||
|
subtitle="",
|
||||||
|
mode: Literal["eye", "head"] = "head",
|
||||||
|
show=False,
|
||||||
|
):
|
||||||
|
data_settings_backup = copy.deepcopy(self.data_settings)
|
||||||
|
pytorch_settings_backup = copy.deepcopy(self.pytorch_settings)
|
||||||
|
self.data_settings.drop_first = 100 * 128
|
||||||
|
self.data_settings.shuffle = False
|
||||||
|
self.data_settings.train_split = 1.0
|
||||||
|
self.pytorch_settings.batchsize = (
|
||||||
|
self.pytorch_settings.eye_symbols if mode == "eye" else self.pytorch_settings.head_symbols
|
||||||
|
)
|
||||||
|
plot_loader, _ = self.get_sliced_data(override={"num_symbols": self.pytorch_settings.batchsize})
|
||||||
|
self.data_settings = data_settings_backup
|
||||||
|
self.pytorch_settings = pytorch_settings_backup
|
||||||
|
|
||||||
|
fiber_in, fiber_out, regen = self.run_model(model, plot_loader)
|
||||||
|
fiber_in = fiber_in.view(-1, 2)
|
||||||
|
fiber_out = fiber_out.view(-1, 2)
|
||||||
|
regen = regen.view(-1, 2)
|
||||||
|
|
||||||
|
fiber_in = fiber_in.numpy()
|
||||||
|
fiber_out = fiber_out.numpy()
|
||||||
|
regen = regen.numpy()
|
||||||
|
|
||||||
|
# https://github.com/matplotlib/matplotlib/issues/27713#issue-2104110987
|
||||||
|
# https://github.com/matplotlib/matplotlib/issues/27713#issuecomment-1915497463
|
||||||
|
import gc
|
||||||
|
|
||||||
|
if mode == "head":
|
||||||
|
fig = self._plot_model_response_head(
|
||||||
|
fiber_in,
|
||||||
|
fiber_out,
|
||||||
|
regen,
|
||||||
|
labels=("fiber in", "fiber out", "regen"),
|
||||||
|
sps=plot_loader.dataset.samples_per_symbol,
|
||||||
|
title_append=title_append,
|
||||||
|
subtitle=subtitle,
|
||||||
|
show=show,
|
||||||
|
)
|
||||||
|
elif mode == "eye":
|
||||||
|
# raise NotImplementedError("Eye diagram not implemented")
|
||||||
|
fig = self._plot_model_response_eye(
|
||||||
|
fiber_in,
|
||||||
|
fiber_out,
|
||||||
|
regen,
|
||||||
|
labels=("fiber in", "fiber out", "regen"),
|
||||||
|
sps=plot_loader.dataset.samples_per_symbol,
|
||||||
|
title_append=title_append,
|
||||||
|
subtitle=subtitle,
|
||||||
|
show=show,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown mode: {mode}")
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
def build_title(self, number: int):
|
||||||
|
title_append = f"epoch {number}"
|
||||||
|
model_n_hidden_layers = self.model_settings.n_hidden_layers
|
||||||
|
input_dim = 2 * self.data_settings.output_size
|
||||||
|
model_dims = [
|
||||||
|
self.model_settings.overrides.get(f"n_hidden_nodes_{i}", -1) for i in range(model_n_hidden_layers)
|
||||||
|
]
|
||||||
|
model_dims.insert(0, input_dim)
|
||||||
|
model_dims.append(2)
|
||||||
|
model_dims = [str(dim) for dim in model_dims]
|
||||||
|
model_activation_func = self.model_settings.model_activation_func
|
||||||
|
model_dtype = self.data_settings.dtype
|
||||||
|
|
||||||
|
subtitle = f"{model_n_hidden_layers + 2} layers à ({', '.join(model_dims)}) units, {model_activation_func}, {model_dtype}"
|
||||||
|
|
||||||
|
return title_append, subtitle
|
||||||
118
src/single-core-regen/regen.py
Normal file
118
src/single-core-regen/regen.py
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import optuna
|
||||||
|
from hypertraining.hypertraining import HyperTraining
|
||||||
|
from hypertraining.settings import (
|
||||||
|
GlobalSettings,
|
||||||
|
DataSettings,
|
||||||
|
PytorchSettings,
|
||||||
|
ModelSettings,
|
||||||
|
OptimizerSettings,
|
||||||
|
OptunaSettings,
|
||||||
|
)
|
||||||
|
|
||||||
|
global_settings = GlobalSettings(
|
||||||
|
seed=42,
|
||||||
|
)
|
||||||
|
|
||||||
|
data_settings = DataSettings(
|
||||||
|
config_path="data/*-128-16384-100000-0-0-17-0-PAM4-0.ini",
|
||||||
|
dtype="complex64",
|
||||||
|
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
|
||||||
|
symbols=13, # study: single_core_regen_20241123_011232
|
||||||
|
# output_size = (11, 32), # ballpark 26 taps -> 2 taps per input symbol -> 1 tap every 0.01m (model has 52 inputs)
|
||||||
|
output_size=26, # study: single_core_regen_20241123_011232 (model_input_dim/2)
|
||||||
|
shuffle=True,
|
||||||
|
in_out_delay=0,
|
||||||
|
xy_delay=0,
|
||||||
|
drop_first=128 * 100,
|
||||||
|
train_split=0.8,
|
||||||
|
)
|
||||||
|
|
||||||
|
pytorch_settings = PytorchSettings(
|
||||||
|
epochs=10000,
|
||||||
|
batchsize=2**10,
|
||||||
|
device="cuda",
|
||||||
|
dataloader_workers=12,
|
||||||
|
dataloader_prefetch=4,
|
||||||
|
summary_dir=".runs",
|
||||||
|
write_every=2**5,
|
||||||
|
save_models=True,
|
||||||
|
model_dir=".models",
|
||||||
|
)
|
||||||
|
|
||||||
|
model_settings = ModelSettings(
|
||||||
|
output_dim=2,
|
||||||
|
# n_hidden_layers = (3, 8),
|
||||||
|
n_hidden_layers=4,
|
||||||
|
overrides={
|
||||||
|
"n_hidden_nodes_0": 8,
|
||||||
|
"n_hidden_nodes_1": 6,
|
||||||
|
"n_hidden_nodes_2": 4,
|
||||||
|
"n_hidden_nodes_3": 8,
|
||||||
|
},
|
||||||
|
model_activation_func="Mag",
|
||||||
|
# satabsT0=(1e-6, 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizer_settings = OptimizerSettings(
|
||||||
|
optimizer="Adam",
|
||||||
|
# learning_rate = (1e-5, 1e-1),
|
||||||
|
learning_rate=5e-3
|
||||||
|
# learning_rate=5e-4,
|
||||||
|
)
|
||||||
|
|
||||||
|
optuna_settings = OptunaSettings(
|
||||||
|
n_trials=1,
|
||||||
|
n_workers=1,
|
||||||
|
timeout=3600,
|
||||||
|
directions=("minimize",),
|
||||||
|
metrics_names=("mse",),
|
||||||
|
limit_examples=False,
|
||||||
|
n_train_batches=500,
|
||||||
|
# n_valid_batches = 100,
|
||||||
|
storage="sqlite:///data/single_core_regen.db",
|
||||||
|
study_name=f"single_core_regen_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
|
||||||
|
n_trials_filter=(optuna.trial.TrialState.COMPLETE, optuna.trial.TrialState.PRUNED),
|
||||||
|
pruner="MedianPruner",
|
||||||
|
pruner_kwargs=None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
hyper_training = HyperTraining(
|
||||||
|
global_settings=global_settings,
|
||||||
|
data_settings=data_settings,
|
||||||
|
pytorch_settings=pytorch_settings,
|
||||||
|
model_settings=model_settings,
|
||||||
|
optimizer_settings=optimizer_settings,
|
||||||
|
optuna_settings=optuna_settings,
|
||||||
|
)
|
||||||
|
|
||||||
|
hyper_training.setup_study()
|
||||||
|
|
||||||
|
hyper_training.run_study()
|
||||||
|
# best_trial = hyper_training.study.best_trial
|
||||||
|
|
||||||
|
# best_model = hyper_training.define_model(best_trial).to(
|
||||||
|
# hyper_training.pytorch_settings.device
|
||||||
|
# )
|
||||||
|
|
||||||
|
# title_append, subtitle = hyper_training.build_title(best_trial)
|
||||||
|
# hyper_training.plot_model_response(
|
||||||
|
# best_trial,
|
||||||
|
# model=best_model,
|
||||||
|
# title_append=title_append,
|
||||||
|
# subtitle=subtitle,
|
||||||
|
# mode="eye",
|
||||||
|
# show=True,
|
||||||
|
# )
|
||||||
|
|
||||||
|
# print(f"Best model found for trial {best_trial.number}")
|
||||||
|
# print(f"Best model error: {best_trial.value}")
|
||||||
|
# print(f"Best model params: {best_trial.params}")
|
||||||
|
# print()
|
||||||
|
# print(best_model)
|
||||||
|
|
||||||
|
# eye_fig = hyper_training.plot_eye()
|
||||||
|
...
|
||||||
130
src/single-core-regen/regen_no_hyper.py
Normal file
130
src/single-core-regen/regen_no_hyper.py
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
from hypertraining.settings import (
|
||||||
|
GlobalSettings,
|
||||||
|
DataSettings,
|
||||||
|
PytorchSettings,
|
||||||
|
ModelSettings,
|
||||||
|
OptimizerSettings,
|
||||||
|
)
|
||||||
|
|
||||||
|
from hypertraining.training import Trainer
|
||||||
|
import torch
|
||||||
|
import json
|
||||||
|
import util
|
||||||
|
|
||||||
|
global_settings = GlobalSettings(
|
||||||
|
seed=42,
|
||||||
|
)
|
||||||
|
|
||||||
|
data_settings = DataSettings(
|
||||||
|
config_path="data/*-128-16384-100000-0-0-17-0-PAM4-0.ini",
|
||||||
|
dtype="complex64",
|
||||||
|
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
|
||||||
|
symbols=13, # study: single_core_regen_20241123_011232
|
||||||
|
# output_size = (11, 32), # 26 taps -> 2 taps per input symbol -> 1 tap every 1cm (model has 52 inputs (x/y))
|
||||||
|
output_size=26, # study: single_core_regen_20241123_011232 (model_input_dim/2)
|
||||||
|
shuffle=True,
|
||||||
|
in_out_delay=0,
|
||||||
|
xy_delay=0,
|
||||||
|
drop_first=128*64,
|
||||||
|
train_split=0.8,
|
||||||
|
)
|
||||||
|
|
||||||
|
pytorch_settings = PytorchSettings(
|
||||||
|
epochs=10000,
|
||||||
|
batchsize=2**12,
|
||||||
|
device="cuda",
|
||||||
|
dataloader_workers=12,
|
||||||
|
dataloader_prefetch=8,
|
||||||
|
summary_dir=".runs",
|
||||||
|
write_every=2**5,
|
||||||
|
save_models=True,
|
||||||
|
model_dir=".models",
|
||||||
|
)
|
||||||
|
|
||||||
|
model_settings = ModelSettings(
|
||||||
|
output_dim=2,
|
||||||
|
n_hidden_layers=4,
|
||||||
|
overrides={
|
||||||
|
"n_hidden_nodes_0": 8,
|
||||||
|
"n_hidden_nodes_1": 8,
|
||||||
|
"n_hidden_nodes_2": 4,
|
||||||
|
"n_hidden_nodes_3": 6,
|
||||||
|
},
|
||||||
|
model_activation_func="PowScale",
|
||||||
|
# dropout_prob=0.01,
|
||||||
|
model_layer_function="ONN",
|
||||||
|
model_layer_parametrizations=[
|
||||||
|
{
|
||||||
|
"tensor_name": "weight",
|
||||||
|
"parametrization": torch.nn.utils.parametrizations.orthogonal,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tensor_name": "scales",
|
||||||
|
"parametrization": util.complexNN.clamp,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tensor_name": "scale",
|
||||||
|
"parametrization": util.complexNN.clamp,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tensor_name": "bias",
|
||||||
|
"parametrization": util.complexNN.clamp,
|
||||||
|
},
|
||||||
|
# {
|
||||||
|
# "tensor_name": "V",
|
||||||
|
# "parametrization": torch.nn.utils.parametrizations.orthogonal,
|
||||||
|
# },
|
||||||
|
# {
|
||||||
|
# "tensor_name": "S",
|
||||||
|
# "parametrization": util.complexNN.clamp,
|
||||||
|
# },
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizer_settings = OptimizerSettings(
|
||||||
|
optimizer="Adam",
|
||||||
|
learning_rate=0.05,
|
||||||
|
scheduler="ReduceLROnPlateau",
|
||||||
|
scheduler_kwargs={
|
||||||
|
"patience": 2**6,
|
||||||
|
"factor": 0.9,
|
||||||
|
# "threshold": 1e-3,
|
||||||
|
"min_lr": 1e-6,
|
||||||
|
"cooldown": 10,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def save_dict_to_file(dictionary, filename):
|
||||||
|
"""
|
||||||
|
Save the best dictionary to a JSON file.
|
||||||
|
|
||||||
|
:param best: Dictionary containing the best training results.
|
||||||
|
:type best: dict
|
||||||
|
:param filename: Path to the JSON file where the dictionary will be saved.
|
||||||
|
:type filename: str
|
||||||
|
"""
|
||||||
|
with open(filename, 'w') as f:
|
||||||
|
json.dump(dictionary, f, indent=4)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
trainer = Trainer(
|
||||||
|
global_settings=global_settings,
|
||||||
|
data_settings=data_settings,
|
||||||
|
pytorch_settings=pytorch_settings,
|
||||||
|
model_settings=model_settings,
|
||||||
|
optimizer_settings=optimizer_settings,
|
||||||
|
checkpoint_path='.models/20241128_084935_8885.tar',
|
||||||
|
settings_override={
|
||||||
|
"model_settings": {
|
||||||
|
# "model_activation_func": "PowScale",
|
||||||
|
"dropout_prob": 0,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
reset_epoch=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
best = trainer.train()
|
||||||
|
save_dict_to_file(best, ".models/best_results.json")
|
||||||
|
|
||||||
|
...
|
||||||
194
src/single-core-regen/testing/learn_optuna.py
Normal file
194
src/single-core-regen/testing/learn_optuna.py
Normal file
@@ -0,0 +1,194 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import optuna
|
||||||
|
import warnings
|
||||||
|
from util.optuna_vis import show_figures
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.optim as optim
|
||||||
|
import torch.utils.data
|
||||||
|
|
||||||
|
from torchvision import datasets
|
||||||
|
from torchvision import transforms
|
||||||
|
import multiprocessing
|
||||||
|
|
||||||
|
# from util.dataset import SlicedDataset
|
||||||
|
|
||||||
|
DEVICE = torch.device("cuda")
|
||||||
|
BATCHSIZE = 128
|
||||||
|
CLASSES = 10
|
||||||
|
DIR = Path(__file__).parent
|
||||||
|
EPOCHS = 100
|
||||||
|
N_TRAIN_EXAMPLES = BATCHSIZE * 30
|
||||||
|
N_VALID_EXAMPLES = BATCHSIZE * 10
|
||||||
|
|
||||||
|
n_trials = 128
|
||||||
|
n_threads = 16
|
||||||
|
|
||||||
|
|
||||||
|
def define_model(trial):
|
||||||
|
n_layers = trial.suggest_int("n_layers", 1, 3)
|
||||||
|
layers = []
|
||||||
|
|
||||||
|
in_features = 28 * 28
|
||||||
|
for i in range(n_layers):
|
||||||
|
out_features = trial.suggest_int(f"n_units_l{i}", 4, 128)
|
||||||
|
layers.append(nn.Linear(in_features, out_features))
|
||||||
|
layers.append(nn.ReLU())
|
||||||
|
p = trial.suggest_float(f"dropout_l{i}", 0.2, 0.5)
|
||||||
|
layers.append(nn.Dropout(p))
|
||||||
|
|
||||||
|
in_features = out_features
|
||||||
|
|
||||||
|
layers.append(nn.Linear(in_features, CLASSES))
|
||||||
|
layers.append(nn.LogSoftmax(dim=1))
|
||||||
|
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
|
||||||
|
def get_mnist():
|
||||||
|
# Load FashionMNIST dataset.
|
||||||
|
train_loader = torch.utils.data.DataLoader(
|
||||||
|
datasets.FashionMNIST(
|
||||||
|
DIR / ".data", train=True, download=True, transform=transforms.ToTensor()
|
||||||
|
),
|
||||||
|
batch_size=BATCHSIZE,
|
||||||
|
shuffle=True,
|
||||||
|
)
|
||||||
|
valid_loader = torch.utils.data.DataLoader(
|
||||||
|
datasets.FashionMNIST(
|
||||||
|
DIR / ".data", train=False, transform=transforms.ToTensor()
|
||||||
|
),
|
||||||
|
batch_size=BATCHSIZE,
|
||||||
|
shuffle=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return train_loader, valid_loader
|
||||||
|
|
||||||
|
|
||||||
|
def objective(trial):
|
||||||
|
model = define_model(trial).to(DEVICE)
|
||||||
|
|
||||||
|
optimizer_name = trial.suggest_categorical("optimizer", ["Adam", "RMSprop", "SGD"])
|
||||||
|
lr = trial.suggest_float("lr", 1e-5, 1e-1, log=True)
|
||||||
|
optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)
|
||||||
|
|
||||||
|
train_loader, valid_loader = get_mnist()
|
||||||
|
|
||||||
|
for epoch in range(EPOCHS):
|
||||||
|
train_model(model, optimizer, train_loader)
|
||||||
|
accuracy, num_params = eval_model(model, valid_loader)
|
||||||
|
|
||||||
|
return accuracy, num_params
|
||||||
|
|
||||||
|
|
||||||
|
def eval_model(model, valid_loader):
|
||||||
|
model.eval()
|
||||||
|
correct = 0
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch_idx, (data, target) in enumerate(valid_loader):
|
||||||
|
if batch_idx * BATCHSIZE >= N_VALID_EXAMPLES:
|
||||||
|
break
|
||||||
|
|
||||||
|
data, target = data.view(data.size(0), -1).to(DEVICE), target.to(DEVICE)
|
||||||
|
output = model(data)
|
||||||
|
|
||||||
|
pred = output.argmax(dim=1, keepdim=True)
|
||||||
|
correct += pred.eq(target.view_as(pred)).sum().item()
|
||||||
|
|
||||||
|
accuracy = correct / min(len(valid_loader.dataset), N_VALID_EXAMPLES)
|
||||||
|
num_params = sum(p.numel() for p in model.parameters())
|
||||||
|
return accuracy, num_params
|
||||||
|
|
||||||
|
|
||||||
|
def train_model(model, optimizer, train_loader):
|
||||||
|
model.train()
|
||||||
|
for batch_idx, (data, target) in enumerate(train_loader):
|
||||||
|
if batch_idx * BATCHSIZE >= N_TRAIN_EXAMPLES:
|
||||||
|
break
|
||||||
|
|
||||||
|
data, target = data.view(data.size(0), -1).to(DEVICE), target.to(DEVICE)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
output = model(data)
|
||||||
|
loss = F.nll_loss(output, target)
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
|
||||||
|
def run_optimize(n_trials, study):
|
||||||
|
study.optimize(objective, n_trials=n_trials, timeout=600)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
study_name = f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')} mnist example"
|
||||||
|
storage = "sqlite:///db.sqlite3"
|
||||||
|
directions = ["maximize", "minimize"]
|
||||||
|
|
||||||
|
study = optuna.create_study(
|
||||||
|
directions=directions,
|
||||||
|
storage=storage,
|
||||||
|
study_name=study_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
with warnings.catch_warnings(action="ignore"):
|
||||||
|
study.set_metric_names(["accuracy", "num params"])
|
||||||
|
|
||||||
|
n_threads = min(n_trials, n_threads)
|
||||||
|
|
||||||
|
processes = []
|
||||||
|
for _ in range(n_threads):
|
||||||
|
p = multiprocessing.Process(
|
||||||
|
target=run_optimize, args=(n_trials // n_threads, study)
|
||||||
|
)
|
||||||
|
p.start()
|
||||||
|
processes.append(p)
|
||||||
|
|
||||||
|
for p in processes:
|
||||||
|
p.join()
|
||||||
|
|
||||||
|
remaining_trials = n_trials - ((n_trials // n_threads) * n_threads)
|
||||||
|
if remaining_trials:
|
||||||
|
print(
|
||||||
|
f"\nRunning last {remaining_trials} trial{'s' if remaining_trials > 1 else ''}:"
|
||||||
|
)
|
||||||
|
run_optimize(directions, remaining_trials, study_name, storage)
|
||||||
|
|
||||||
|
print(f"Number of trials on the Pareto front: {len(study.best_trials)}")
|
||||||
|
|
||||||
|
trial_with_highest_accuracy = max(study.best_trials, key=lambda t: t.values[1])
|
||||||
|
print("Trial with highest accuracy: ")
|
||||||
|
print(f"\tnumber: {trial_with_highest_accuracy.number}")
|
||||||
|
print(f"\tparams: {trial_with_highest_accuracy.params}")
|
||||||
|
print(f"\tvalues: {trial_with_highest_accuracy.values}")
|
||||||
|
|
||||||
|
# for trial in trials:
|
||||||
|
# print(f"Trial {trial.number}")
|
||||||
|
# print(f" Accuracy: {trial.values[0]}")
|
||||||
|
# print(f" n_params: {int(trial.values[1])}")
|
||||||
|
# print( " Params: ")
|
||||||
|
# for key, value in trial.params.items():
|
||||||
|
# print(" {}: {}".format(key, value))
|
||||||
|
# print()
|
||||||
|
|
||||||
|
# print(" Value: ", trial.value)
|
||||||
|
|
||||||
|
# print(" Params: ")
|
||||||
|
# for key, value in trial.params.items():
|
||||||
|
# print(" {}: {}".format(key, value))
|
||||||
|
|
||||||
|
figures = []
|
||||||
|
figures.append(
|
||||||
|
optuna.visualization.plot_pareto_front(
|
||||||
|
study, target_names=["accuracy", "num_params"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
figures.append(optuna.visualization.plot_timeline(study))
|
||||||
|
|
||||||
|
plt = show_figures(*figures)
|
||||||
|
|
||||||
|
print()
|
||||||
|
# plt.show()
|
||||||
51
src/single-core-regen/testing/sliced_dataset_test.py
Normal file
51
src/single-core-regen/testing/sliced_dataset_test.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
# move into dir single-core-regen before running
|
||||||
|
|
||||||
|
from util.dataset import SlicedDataset
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
def eye_dataset(dataset, no_symbols=None, offset=False, show=True):
|
||||||
|
if no_symbols is None:
|
||||||
|
no_symbols = len(dataset)
|
||||||
|
_, axs = plt.subplots(2,2, sharex=True, sharey=True)
|
||||||
|
|
||||||
|
xaxis = np.linspace(0,dataset.symbols_per_slice,dataset.samples_per_slice)
|
||||||
|
roll = dataset.samples_per_symbol//2 if offset else 0
|
||||||
|
for E_out, E_in in dataset[roll:dataset.samples_per_symbol*no_symbols+roll:dataset.samples_per_symbol]:
|
||||||
|
E_in_x, E_in_y, E_out_x, E_out_y = E_in[0], E_in[1], E_out[0], E_out[1]
|
||||||
|
axs[0,0].plot(xaxis, np.abs( E_in_x.numpy())**2, alpha=0.05, color='C0')
|
||||||
|
axs[1,0].plot(xaxis, np.abs( E_in_y.numpy())**2, alpha=0.05, color='C0')
|
||||||
|
axs[0,1].plot(xaxis, np.abs(E_out_x.numpy())**2, alpha=0.05, color='C0')
|
||||||
|
axs[1,1].plot(xaxis, np.abs(E_out_y.numpy())**2, alpha=0.05, color='C0')
|
||||||
|
|
||||||
|
if show:
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
# def plt_dataloader(dataloader, show=True):
|
||||||
|
# _, axs = plt.subplots(2,2, sharex=True, sharey=True)
|
||||||
|
|
||||||
|
# E_outs, E_ins = next(iter(dataloader))
|
||||||
|
# for i, (E_out, E_in) in enumerate(zip(E_outs, E_ins)):
|
||||||
|
# xaxis = np.linspace(dataset.symbols_per_slice*i,dataset.symbols_per_slice+dataset.symbols_per_slice*i,dataset.samples_per_slice)
|
||||||
|
# E_in_x, E_in_y, E_out_x, E_out_y = E_in[0], E_in[1], E_out[0], E_out[1]
|
||||||
|
# axs[0,0].plot(xaxis, np.abs(E_in_x.numpy())**2)
|
||||||
|
# axs[1,0].plot(xaxis, np.abs(E_in_y.numpy())**2)
|
||||||
|
# axs[0,1].plot(xaxis, np.abs(E_out_x.numpy())**2)
|
||||||
|
# axs[1,1].plot(xaxis, np.abs(E_out_y.numpy())**2)
|
||||||
|
|
||||||
|
# if show:
|
||||||
|
# plt.show()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
dataset = SlicedDataset("data/20241115-175517-128-16384-10000-0-0-17-0-PAM4-0.ini", symbols=1, drop_first=100)
|
||||||
|
print(dataset[0][0].shape)
|
||||||
|
|
||||||
|
eye_dataset(dataset, 1000, offset=True, show=False)
|
||||||
|
|
||||||
|
train_loader = DataLoader(dataset, batch_size=10, shuffle=False)
|
||||||
|
|
||||||
|
# plt_dataloader(train_loader, show=False)
|
||||||
|
|
||||||
|
plt.show()
|
||||||
68
src/single-core-regen/testing/torch-import-test.py
Normal file
68
src/single-core-regen/testing/torch-import-test.py
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
import torch
|
||||||
|
import time
|
||||||
|
|
||||||
|
def print_torch_env():
|
||||||
|
print("Torch version: ", torch.__version__)
|
||||||
|
print("CUDA available: ", torch.cuda.is_available())
|
||||||
|
print("CUDA version: ", torch.version.cuda)
|
||||||
|
print("CUDNN version: ", torch.backends.cudnn.version())
|
||||||
|
print("Device count: ", torch.cuda.device_count())
|
||||||
|
print("Current device: ", torch.cuda.current_device())
|
||||||
|
print("Device name: ", torch.cuda.get_device_name(0))
|
||||||
|
print("Device capability: ", torch.cuda.get_device_capability(0))
|
||||||
|
print("Device memory: ", torch.cuda.get_device_properties(0).total_memory)
|
||||||
|
|
||||||
|
def measure_runtime(func):
|
||||||
|
"""
|
||||||
|
Measure the runtime of a function.
|
||||||
|
|
||||||
|
:param func: Function to measure
|
||||||
|
:type func: function
|
||||||
|
:return: Wrapped function with runtime measurement
|
||||||
|
:rtype: function
|
||||||
|
"""
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
start_time = time.time()
|
||||||
|
result = func(*args, **kwargs)
|
||||||
|
end_time = time.time()
|
||||||
|
print(f"Runtime: {end_time - start_time:.6f} seconds")
|
||||||
|
return result, end_time - start_time
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
@measure_runtime
|
||||||
|
def tensor_addition(a, b):
|
||||||
|
"""
|
||||||
|
Perform tensor addition.
|
||||||
|
|
||||||
|
:param a: First tensor
|
||||||
|
:type a: torch.Tensor
|
||||||
|
:param b: Second tensor
|
||||||
|
:type b: torch.Tensor
|
||||||
|
:return: Sum of tensors
|
||||||
|
:rtype: torch.Tensor
|
||||||
|
"""
|
||||||
|
return a + b
|
||||||
|
|
||||||
|
def runtime_test():
|
||||||
|
x = torch.rand(2**18, 2**10)
|
||||||
|
y = torch.rand(2**18, 2**10)
|
||||||
|
|
||||||
|
print("Tensor addition on CPU")
|
||||||
|
_, cpu_time = tensor_addition(x, y)
|
||||||
|
|
||||||
|
print()
|
||||||
|
print("Tensor addition on GPU")
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
print("CUDA is not available")
|
||||||
|
return
|
||||||
|
|
||||||
|
_, gpu_time = tensor_addition(x.cuda(), y.cuda())
|
||||||
|
|
||||||
|
print()
|
||||||
|
print(f"Speedup: {cpu_time / gpu_time *100:.2f}%")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print_torch_env()
|
||||||
|
print()
|
||||||
|
runtime_test()
|
||||||
19
src/single-core-regen/util/__init__.py
Normal file
19
src/single-core-regen/util/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
from . import datasets # noqa: F401
|
||||||
|
# from .datasets import FiberRegenerationDataset # noqa: F401
|
||||||
|
# from .datasets import load_data # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
|
from . import plot # noqa: F401
|
||||||
|
# from .plot import eye # noqa: F401
|
||||||
|
|
||||||
|
from . import optuna_helpers # noqa: F401
|
||||||
|
# from .optuna_helpers import optional_suggest_categorical # noqa: F401
|
||||||
|
# from .optuna_helpers import optional_suggest_float # noqa: F401
|
||||||
|
# from .optuna_helpers import optional_suggest_int # noqa: F401
|
||||||
|
|
||||||
|
from . import complexNN # noqa: F401
|
||||||
|
# from .complexNN import UnitaryLayer # noqa: F401
|
||||||
|
# from .complexNN import complex_mse_loss # noqa: F401
|
||||||
|
# from .complexNN import complex_sse_loss # noqa: F401
|
||||||
|
|
||||||
|
from . import misc # noqa: F401
|
||||||
504
src/single-core-regen/util/complexNN.py
Normal file
504
src/single-core-regen/util/complexNN.py
Normal file
@@ -0,0 +1,504 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
# from torchlambertw.special import lambertw
|
||||||
|
|
||||||
|
|
||||||
|
def complex_mse_loss(input, target, power=False, reduction="mean"):
|
||||||
|
"""
|
||||||
|
Compute the mean squared error between two complex tensors.
|
||||||
|
If power is set to True, the loss is computed as |input|^2 - |target|^2
|
||||||
|
"""
|
||||||
|
reduce = getattr(torch, reduction)
|
||||||
|
|
||||||
|
if power:
|
||||||
|
input = (input * input.conj()).real.to(dtype=input.dtype.to_real())
|
||||||
|
target = (target * target.conj()).real.to(dtype=target.dtype.to_real())
|
||||||
|
|
||||||
|
if input.is_complex() and target.is_complex():
|
||||||
|
return reduce(torch.square(input.real - target.real) + torch.square(input.imag - target.imag))
|
||||||
|
elif input.is_complex() or target.is_complex():
|
||||||
|
raise ValueError("Input and target must have the same type (real or complex)")
|
||||||
|
else:
|
||||||
|
return F.mse_loss(input, target, reduction=reduction)
|
||||||
|
|
||||||
|
|
||||||
|
def complex_sse_loss(input, target):
|
||||||
|
"""
|
||||||
|
Compute the sum squared error between two complex tensors.
|
||||||
|
"""
|
||||||
|
if input.is_complex():
|
||||||
|
return torch.sum(torch.square(input.real - target.real) + torch.square(input.imag - target.imag))
|
||||||
|
else:
|
||||||
|
return torch.sum(torch.square(input - target))
|
||||||
|
|
||||||
|
|
||||||
|
class UnitaryLayer(nn.Module):
|
||||||
|
def __init__(self, in_features, out_features, dtype=None):
|
||||||
|
assert in_features >= out_features
|
||||||
|
super(UnitaryLayer, self).__init__()
|
||||||
|
self.in_features = in_features
|
||||||
|
self.out_features = out_features
|
||||||
|
self.weight = nn.Parameter(torch.randn(in_features, out_features, dtype=dtype))
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
q, _ = torch.linalg.qr(self.weight)
|
||||||
|
self.weight.data = q
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.matmul(x, self.weight)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"UnitaryLayer({self.in_features}, {self.out_features})"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class _Unitary(nn.Module):
|
||||||
|
def forward(self, X:torch.Tensor):
|
||||||
|
if X.ndim < 2:
|
||||||
|
raise ValueError(
|
||||||
|
"Only tensors with 2 or more dimensions are supported. "
|
||||||
|
f"Got a tensor of shape {X.shape}"
|
||||||
|
)
|
||||||
|
n, k = X.size(-2), X.size(-1)
|
||||||
|
transpose = n<k
|
||||||
|
if transpose:
|
||||||
|
X = X.transpose(-2, -1)
|
||||||
|
q, r = torch.linalg.qr(X)
|
||||||
|
# q: torch.Tensor = q
|
||||||
|
# r: torch.Tensor = r
|
||||||
|
d = r.diagonal(dim1=-2, dim2=-1).sgn()
|
||||||
|
q*=d.unsqueeze(-2)
|
||||||
|
if transpose:
|
||||||
|
q = q.transpose(-2, -1)
|
||||||
|
if n == k:
|
||||||
|
mask = (torch.linalg.det(q).abs() >= 0).to(q.dtype.to_real())
|
||||||
|
mask[mask == 0] = -1
|
||||||
|
mask = mask.unsqueeze(-1)
|
||||||
|
q[..., 0] *= mask
|
||||||
|
# X.copy_(q)
|
||||||
|
return q
|
||||||
|
|
||||||
|
def unitary(module: nn.Module, name: str = "weight") -> nn.Module:
|
||||||
|
weight = getattr(module, name, None)
|
||||||
|
if not isinstance(weight, torch.Tensor):
|
||||||
|
raise ValueError(f"Module '{module}' has no parameter or buffer '{name}'")
|
||||||
|
|
||||||
|
if weight.ndim < 2:
|
||||||
|
raise ValueError(f"Expected a matrix or batch of matrices. Got a tensor of {weight.ndim} dimensions.")
|
||||||
|
|
||||||
|
if weight.shape[-2] != weight.shape[-1]:
|
||||||
|
raise ValueError(f"Expected a square matrix or batch of square matrices. Got a tensor of shape {weight.shape}")
|
||||||
|
|
||||||
|
unit = _Unitary()
|
||||||
|
nn.utils.parametrize.register_parametrization(module, name, unit)
|
||||||
|
return module
|
||||||
|
|
||||||
|
class _SpecialUnitary(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, X:torch.Tensor):
|
||||||
|
n, k = X.size(-2), X.size(-1)
|
||||||
|
if n != k:
|
||||||
|
raise ValueError(f"Expected a square matrix. Got a tensor of shape {X.shape}")
|
||||||
|
q, _ = torch.linalg.qr(X)
|
||||||
|
q = q / torch.linalg.det(q).pow(1/n)
|
||||||
|
|
||||||
|
return q
|
||||||
|
|
||||||
|
def special_unitary(module: nn.Module, name: str = "weight") -> nn.Module:
|
||||||
|
weight = getattr(module, name, None)
|
||||||
|
if not isinstance(weight, torch.Tensor):
|
||||||
|
raise ValueError(f"Module '{module}' has no parameter or buffer '{name}'")
|
||||||
|
|
||||||
|
if weight.ndim < 2:
|
||||||
|
raise ValueError(f"Expected a matrix or batch of matrices. Got a tensor of {weight.ndim} dimensions.")
|
||||||
|
|
||||||
|
if weight.shape[-2] != weight.shape[-1]:
|
||||||
|
raise ValueError(f"Expected a square matrix or batch of square matrices. Got a tensor of shape {weight.shape}")
|
||||||
|
|
||||||
|
unit = _SpecialUnitary()
|
||||||
|
nn.utils.parametrize.register_parametrization(module, name, unit)
|
||||||
|
return module
|
||||||
|
|
||||||
|
class _Clamp(nn.Module):
|
||||||
|
def __init__(self, min, max):
|
||||||
|
super(_Clamp, self).__init__()
|
||||||
|
self.min = min
|
||||||
|
self.max = max
|
||||||
|
def forward(self, x):
|
||||||
|
if x.is_complex():
|
||||||
|
# clamp magnitude, ignore phase
|
||||||
|
return torch.clamp(x.abs(), self.min, self.max) * x / x.abs()
|
||||||
|
return torch.clamp(x, self.min, self.max)
|
||||||
|
|
||||||
|
|
||||||
|
def clamp(module: nn.Module, name: str = "scale", min=0, max=1) -> nn.Module:
|
||||||
|
scale = getattr(module, name, None)
|
||||||
|
if not isinstance(scale, torch.Tensor):
|
||||||
|
raise ValueError(f"Module '{module}' has no parameter or buffer '{name}'")
|
||||||
|
|
||||||
|
cl = _Clamp(min, max)
|
||||||
|
nn.utils.parametrize.register_parametrization(module, name, cl)
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
class ONNMiller(nn.Module):
|
||||||
|
def __init__(self, input_dim, output_dim, dtype=None) -> None:
|
||||||
|
super(ONNMiller, self).__init__()
|
||||||
|
self.input_dim = input_dim
|
||||||
|
self.output_dim = output_dim
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
self.dim = max(input_dim, output_dim)
|
||||||
|
|
||||||
|
# zero pad input to internal size if smaller
|
||||||
|
if self.input_dim < self.dim:
|
||||||
|
self.pad = lambda x: F.pad(x, ((self.dim - self.input_dim) // 2, (self.dim - self.input_dim + 1) // 2))
|
||||||
|
else:
|
||||||
|
self.pad = lambda x: x
|
||||||
|
self.pad.__doc__ = f"Zero pad input from {self.input_dim} to {self.dim}"
|
||||||
|
|
||||||
|
# crop output to desired size
|
||||||
|
if self.output_dim < self.dim:
|
||||||
|
self.crop = lambda x: x[:, (self.dim - self.output_dim) // 2 : (x.shape[1] - (self.dim - self.output_dim + 1) // 2)]
|
||||||
|
else:
|
||||||
|
self.crop = lambda x: x
|
||||||
|
self.crop.__doc__ = f"Crop output from {self.dim} to {self.output_dim}"
|
||||||
|
|
||||||
|
self.U = nn.Parameter(torch.randn(self.dim, self.dim, dtype=self.dtype)) # -> parametrization: Unitary
|
||||||
|
self.S = nn.Parameter(torch.randn(self.dim, dtype=self.dtype)) # -> parametrization: Clamp (magnitude 0..1)
|
||||||
|
self.V = nn.Parameter(torch.randn(self.dim, self.dim, dtype=self.dtype)) # -> parametrization: Unitary
|
||||||
|
self.register_buffer("MZI_scale", torch.tensor(2, dtype=self.dtype.to_real()).sqrt())
|
||||||
|
# V is actually V.H, but
|
||||||
|
|
||||||
|
def forward(self, x_in):
|
||||||
|
x = x_in
|
||||||
|
x = self.pad(x)
|
||||||
|
x = x @ self.U
|
||||||
|
x = x * (self.S.squeeze() / self.MZI_scale)
|
||||||
|
x = x @ self.V
|
||||||
|
x = self.crop(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class ONN(nn.Module):
|
||||||
|
def __init__(self, input_dim, output_dim, dtype=None) -> None:
|
||||||
|
super(ONN, self).__init__()
|
||||||
|
self.input_dim = input_dim
|
||||||
|
self.output_dim = output_dim
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
self.dim = max(input_dim, output_dim)
|
||||||
|
|
||||||
|
# zero pad input to internal size if smaller
|
||||||
|
if self.input_dim < self.dim:
|
||||||
|
self.pad = lambda x: F.pad(x, ((self.dim - self.input_dim) // 2, (self.dim - self.input_dim + 1) // 2))
|
||||||
|
self.pad.__doc__ = f"Zero pad input from {self.input_dim} to {self.dim}"
|
||||||
|
else:
|
||||||
|
self.pad = lambda x: x
|
||||||
|
self.pad.__doc__ = f"Input size equals internal size {self.dim}"
|
||||||
|
|
||||||
|
# crop output to desired size
|
||||||
|
if self.output_dim < self.dim:
|
||||||
|
self.crop = lambda x: x[:, (self.dim - self.output_dim) // 2 : (x.shape[1] - (self.dim - self.output_dim + 1) // 2)]
|
||||||
|
self.crop.__doc__ = f"Crop output from {self.dim} to {self.output_dim}"
|
||||||
|
else:
|
||||||
|
self.crop = lambda x: x
|
||||||
|
self.crop.__doc__ = f"Output size equals internal size {self.dim}"
|
||||||
|
|
||||||
|
|
||||||
|
self.weight = nn.Parameter(torch.randn(self.dim, self.dim, dtype=self.dtype))
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
q, _ = torch.linalg.qr(self.weight)
|
||||||
|
self.weight.data = q
|
||||||
|
# def get_M(self):
|
||||||
|
# return self.U @ self.sigma @ self.V
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.crop(self.pad(x) @ self.weight)
|
||||||
|
|
||||||
|
|
||||||
|
class SemiUnitaryLayer(nn.Module):
|
||||||
|
def __init__(self, input_dim, output_dim, dtype=None):
|
||||||
|
super(SemiUnitaryLayer, self).__init__()
|
||||||
|
self.input_dim = input_dim
|
||||||
|
self.output_dim = output_dim
|
||||||
|
|
||||||
|
# Create a larger square matrix for QR decomposition
|
||||||
|
self.weight = nn.Parameter(torch.randn(max(input_dim, output_dim), max(input_dim, output_dim), dtype=dtype))
|
||||||
|
self.scale = nn.Parameter(torch.tensor(1.0, dtype=dtype.to_real()))
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
# Ensure the weights are unitary by QR decomposition
|
||||||
|
q, _ = torch.linalg.qr(self.weight)
|
||||||
|
# A = QR with A being a complex square matrix -> Q is unitary, R is upper triangular
|
||||||
|
|
||||||
|
# truncate the matrix to the desired size
|
||||||
|
if self.input_dim > self.output_dim:
|
||||||
|
self.weight.data = q[: self.input_dim, : self.output_dim]
|
||||||
|
else:
|
||||||
|
self.weight.data = q[: self.output_dim, : self.input_dim].t()
|
||||||
|
...
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
with torch.no_grad():
|
||||||
|
scale = torch.clamp(self.scale, 0.0, 1.0)
|
||||||
|
out = torch.matmul(x, scale * self.weight)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"SemiUnitaryLayer({self.input_dim}, {self.output_dim})"
|
||||||
|
|
||||||
|
|
||||||
|
# class SaturableAbsorberLambertW(nn.Module):
|
||||||
|
# """
|
||||||
|
# Implements the activation function for an optical saturable absorber
|
||||||
|
|
||||||
|
# base eqn: sigma*tau*I0 = 0.5*(log(Tm/T0))/(1-Tm),
|
||||||
|
# where: sigma is the absorption cross section
|
||||||
|
# tau is the radiative lifetime of the absorber material
|
||||||
|
# T0 is the initial transmittance
|
||||||
|
# I0 is the input intensity
|
||||||
|
# Tm is the transmittance of the absorber
|
||||||
|
|
||||||
|
# The activation function is defined as:
|
||||||
|
# Iout = I0 * Tm(I0)
|
||||||
|
|
||||||
|
# where Tm(I0) is the transmittance of the absorber as a function of the input intensity I0
|
||||||
|
|
||||||
|
# for a unit sigma*tau product, he solution Tm(I0) is given by:
|
||||||
|
# Tm(I0) = (W(2*exp(2*I0)*I0*T0))/(2*I0),
|
||||||
|
# where W is the Lambert W function
|
||||||
|
|
||||||
|
# if sigma*tau is not 1, I0 has to be scaled by sigma*tau
|
||||||
|
# (-> x has to be scaled by sqrt(sigma*tau))
|
||||||
|
|
||||||
|
# """
|
||||||
|
|
||||||
|
# def __init__(self, T0):
|
||||||
|
# super(SaturableAbsorberLambertW, self).__init__()
|
||||||
|
# self.register_buffer("T0", torch.tensor(T0))
|
||||||
|
|
||||||
|
# def forward(self, x: torch.Tensor):
|
||||||
|
# xc = x.conj()
|
||||||
|
# two_x_xc = (2 * x * xc).real
|
||||||
|
# return (lambertw(2 * torch.exp(two_x_xc) * (x * self.T0 * xc).real) / two_x_xc).to(dtype=x.dtype)
|
||||||
|
|
||||||
|
# def backward(self, x):
|
||||||
|
# xc = x.conj()
|
||||||
|
# lambert_eval = lambertw(2 * torch.exp(2 * x * xc).real * (x * self.T0 * xc).real)
|
||||||
|
# return (((xc * (-2 * lambert_eval + 2 * torch.square(x) - 1) + 2 * x * torch.square(xc) + x) * lambert_eval) / (
|
||||||
|
# 2 * torch.pow(x, 3) * xc * (lambert_eval + 1)
|
||||||
|
# )).to(dtype=x.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
# class SaturableAbsorber(nn.Module):
|
||||||
|
# def __init__(self, alpha, I0):
|
||||||
|
# super(SaturableAbsorber, self).__init__()
|
||||||
|
# self.register_buffer("alpha", torch.tensor(alpha))
|
||||||
|
# self.register_buffer("I0", torch.tensor(I0))
|
||||||
|
|
||||||
|
# def forward(self, x):
|
||||||
|
# I = (x*x.conj()).to(dtype=x.dtype.to_real())
|
||||||
|
# A = self.alpha/(1+I/self.I0)
|
||||||
|
|
||||||
|
|
||||||
|
# class SpreadLayer(nn.Module):
|
||||||
|
# def __init__(self, in_features, out_features, dtype=None):
|
||||||
|
# super(SpreadLayer, self).__init__()
|
||||||
|
# self.in_features = in_features
|
||||||
|
# self.out_features = out_features
|
||||||
|
# self.mat = torch.ones(in_features, out_features, dtype=dtype)*torch.sqrt(torch.tensor(in_features/out_features))
|
||||||
|
|
||||||
|
# def forward(self, x):
|
||||||
|
# # N in_features -> M out_features, Enery is preserved (P = abs(x)^2)
|
||||||
|
# out = torch.matmul(x, self.mat)
|
||||||
|
# return out
|
||||||
|
|
||||||
|
|
||||||
|
#### as defined by zhang et al
|
||||||
|
|
||||||
|
|
||||||
|
class DropoutComplex(nn.Module):
|
||||||
|
def __init__(self, p=0.5):
|
||||||
|
super(DropoutComplex, self).__init__()
|
||||||
|
self.dropout = nn.Dropout(p=p)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if x.is_complex():
|
||||||
|
mask = self.dropout(torch.ones_like(x.real))
|
||||||
|
return x * mask
|
||||||
|
else:
|
||||||
|
return self.dropout(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Identity(nn.Module):
|
||||||
|
"""
|
||||||
|
implements the "activation" function
|
||||||
|
M(z) = z
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(Identity, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
class PowRot(nn.Module):
|
||||||
|
def __init__(self, bias=False):
|
||||||
|
super(PowRot, self).__init__()
|
||||||
|
self.scale = nn.Parameter(torch.tensor(1.0))
|
||||||
|
if bias:
|
||||||
|
self.bias = nn.Parameter(torch.tensor(0.0))
|
||||||
|
else:
|
||||||
|
self.register_buffer("bias", torch.tensor(0.0))
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
if x.is_complex():
|
||||||
|
return x * torch.exp(-self.scale*1j*x.abs().square()+self.bias.to(dtype=x.dtype))
|
||||||
|
else:
|
||||||
|
return x
|
||||||
|
|
||||||
|
class Pow(nn.Module):
|
||||||
|
"""
|
||||||
|
implements the activation function
|
||||||
|
M(z) = ||z||^2 + b
|
||||||
|
"""
|
||||||
|
def __init__(self, bias=False):
|
||||||
|
super(Pow, self).__init__()
|
||||||
|
if bias:
|
||||||
|
self.bias = nn.Parameter(torch.tensor(0.0))
|
||||||
|
else:
|
||||||
|
self.register_buffer("bias", torch.tensor(0.0))
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
return x.abs().square().add(self.bias).to(dtype=x.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class Mag(nn.Module):
|
||||||
|
"""
|
||||||
|
implements the activation function
|
||||||
|
M(z) = ||z||+b
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, bias=False):
|
||||||
|
super(Mag, self).__init__()
|
||||||
|
if bias:
|
||||||
|
self.bias = nn.Parameter(torch.tensor(0.0))
|
||||||
|
else:
|
||||||
|
self.register_buffer("bias", torch.tensor(0.0))
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
return x.abs().add(self.bias).to(dtype=x.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class MagScale(nn.Module):
|
||||||
|
def __init__(self, bias=False):
|
||||||
|
super(MagScale, self).__init__()
|
||||||
|
if bias:
|
||||||
|
self.bias = nn.Parameter(torch.tensor(0.0))
|
||||||
|
else:
|
||||||
|
self.register_buffer("bias", torch.tensor(0.0))
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
return x.abs().add(self.bias).to(dtype=x.dtype).sin().mul(x)
|
||||||
|
|
||||||
|
class PowScale(nn.Module):
|
||||||
|
def __init__(self, bias=False):
|
||||||
|
super(PowScale, self).__init__()
|
||||||
|
if bias:
|
||||||
|
self.bias = nn.Parameter(torch.tensor(0.0))
|
||||||
|
else:
|
||||||
|
self.register_buffer("bias", torch.tensor(0.0))
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
return x.mul(x.abs().square().add(self.bias).to(dtype=x.dtype).sin())
|
||||||
|
|
||||||
|
|
||||||
|
class ModReLU(nn.Module):
|
||||||
|
"""
|
||||||
|
implements the activation function
|
||||||
|
M(z) = ReLU(||z|| + b)*exp(j*theta_z)
|
||||||
|
= ReLU(||z|| + b)*z/||z||
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, bias=True):
|
||||||
|
super(ModReLU, self).__init__()
|
||||||
|
if bias:
|
||||||
|
self.bias = nn.Parameter(torch.tensor(0.0))
|
||||||
|
else:
|
||||||
|
self.register_buffer("bias", torch.tensor(0.0))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if x.is_complex():
|
||||||
|
mod = x.abs()
|
||||||
|
out = torch.relu(mod + self.bias) * x / mod
|
||||||
|
return out.to(dtype=x.dtype)
|
||||||
|
|
||||||
|
else:
|
||||||
|
return torch.relu(x + self.bias).to(dtype=x.dtype)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"ModReLU(b={self.b})"
|
||||||
|
|
||||||
|
|
||||||
|
class CReLU(nn.Module):
|
||||||
|
"""
|
||||||
|
implements the activation function
|
||||||
|
M(z) = ReLU(Re(z)) + j*ReLU(Im(z))
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(CReLU, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if x.is_complex():
|
||||||
|
return torch.relu(x.real) + 1j * torch.relu(x.imag)
|
||||||
|
else:
|
||||||
|
return torch.relu(x)
|
||||||
|
|
||||||
|
|
||||||
|
class ZReLU(nn.Module):
|
||||||
|
"""
|
||||||
|
implements the activation function
|
||||||
|
|
||||||
|
M(z) = z if 0 <= angle(z) <= pi/2
|
||||||
|
= 0 otherwise
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(ZReLU, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if x.is_complex():
|
||||||
|
return x * (torch.angle(x) >= 0) * (torch.angle(x) <= torch.pi / 2)
|
||||||
|
else:
|
||||||
|
return torch.relu(x)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
complex_sse_loss,
|
||||||
|
complex_mse_loss,
|
||||||
|
UnitaryLayer,
|
||||||
|
unitary,
|
||||||
|
clamp,
|
||||||
|
ONN,
|
||||||
|
ONNMiller,
|
||||||
|
SemiUnitaryLayer,
|
||||||
|
DropoutComplex,
|
||||||
|
Identity,
|
||||||
|
Pow,
|
||||||
|
PowRot,
|
||||||
|
Mag,
|
||||||
|
ModReLU,
|
||||||
|
CReLU,
|
||||||
|
ZReLU,
|
||||||
|
# SaturableAbsorberLambertW,
|
||||||
|
# SaturableAbsorber,
|
||||||
|
# SpreadLayer,
|
||||||
|
]
|
||||||
282
src/single-core-regen/util/datasets.py
Normal file
282
src/single-core-regen/util/datasets.py
Normal file
@@ -0,0 +1,282 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
# from torch.utils.data import Sampler
|
||||||
|
import numpy as np
|
||||||
|
import configparser
|
||||||
|
|
||||||
|
# class SubsetSampler(Sampler[int]):
|
||||||
|
# """
|
||||||
|
# Samples elements from a given list of indices.
|
||||||
|
|
||||||
|
# :param indices: List of indices to sample from.
|
||||||
|
# :type indices: list[int]
|
||||||
|
# """
|
||||||
|
|
||||||
|
# def __init__(self, indices):
|
||||||
|
# self.indices = indices
|
||||||
|
|
||||||
|
# def __iter__(self):
|
||||||
|
# return iter(self.indices)
|
||||||
|
|
||||||
|
# def __len__(self):
|
||||||
|
# return len(self.indices)
|
||||||
|
|
||||||
|
|
||||||
|
def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=False, device=None, dtype=None):
|
||||||
|
filepath = Path(config_path)
|
||||||
|
filepath = filepath.parent.glob(filepath.name)
|
||||||
|
config = configparser.ConfigParser()
|
||||||
|
config.read(filepath)
|
||||||
|
path_elements = (
|
||||||
|
config["data"]["dir"],
|
||||||
|
config["data"]["npy_dir"],
|
||||||
|
config["data"]["file"],
|
||||||
|
)
|
||||||
|
datapath = Path("/".join(path_elements).replace('"', ""))
|
||||||
|
sps = int(config["glova"]["sps"])
|
||||||
|
|
||||||
|
if symbols is None:
|
||||||
|
symbols = int(config["glova"]["nos"]) - skipfirst
|
||||||
|
|
||||||
|
data = np.load(datapath)[skipfirst * sps : symbols * sps + skipfirst * sps]
|
||||||
|
|
||||||
|
if normalize:
|
||||||
|
# square gets normalized to 1, as the power is (proportional to) the square of the amplitude
|
||||||
|
a, b, c, d = np.square(data.T)
|
||||||
|
a, b, c, d = a / np.max(np.abs(a)), b / np.max(np.abs(b)), c / np.max(np.abs(c)), d / np.max(np.abs(d))
|
||||||
|
data = np.sqrt(np.array([a, b, c, d]).T)
|
||||||
|
|
||||||
|
if real:
|
||||||
|
data = np.abs(data)
|
||||||
|
|
||||||
|
config["glova"]["nos"] = str(symbols)
|
||||||
|
|
||||||
|
data = torch.tensor(data, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
return data, config
|
||||||
|
|
||||||
|
|
||||||
|
def roll_along(arr, shifts, dim):
|
||||||
|
# https://stackoverflow.com/a/76920720
|
||||||
|
# (c) Mateen Ulhaq, 2023
|
||||||
|
# CC BY-SA 4.0
|
||||||
|
shifts = torch.tensor(shifts)
|
||||||
|
assert arr.ndim - 1 == shifts.ndim
|
||||||
|
dim %= arr.ndim
|
||||||
|
shape = (1,) * dim + (-1,) + (1,) * (arr.ndim - dim - 1)
|
||||||
|
dim_indices = torch.arange(arr.shape[dim]).reshape(shape)
|
||||||
|
indices = (dim_indices - shifts.unsqueeze(dim)) % arr.shape[dim]
|
||||||
|
return torch.gather(arr, dim, indices)
|
||||||
|
|
||||||
|
|
||||||
|
class FiberRegenerationDataset(Dataset):
|
||||||
|
"""
|
||||||
|
Dataset for fiber regeneration training.
|
||||||
|
|
||||||
|
The dataset is loaded from a configuration file, which must contain (at least) the following sections:
|
||||||
|
```
|
||||||
|
[data]
|
||||||
|
dir = <data_dir>
|
||||||
|
npy_dir = <npy_dir>
|
||||||
|
file = <data_file>
|
||||||
|
|
||||||
|
[glova]
|
||||||
|
sps = <samples per symbol>
|
||||||
|
```
|
||||||
|
The data is loaded from the file `<data_dir>/<npy_dir>/<data_file>` and is assumed to be in the following format:
|
||||||
|
```
|
||||||
|
[ E_in_x,
|
||||||
|
E_in_y,
|
||||||
|
E_out_x,
|
||||||
|
E_out_y ]
|
||||||
|
```
|
||||||
|
|
||||||
|
The dataset is sliced into slices, where each slice consists of a (fractional) number of symbols.
|
||||||
|
The target can be delayed relative to the input data by a (fractional) number of symbols.
|
||||||
|
The x and y channels can be delayed relative to each other by a (fractional) number of symbols.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
file_path: str | Path,
|
||||||
|
symbols: int | float,
|
||||||
|
*,
|
||||||
|
output_dim: int = None,
|
||||||
|
target_delay: float | int = 0,
|
||||||
|
xy_delay: float | int = 0,
|
||||||
|
drop_first: float | int = 0,
|
||||||
|
dtype: torch.dtype = None,
|
||||||
|
real: bool = False,
|
||||||
|
device=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the dataset.
|
||||||
|
|
||||||
|
:param file_path: Path to the data file. Can contain wildcards (*). The first
|
||||||
|
:type file_path: str | pathlib.Path
|
||||||
|
:param symbols: Number of symbols in each slice. Can be a float to specify a fraction of a symbol.
|
||||||
|
:type symbols: float | int
|
||||||
|
:param data_size: Number of samples in each slice. The data is reduced by taking equally spaced samples. If unset, each slice will contain symbols*samples_per_symbol samples.
|
||||||
|
:type data_size: int, optional
|
||||||
|
:param target_delay: Delay (in fractional symbols) between data and target. A positive delay means the target is delayed relative to the data. Default is 0.
|
||||||
|
:type target_delay: float | int, optional
|
||||||
|
:param xy_delay: Delay (in fractional symbols) between the x and y channels. A positive delay means the y channel is delayed relative to the x channel. Default is 0.
|
||||||
|
:type xy_delay: float | int, optional
|
||||||
|
:param drop_first: Number of (fractional) symbols to drop from the beginning
|
||||||
|
:type drop_first: float | int
|
||||||
|
"""
|
||||||
|
|
||||||
|
# check types
|
||||||
|
assert isinstance(file_path, str), "file_path must be a string"
|
||||||
|
assert isinstance(symbols, (float, int)), "symbols must be a float or an integer"
|
||||||
|
assert output_dim is None or isinstance(output_dim, int), "output_len must be an integer"
|
||||||
|
assert isinstance(target_delay, (float, int)), "target_delay must be a float or an integer"
|
||||||
|
assert isinstance(xy_delay, (float, int)), "xy_delay must be a float or an integer"
|
||||||
|
assert isinstance(drop_first, int), "drop_first must be an integer"
|
||||||
|
|
||||||
|
# check values
|
||||||
|
assert symbols > 0, "symbols must be positive"
|
||||||
|
assert output_dim is None or output_dim > 0, "output_len must be positive or None"
|
||||||
|
assert drop_first >= 0, "drop_first must be non-negative"
|
||||||
|
|
||||||
|
faux = kwargs.pop("faux", False)
|
||||||
|
|
||||||
|
if faux:
|
||||||
|
data_raw = np.array(
|
||||||
|
[[i + 0.1j, i + 0.2j, i + 1.1j, i + 1.2j] for i in range(12800)],
|
||||||
|
dtype=np.complex128,
|
||||||
|
)
|
||||||
|
data_raw = torch.tensor(data_raw, device=device, dtype=dtype)
|
||||||
|
self.config = {
|
||||||
|
"data": {"dir": '"."', "npy_dir": '"."', "file": "faux"},
|
||||||
|
"glova": {"sps": 128},
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
data_raw, self.config = load_data(
|
||||||
|
file_path,
|
||||||
|
skipfirst=drop_first,
|
||||||
|
symbols=kwargs.pop("num_symbols", None),
|
||||||
|
real=real,
|
||||||
|
normalize=True,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.device = data_raw.device
|
||||||
|
|
||||||
|
self.samples_per_symbol = int(self.config["glova"]["sps"])
|
||||||
|
self.samples_per_slice = int(symbols * self.samples_per_symbol)
|
||||||
|
self.symbols_per_slice = self.samples_per_slice / self.samples_per_symbol
|
||||||
|
|
||||||
|
self.output_dim = output_dim or self.samples_per_slice
|
||||||
|
self.target_delay = target_delay or 0
|
||||||
|
self.xy_delay = xy_delay or 0
|
||||||
|
|
||||||
|
ovrd_target_delay_samples = kwargs.pop("ovrd_target_delay_samples", None)
|
||||||
|
ovrd_xy_delay_samples = kwargs.pop("ovrd_xy_delay_samples", None)
|
||||||
|
|
||||||
|
self.target_delay_samples = (
|
||||||
|
ovrd_target_delay_samples
|
||||||
|
if ovrd_target_delay_samples is not None
|
||||||
|
else int(self.target_delay * self.samples_per_symbol)
|
||||||
|
)
|
||||||
|
self.xy_delay_samples = (
|
||||||
|
ovrd_xy_delay_samples if ovrd_xy_delay_samples is not None else int(self.xy_delay * self.samples_per_symbol)
|
||||||
|
)
|
||||||
|
|
||||||
|
# data_raw = torch.tensor(data_raw, dtype=dtype)
|
||||||
|
|
||||||
|
# data layout
|
||||||
|
# [ [E_in_x0, E_in_y0, E_out_x0, E_out_y0],
|
||||||
|
# [E_in_x1, E_in_y1, E_out_x1, E_out_y1],
|
||||||
|
# ...
|
||||||
|
# [E_in_xN, E_in_yN, E_out_xN, E_out_yN] ]
|
||||||
|
|
||||||
|
data_raw = data_raw.transpose(0, 1)
|
||||||
|
|
||||||
|
# data layout
|
||||||
|
# [ E_in_x[0:N],
|
||||||
|
# E_in_y[0:N],
|
||||||
|
# E_out_x[0:N],
|
||||||
|
# E_out_y[0:N] ]
|
||||||
|
|
||||||
|
# shift x data by xy_delay_samples relative to the y data (example value: 3)
|
||||||
|
# [ E_in_x [0:N], [ E_in_x [ 0:N ], [ E_in_x [3:N ],
|
||||||
|
# E_in_y [0:N], -> E_in_y [-3:N-3], -> E_in_y [0:N-3],
|
||||||
|
# E_out_x[0:N], E_out_x[ 0:N ], E_out_x[3:N ],
|
||||||
|
# E_out_y[0:N] ] E_out_y[-3:N-3] ] E_out_y[0:N-3] ]
|
||||||
|
|
||||||
|
if self.xy_delay_samples != 0:
|
||||||
|
data_raw = roll_along(data_raw, [0, self.xy_delay_samples, 0, self.xy_delay_samples], dim=1)
|
||||||
|
if self.xy_delay_samples > 0:
|
||||||
|
data_raw = data_raw[:, self.xy_delay_samples :]
|
||||||
|
elif self.xy_delay_samples < 0:
|
||||||
|
data_raw = data_raw[:, : self.xy_delay_samples]
|
||||||
|
|
||||||
|
# shift fiber input data (target) by target_delay_samples relative to the fiber output data (input)
|
||||||
|
# (example value: 5)
|
||||||
|
# [ E_in_x [0:N], [ E_in_x [-5:N-5], [ E_in_x [0:N-5],
|
||||||
|
# E_in_y [0:N], -> E_in_y [-5:N-5], -> E_in_y [0:N-5],
|
||||||
|
# E_out_x[0:N], E_out_x[ 0:N ], E_out_x[5:N ],
|
||||||
|
# E_out_y[0:N] ] E_out_y[ 0:N ] ] E_out_y[5:N ] ]
|
||||||
|
|
||||||
|
if self.target_delay_samples != 0:
|
||||||
|
data_raw = roll_along(
|
||||||
|
data_raw,
|
||||||
|
[self.target_delay_samples, self.target_delay_samples, 0, 0],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
if self.target_delay_samples > 0:
|
||||||
|
data_raw = data_raw[:, self.target_delay_samples :]
|
||||||
|
elif self.target_delay_samples < 0:
|
||||||
|
data_raw = data_raw[:, : self.target_delay_samples]
|
||||||
|
|
||||||
|
data_raw = data_raw.view(2, 2, -1)
|
||||||
|
# data layout
|
||||||
|
# [ [E_in_x, E_in_y],
|
||||||
|
# [E_out_x, E_out_y] ]
|
||||||
|
|
||||||
|
self.data = data_raw.unfold(dimension=-1, size=self.samples_per_slice, step=1)
|
||||||
|
self.data = self.data.movedim(-2, 0)
|
||||||
|
# -> [no_slices, 2, 2, samples_per_slice]
|
||||||
|
|
||||||
|
# data layout
|
||||||
|
# [
|
||||||
|
# [ [E_in_x[0:N+0], E_in_y[0:N+0] ], [ E_out_x[0:N+0], E_out_y[0:N+0] ] ],
|
||||||
|
# [ [E_in_x[1:N+1], E_in_y[1:N+1] ], [ E_out_x[1:N+1], E_out_y[1:N+1] ] ],
|
||||||
|
# ...
|
||||||
|
# ] -> [no_slices, 2, 2, samples_per_slice]
|
||||||
|
|
||||||
|
...
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.data.shape[0]
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
if isinstance(idx, slice):
|
||||||
|
return [self.__getitem__(i) for i in range(*idx.indices(len(self)))]
|
||||||
|
else:
|
||||||
|
data, target = self.data[idx, 1].squeeze(), self.data[idx, 0].squeeze()
|
||||||
|
|
||||||
|
# reduce by by taking self.output_dim equally spaced samples
|
||||||
|
data = data[:, : data.shape[1] // self.output_dim * self.output_dim]
|
||||||
|
data = data.view(data.shape[0], self.output_dim, -1)
|
||||||
|
data = data[:, :, 0]
|
||||||
|
|
||||||
|
# target is corresponding to the middle of the data as the output sample is influenced by the data before and after it
|
||||||
|
target = target[:, : target.shape[1] // self.output_dim * self.output_dim]
|
||||||
|
target = target.view(target.shape[0], self.output_dim, -1)
|
||||||
|
target = target[:, 0, target.shape[2] // 2]
|
||||||
|
|
||||||
|
data = data.transpose(0, 1).flatten().squeeze()
|
||||||
|
target = target.flatten().squeeze()
|
||||||
|
|
||||||
|
# data layout:
|
||||||
|
# [sample_x0, sample_y0, sample_x1, sample_y1, ...]
|
||||||
|
# target layout:
|
||||||
|
# [sample_x0, sample_y0]
|
||||||
|
|
||||||
|
return data, target
|
||||||
21
src/single-core-regen/util/misc.py
Normal file
21
src/single-core-regen/util/misc.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
def multi_getattr(objs, attr, fallback=None):
|
||||||
|
"""
|
||||||
|
tries to get the attribute from a list of objects, returning the first hit
|
||||||
|
if no object has the attribute, it returns the fallback value if provided, otherwise raises AttributeError
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return _multi_getattr(objs, attr)
|
||||||
|
except AttributeError as e:
|
||||||
|
if fallback is not None:
|
||||||
|
return fallback
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def _multi_getattr(objs, attr):
|
||||||
|
if not isinstance(objs, (list, tuple)):
|
||||||
|
objs = [objs]
|
||||||
|
for obj in objs:
|
||||||
|
try:
|
||||||
|
return getattr(obj, attr)
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
raise AttributeError(f"None of the objects has attribute {attr}")
|
||||||
335
src/single-core-regen/util/optuna_helpers.py
Normal file
335
src/single-core-regen/util/optuna_helpers.py
Normal file
@@ -0,0 +1,335 @@
|
|||||||
|
from typing import Any
|
||||||
|
from optuna import trial
|
||||||
|
|
||||||
|
|
||||||
|
def install_optional_suggests():
|
||||||
|
trial.Trial.suggest_categorical_optional = suggest_categorical_optional_wrapper
|
||||||
|
trial.Trial.suggest_int_optional = suggest_int_optional_wrapper
|
||||||
|
trial.Trial.suggest_float_optional = suggest_float_optional_wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def _is_listlike(obj: Any) -> bool:
|
||||||
|
return hasattr(obj, "__iter__") and not isinstance(obj, str)
|
||||||
|
|
||||||
|
|
||||||
|
def _optional_suggest(
|
||||||
|
*,
|
||||||
|
trial: trial.Trial,
|
||||||
|
name: str,
|
||||||
|
range_or_value: Any,
|
||||||
|
type: str,
|
||||||
|
log: bool = False,
|
||||||
|
step: int | float | None = None,
|
||||||
|
add_user: bool = True,
|
||||||
|
force: bool = False,
|
||||||
|
multiply: float | int = 1,
|
||||||
|
set_new: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Suggest a value for a parameter with more control over the process
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
type : str
|
||||||
|
The type of the parameter
|
||||||
|
trial : optuna.trial.Trial
|
||||||
|
The trial object
|
||||||
|
name : str
|
||||||
|
The name of the parameter
|
||||||
|
range_or_value : Any
|
||||||
|
The range of values or a single value
|
||||||
|
log : bool, optional
|
||||||
|
Whether to use a logarithmic scale, by default False
|
||||||
|
step : int|float|None, optional
|
||||||
|
The step size, by default None
|
||||||
|
add_user : bool, optional
|
||||||
|
Whether to add the suggested value to the user attributes if not added as a parameter, by default False
|
||||||
|
force : bool, optional
|
||||||
|
Whether to force a single value to be suggested, by default False
|
||||||
|
multiply : float| int, optional
|
||||||
|
A multiplier to apply to the range or value, by default 1. Ignored for type "categorical".
|
||||||
|
set_new : bool, optional
|
||||||
|
Whether to override the parameter if it already exists, by default True
|
||||||
|
"""
|
||||||
|
|
||||||
|
# value should be retrieved from trial
|
||||||
|
if not set_new and name in trial.params:
|
||||||
|
return trial.params[name]
|
||||||
|
|
||||||
|
# value is not a list or tuple
|
||||||
|
if not _is_listlike(range_or_value):
|
||||||
|
range_or_value = (range_or_value,)
|
||||||
|
|
||||||
|
# range with only one value
|
||||||
|
if len(range_or_value) == 1 and not force:
|
||||||
|
if add_user:
|
||||||
|
trial.set_user_attr(name, range_or_value[0])
|
||||||
|
return range_or_value[0]
|
||||||
|
|
||||||
|
# normal operation
|
||||||
|
if type == "categorical":
|
||||||
|
return trial.suggest_categorical(name, range_or_value)
|
||||||
|
|
||||||
|
# multiply range
|
||||||
|
range_or_value = tuple(multiply * x for x in range_or_value)
|
||||||
|
#
|
||||||
|
if len(range_or_value) > 2:
|
||||||
|
raise UserWarning("More than two values in range, using highest and lowest")
|
||||||
|
low = min(range_or_value)
|
||||||
|
high = max(range_or_value)
|
||||||
|
|
||||||
|
if type == "float":
|
||||||
|
return trial.suggest_float(name, low, high, step=step, log=log)
|
||||||
|
|
||||||
|
if type == "int":
|
||||||
|
step = step or 1
|
||||||
|
lowi = int(low)
|
||||||
|
highi = int(high)
|
||||||
|
if lowi != low or highi != high:
|
||||||
|
raise ValueError(f"Range {low} to {high} (using multiplier {multiply}) is not valid for int")
|
||||||
|
return trial.suggest_int(name, lowi, highi, step=step, log=log)
|
||||||
|
|
||||||
|
raise ValueError(f"Unknown type: {type}")
|
||||||
|
|
||||||
|
|
||||||
|
def suggest_categorical_optional(
|
||||||
|
trial: trial.Trial,
|
||||||
|
name: str,
|
||||||
|
choices_or_value: tuple[Any] | list[Any] | Any,
|
||||||
|
add_user: bool = True,
|
||||||
|
force: bool = False,
|
||||||
|
set_new: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Suggest a value for a categorical parameter with more control over the process
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
trial : optuna.trial.Trial
|
||||||
|
The trial object
|
||||||
|
name : str
|
||||||
|
The name of the parameter
|
||||||
|
choices_or_value : tuple|list|Any
|
||||||
|
The choices or a single value
|
||||||
|
add_user : bool, optional
|
||||||
|
Whether to add the suggested value to the user attributes if not added as a parameter, by default False
|
||||||
|
force : bool, optional
|
||||||
|
Whether to suggest a single value as a parameter, by default False
|
||||||
|
set_new : bool, optional
|
||||||
|
Whether to override the parameter if it already exists, by default True
|
||||||
|
"""
|
||||||
|
return _optional_suggest(
|
||||||
|
trial=trial, name=name, range_or_value=choices_or_value, type="categorical", add_user=add_user, force=force, set_new=set_new
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def suggest_int_optional(
|
||||||
|
trial: trial.Trial,
|
||||||
|
name: str,
|
||||||
|
range_or_value: tuple[int] | list[int] | int,
|
||||||
|
step: int = 1,
|
||||||
|
log: bool = False,
|
||||||
|
add_user: bool = True,
|
||||||
|
force: bool = False,
|
||||||
|
multiply: int = 1,
|
||||||
|
set_new: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Suggest a value for an integer parameter with more control over the process
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
trial : optuna.trial.Trial
|
||||||
|
The trial object
|
||||||
|
name : str
|
||||||
|
The name of the parameter
|
||||||
|
range_or_value : tuple|list|int
|
||||||
|
The range of values or a single value.
|
||||||
|
step : int, optional
|
||||||
|
The step size, by default 1
|
||||||
|
log : bool, optional
|
||||||
|
Whether to use a logarithmic scale, by default False
|
||||||
|
add_user : bool, optional
|
||||||
|
Whether to add the suggested value to the user attributes if not added as a parameter, by default False
|
||||||
|
force : bool, optional
|
||||||
|
Whether to suggest a single value as a parameter, by default False
|
||||||
|
"""
|
||||||
|
return _optional_suggest(
|
||||||
|
trial=trial,
|
||||||
|
name=name,
|
||||||
|
range_or_value=range_or_value,
|
||||||
|
step=step,
|
||||||
|
log=log,
|
||||||
|
type="int",
|
||||||
|
add_user=add_user,
|
||||||
|
force=force,
|
||||||
|
multiply=multiply,
|
||||||
|
set_new=set_new,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def suggest_float_optional(
|
||||||
|
trial: trial.Trial,
|
||||||
|
name: str,
|
||||||
|
range_or_value: tuple[float] | list[float] | float,
|
||||||
|
step: float | None = None,
|
||||||
|
log: bool = False,
|
||||||
|
add_user: bool = True,
|
||||||
|
force: bool = False,
|
||||||
|
multiply: float = 1,
|
||||||
|
set_new: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Suggest a value for a float parameter with more control over the process
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
trial : optuna.trial.Trial
|
||||||
|
The trial object
|
||||||
|
name : str
|
||||||
|
The name of the parameter
|
||||||
|
range_or_value : tuple|list|float
|
||||||
|
The range of values or a single value
|
||||||
|
step : float|None, optional
|
||||||
|
The step size, by default None
|
||||||
|
log : bool, optional
|
||||||
|
Whether to use a logarithmic scale, by default False
|
||||||
|
add_user : bool, optional
|
||||||
|
Whether to add the suggested value to the user attributes if not added as a parameter, by default False
|
||||||
|
force : bool, optional
|
||||||
|
Whether to suggest a single value as a parameter, by default False
|
||||||
|
multiply : float, optional
|
||||||
|
A multiplier to apply to the range or value, by default 1
|
||||||
|
set_new : bool, optional
|
||||||
|
Whether to override the parameter if it already exists, by default True
|
||||||
|
"""
|
||||||
|
|
||||||
|
return _optional_suggest(
|
||||||
|
trial=trial,
|
||||||
|
name=name,
|
||||||
|
range_or_value=range_or_value,
|
||||||
|
step=step,
|
||||||
|
log=log,
|
||||||
|
type="float",
|
||||||
|
add_user=add_user,
|
||||||
|
force=force,
|
||||||
|
multiply=multiply,
|
||||||
|
set_new=set_new,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def suggest_categorical_optional_wrapper(
|
||||||
|
self: trial.Trial,
|
||||||
|
name: str,
|
||||||
|
choices_or_value: tuple[Any] | list[Any] | Any,
|
||||||
|
add_user: bool = True,
|
||||||
|
force: bool = False,
|
||||||
|
set_new: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Suggest a value for a categorical parameter with more control over the process
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
name : str
|
||||||
|
The name of the parameter
|
||||||
|
choices_or_value : tuple|list|Any
|
||||||
|
The choices or a single value
|
||||||
|
add_user : bool, optional
|
||||||
|
Whether to add the suggested value to the user attributes if not added as a parameter, by default False
|
||||||
|
force : bool, optional
|
||||||
|
Whether to suggest a single value as a parameter, by default False
|
||||||
|
set_new : bool, optional
|
||||||
|
Whether to override the parameter if it already exists, by default True
|
||||||
|
"""
|
||||||
|
return suggest_categorical_optional(
|
||||||
|
trial=self, name=name, choices_or_value=choices_or_value, add_user=add_user, force=force, set_new=set_new
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def suggest_int_optional_wrapper(
|
||||||
|
self: trial.Trial,
|
||||||
|
name: str,
|
||||||
|
range_or_value: tuple[int] | list[int] | int,
|
||||||
|
step: int = 1,
|
||||||
|
log: bool = False,
|
||||||
|
add_user: bool = True,
|
||||||
|
force: bool = False,
|
||||||
|
multiply: int = 1,
|
||||||
|
set_new: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Suggest a value for an integer parameter with more control over the process
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
name : str
|
||||||
|
The name of the parameter
|
||||||
|
range_or_value : tuple|list|int
|
||||||
|
The range of values or a single value.
|
||||||
|
step : int, optional
|
||||||
|
The step size, by default 1
|
||||||
|
log : bool, optional
|
||||||
|
Whether to use a logarithmic scale, by default False
|
||||||
|
add_user : bool, optional
|
||||||
|
Whether to add the suggested value to the user attributes if not added as a parameter, by default False
|
||||||
|
force : bool, optional
|
||||||
|
Whether to suggest a single value as a parameter, by default False
|
||||||
|
"""
|
||||||
|
return suggest_int_optional(
|
||||||
|
trial=self,
|
||||||
|
name=name,
|
||||||
|
range_or_value=range_or_value,
|
||||||
|
step=step,
|
||||||
|
log=log,
|
||||||
|
add_user=add_user,
|
||||||
|
force=force,
|
||||||
|
multiply=multiply,
|
||||||
|
set_new=set_new,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def suggest_float_optional_wrapper(
|
||||||
|
self: trial.Trial,
|
||||||
|
name: str,
|
||||||
|
range_or_value: tuple[float] | list[float] | float,
|
||||||
|
step: float | None = None,
|
||||||
|
log: bool = False,
|
||||||
|
add_user: bool = True,
|
||||||
|
force: bool = False,
|
||||||
|
multiply: float = 1,
|
||||||
|
set_new: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Suggest a value for a float parameter with more control over the process
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
name : str
|
||||||
|
The name of the parameter
|
||||||
|
range_or_value : tuple|list|float
|
||||||
|
The range of values or a single value
|
||||||
|
step : float|None, optional
|
||||||
|
The step size, by default None
|
||||||
|
log : bool, optional
|
||||||
|
Whether to use a logarithmic scale, by default False
|
||||||
|
add_user : bool, optional
|
||||||
|
Whether to add the suggested value to the user attributes if not added as a parameter, by default False
|
||||||
|
force : bool, optional
|
||||||
|
Whether to suggest a single value as a parameter, by default False
|
||||||
|
multiply : float, optional
|
||||||
|
A multiplier to apply to the range or value, by default 1
|
||||||
|
set_new : bool, optional
|
||||||
|
Whether to override the parameter if it already exists, by default True
|
||||||
|
"""
|
||||||
|
return suggest_float_optional(
|
||||||
|
trial=self,
|
||||||
|
name=name,
|
||||||
|
range_or_value=range_or_value,
|
||||||
|
step=step,
|
||||||
|
log=log,
|
||||||
|
add_user=add_user,
|
||||||
|
force=force,
|
||||||
|
multiply=multiply,
|
||||||
|
set_new=set_new,
|
||||||
|
)
|
||||||
18
src/single-core-regen/util/optuna_vis.py
Normal file
18
src/single-core-regen/util/optuna_vis.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
from dash import Dash, dcc, html
|
||||||
|
import logging
|
||||||
|
import dash_bootstrap_components as dbc
|
||||||
|
|
||||||
|
|
||||||
|
def show_figures(*figures):
|
||||||
|
for figure in figures:
|
||||||
|
figure.layout.template = 'plotly_dark'
|
||||||
|
|
||||||
|
app = Dash(external_stylesheets=[dbc.themes.DARKLY])
|
||||||
|
app.layout = html.Div([
|
||||||
|
dcc.Graph(figure=figure) for figure in figures
|
||||||
|
])
|
||||||
|
log = logging.getLogger('werkzeug')
|
||||||
|
log.setLevel(logging.ERROR)
|
||||||
|
|
||||||
|
app.show = lambda *args, **kwargs: app.run_server(*args, **kwargs, debug=False)
|
||||||
|
return app
|
||||||
73
src/single-core-regen/util/plot.py
Normal file
73
src/single-core-regen/util/plot.py
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
from .datasets import load_data
|
||||||
|
|
||||||
|
def eye(*, path=None, data=None, sps=None, title=None, symbols=1000, skipfirst=0, width=2, alpha=None, complex=False, show=True):
|
||||||
|
"""Plot an eye diagram for the data given by filepath.
|
||||||
|
|
||||||
|
Either path or data and sps must be given.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (str): Path to the data description file.
|
||||||
|
data (np.ndarray): Data to plot.
|
||||||
|
sps (int): Samples per symbol.
|
||||||
|
title (str): Title of the plot.
|
||||||
|
head (int): Number of symbols to plot.
|
||||||
|
skipfirst (int): Number of symbols to skip.
|
||||||
|
show (bool): Whether to call plt.show().
|
||||||
|
"""
|
||||||
|
if path is None and data is None:
|
||||||
|
raise ValueError("Either path or data and sps must be given.")
|
||||||
|
if path is not None:
|
||||||
|
data, config = load_data(path, skipfirst, symbols)
|
||||||
|
sps = int(config["glova"]["sps"])
|
||||||
|
if sps is None:
|
||||||
|
raise ValueError("sps not set.")
|
||||||
|
|
||||||
|
xaxis = np.linspace(0, width, width*sps, endpoint=False)
|
||||||
|
fig, axs = plt.subplots(2, 2, figsize=(10, 10), sharex=True, sharey=True)
|
||||||
|
if complex:
|
||||||
|
# create secondary axis for phase
|
||||||
|
axs2 = axs[0, 0].twinx(), axs[0, 1].twinx(), axs[1, 0].twinx(), axs[1, 1].twinx()
|
||||||
|
axs2 = np.reshape(axs2, (2, 2))
|
||||||
|
|
||||||
|
for i in range(symbols-(width-1)):
|
||||||
|
inx, iny, outx, outy = data[i*sps:(i+width)*sps].T
|
||||||
|
if complex:
|
||||||
|
axs[0, 0].plot(xaxis, np.abs(inx), color="C0", alpha=alpha or 0.1)
|
||||||
|
axs[0, 1].plot(xaxis, np.abs(outx), color="C0", alpha=alpha or 0.1)
|
||||||
|
axs[1, 0].plot(xaxis, np.abs(iny), color="C0", alpha=alpha or 0.1)
|
||||||
|
axs[1, 1].plot(xaxis, np.abs(outy), color="C0", alpha=alpha or 0.1)
|
||||||
|
axs[0, 0].set_ylim(0, 1.1*np.max(np.abs(data)))
|
||||||
|
|
||||||
|
axs2[0, 0].plot(xaxis, np.angle(inx), color="C1", alpha=alpha or 0.1)
|
||||||
|
axs2[0, 1].plot(xaxis, np.angle(outx), color="C1", alpha=alpha or 0.1)
|
||||||
|
axs2[1, 0].plot(xaxis, np.angle(iny), color="C1", alpha=alpha or 0.1)
|
||||||
|
axs2[1, 1].plot(xaxis, np.angle(outy), color="C1", alpha=alpha or 0.1)
|
||||||
|
else:
|
||||||
|
axs[0, 0].plot(xaxis, np.abs(inx)**2, color="C0", alpha=alpha or 0.1)
|
||||||
|
axs[0, 1].plot(xaxis, np.abs(outx)**2, color="C0", alpha=alpha or 0.1)
|
||||||
|
axs[1, 0].plot(xaxis, np.abs(iny)**2, color="C0", alpha=alpha or 0.1)
|
||||||
|
axs[1, 1].plot(xaxis, np.abs(outy)**2, color="C0", alpha=alpha or 0.1)
|
||||||
|
|
||||||
|
if complex:
|
||||||
|
axs2[0, 0].sharey(axs2[0, 1])
|
||||||
|
axs2[0, 1].sharey(axs2[1, 0])
|
||||||
|
axs2[1, 0].sharey(axs2[1, 1])
|
||||||
|
# make y axis symmetric
|
||||||
|
ylim = np.max(np.abs(np.angle(data)))*1.1
|
||||||
|
if ylim != 0:
|
||||||
|
axs2[0, 0].set_ylim(-ylim, ylim)
|
||||||
|
else:
|
||||||
|
axs[0,0].set_ylim(0, 1.1*np.max(np.abs(data))**2)
|
||||||
|
|
||||||
|
axs[0, 0].set_title("Input x")
|
||||||
|
axs[0, 1].set_title("Output x")
|
||||||
|
axs[1, 0].set_title("Input y")
|
||||||
|
axs[1, 1].set_title("Output y")
|
||||||
|
fig.suptitle(title or "Eye diagram")
|
||||||
|
|
||||||
|
if show:
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
return fig
|
||||||
Reference in New Issue
Block a user