Compare commits
9 Commits
main
...
cdca5de473
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cdca5de473 | ||
|
|
1622c38582 | ||
|
|
2bba760378 | ||
|
|
9ec548757d | ||
|
|
05a3ee9394 | ||
|
|
086240489a | ||
|
|
87f40fc37c | ||
|
|
90aa6dbaf8 | ||
|
|
744c5f5166 |
1
.gitattributes
vendored
1
.gitattributes
vendored
@@ -1,4 +1,5 @@
|
||||
data/**/* filter=lfs diff=lfs merge=lfs -text
|
||||
data/*.db filter=lfs diff=lfs merge=lfs -text
|
||||
data/*.ini filter=lfs diff=lfs merge=lfs text
|
||||
|
||||
## lfs setup
|
||||
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -1,7 +1,5 @@
|
||||
src/**/*.ini
|
||||
|
||||
# VSCode
|
||||
.vscode
|
||||
.*
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
|
||||
1
LICENSE
1
LICENSE
@@ -10,6 +10,7 @@ use is covered by a right of the copyright holder of the Work).
|
||||
The Work is provided under the terms of this Licence when the Licensor (as
|
||||
defined below) has placed the following notice immediately following the
|
||||
copyright notice for the Work:
|
||||
|
||||
```raw
|
||||
Licensed under the EUPL
|
||||
```
|
||||
|
||||
@@ -13,7 +13,7 @@ Full license text in LICENSE file
|
||||
|
||||
# optical-regeneration
|
||||
|
||||
## Notes on cloning:
|
||||
## Notes on cloning
|
||||
|
||||
- `pypho` is added as a submodule -> `--recurse-submodules`
|
||||
- This repo has about 7.5GB of datasets in it. The `git lfs fetch` step will take a while.
|
||||
@@ -29,4 +29,5 @@ git lfs checkout
|
||||
```
|
||||
|
||||
## License
|
||||
This project is licensed under EUPL-1.2.
|
||||
|
||||
This project is licensed under EUPL-1.2.
|
||||
|
||||
3
data/optuna_single_core_regen.db
Normal file
3
data/optuna_single_core_regen.db
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:72460af57347d35df91cd76982231bcf538a82fd7f1b8522795202fa298a2dcb
|
||||
size 696320
|
||||
@@ -1,6 +1,6 @@
|
||||
# CUDA 12.4 Install
|
||||
|
||||
> https://unknowndba.blogspot.com/2024/04/cuda-getting-started-on-wsl.html (@\_freddenis\_, 2024)
|
||||
> <https://unknowndba.blogspot.com/2024/04/cuda-getting-started-on-wsl.html> (@\_freddenis\_, 2024)
|
||||
|
||||
```bash
|
||||
# get md5sums
|
||||
@@ -49,10 +49,9 @@ make
|
||||
./devicequery
|
||||
```
|
||||
|
||||
|
||||
### if the cuda-toolkit install fails with unmet dependencies
|
||||
|
||||
>https://askubuntu.com/a/1493087 (jspinella, 2023, CC BY-SA 4.0)
|
||||
><https://askubuntu.com/a/1493087> (jspinella, 2023, CC BY-SA 4.0)
|
||||
|
||||
1. Open the *new* file for storing the sources list
|
||||
|
||||
@@ -60,7 +59,7 @@ make
|
||||
sudo nano /etc/apt/sources.list.d/ubuntu.sources
|
||||
```
|
||||
|
||||
2. Paste in the following at the end of the file:
|
||||
2. Paste in the following at the end of the file:
|
||||
|
||||
```raw
|
||||
Types: deb
|
||||
@@ -71,4 +70,3 @@ make
|
||||
```
|
||||
|
||||
3. Save the file and run `sudo apt update` - now the install command for CUDA should work.
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# useful links
|
||||
|
||||
- (Optuna)[https://optuna.org] Hyperparameter optimization framework
|
||||
- [Optuna](https://optuna.org) Hyperparameter optimization framework
|
||||
`pip install optuna`
|
||||
|
||||
@@ -1,46 +1,42 @@
|
||||
# pyenv install
|
||||
# pyenv installation
|
||||
|
||||
## install
|
||||
## pyenv
|
||||
|
||||
nice to have:
|
||||
1. Install pyenv
|
||||
|
||||
```bash
|
||||
sudo apt install python-is-python3
|
||||
```
|
||||
```bash
|
||||
curl https://pyenv.run | bash
|
||||
```
|
||||
|
||||
```bash
|
||||
curl https://pyenv.run | bash
|
||||
```
|
||||
2. setup zsh
|
||||
|
||||
## setup zsh
|
||||
add the following to `.zshrc`:
|
||||
|
||||
add the following to `.zshrc`:
|
||||
```bash
|
||||
export PYENV_ROOT="$HOME/.pyenv"
|
||||
[[ -d $PYENV_ROOT/bin ]] && export PATH="$PYENV_ROOT/bin:$PATH"
|
||||
eval "$(pyenv init -)"
|
||||
```
|
||||
|
||||
```bash
|
||||
export PYENV_ROOT="$HOME/.pyenv"
|
||||
[[ -d $PYENV_ROOT/bin ]] && export PATH="$PYENV_ROOT/bin:$PATH"
|
||||
eval "$(pyenv init -)"
|
||||
```
|
||||
## python installation
|
||||
|
||||
## pyenv install
|
||||
1. prerequisites
|
||||
|
||||
prerequisites:
|
||||
```bash
|
||||
sudo apt update
|
||||
sudo apt install build-essential libssl-dev zlib1g-dev \
|
||||
libbz2-dev libreadline-dev libsqlite3-dev curl git \
|
||||
libncursesw5-dev xz-utils tk-dev libxml2-dev libxmlsec1-dev \
|
||||
libffi-dev liblzma-dev python3-pip
|
||||
```
|
||||
|
||||
```bash
|
||||
sudo apt update
|
||||
sudo apt install build-essential libssl-dev zlib1g-dev \
|
||||
libbz2-dev libreadline-dev libsqlite3-dev curl git \
|
||||
libncursesw5-dev xz-utils tk-dev libxml2-dev libxmlsec1-dev \
|
||||
libffi-dev liblzma-dev python3-pip
|
||||
```
|
||||
2. install
|
||||
|
||||
install:
|
||||
```bash
|
||||
# using python 3.12.7 as an example
|
||||
pyenv install 3.12.7
|
||||
|
||||
```bash
|
||||
# using python 3.12.7 as an example
|
||||
pyenv install 3.12.7
|
||||
|
||||
# optional
|
||||
pyenv global 3.12.7
|
||||
pyenv versions
|
||||
```
|
||||
# optional
|
||||
pyenv global 3.12.7
|
||||
pyenv versions
|
||||
```
|
||||
|
||||
@@ -8,7 +8,8 @@ source ./.venv/bin/activate
|
||||
```
|
||||
|
||||
## install pytorch
|
||||
> https://pytorch.org/get-started/locally/
|
||||
|
||||
> <https://pytorch.org/get-started/locally/>
|
||||
|
||||
```bash
|
||||
pip install torch torchvision torchaudio
|
||||
|
||||
37
src/single-core-data-gen/add_pypho.py
Normal file
37
src/single-core-data-gen/add_pypho.py
Normal file
@@ -0,0 +1,37 @@
|
||||
|
||||
# add_pypho.py
|
||||
#
|
||||
# This file is part of the repo "optical-regeneration"
|
||||
# https://git.suuppl.dev/seppl/optical-regeneration.git
|
||||
#
|
||||
# (c) Joseph Hopfmüller, 2024
|
||||
# Licensed under the EUPL
|
||||
#
|
||||
# Full license text in LICENSE file
|
||||
|
||||
###
|
||||
|
||||
# copy this file into the directory where you want to use pypho
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
__log = []
|
||||
|
||||
# add the dir above the one where this file lives
|
||||
__parent_dir = Path(__file__).parent
|
||||
|
||||
# search for a dir containing ./pypho/pypho, then add the lower ./pypho
|
||||
while not (__parent_dir / "pypho" / "pypho").exists() and __parent_dir != Path("/"):
|
||||
__parent_dir = __parent_dir.parent
|
||||
|
||||
if __parent_dir != Path("/"):
|
||||
sys.path.append(str(__parent_dir / "pypho"))
|
||||
__log.append(f"Added '{__parent_dir/ "pypho"}' to 'PATH'")
|
||||
else:
|
||||
__log.append('pypho not found')
|
||||
|
||||
|
||||
def show_log():
|
||||
for entry in __log:
|
||||
print(entry)
|
||||
@@ -19,9 +19,8 @@ import time
|
||||
from matplotlib import pyplot as plt # noqa: F401
|
||||
import numpy as np
|
||||
|
||||
import _path_fix # noqa: F401
|
||||
import path_fix
|
||||
import pypho
|
||||
# import inspect
|
||||
|
||||
default_config = f"""
|
||||
[glova]
|
||||
@@ -498,6 +497,7 @@ def plot_eye_diagram(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
path_fix.show_log()
|
||||
config = get_config()
|
||||
|
||||
length_ranges = [1000, 10000]
|
||||
@@ -1,9 +0,0 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# hack to add the parent directory to the path -> pypho doesn't have to be installed as package
|
||||
parent_dir = Path(__file__).parent
|
||||
while not (parent_dir / "pypho" / "pypho").exists() and parent_dir != Path("/"):
|
||||
parent_dir = parent_dir.parent
|
||||
print(f"Adding '{parent_dir / "pypho"}' to 'sys.path' to enable import of '{parent_dir / 'pypho' / 'pypho'}'")
|
||||
sys.path.append(str(parent_dir / "pypho"))
|
||||
464
src/single-core-regen/regen.py
Normal file
464
src/single-core-regen/regen.py
Normal file
@@ -0,0 +1,464 @@
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
import numpy as np
|
||||
import optuna
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
# import torch.nn.functional as F # mse_loss doesn't support complex numbers
|
||||
import torch.optim as optim
|
||||
import torch.utils.data
|
||||
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from rich.progress import Progress, TextColumn, BarColumn, TaskProgressColumn, TimeRemainingColumn, MofNCompleteColumn
|
||||
from rich.console import Console
|
||||
|
||||
import multiprocessing
|
||||
|
||||
from util.datasets import FiberRegenerationDataset
|
||||
from util.complexNN import complex_sse_loss
|
||||
from util.optuna_helpers import optional_suggest_categorical, optional_suggest_float, optional_suggest_int
|
||||
import util
|
||||
# global settings
|
||||
@dataclass
|
||||
class GlobalSettings:
|
||||
seed: int = 42
|
||||
|
||||
|
||||
# data settings
|
||||
@dataclass
|
||||
class DataSettings:
|
||||
config_path: str = "data/*-128-16384-1000-0-0-17-0-PAM4-0.ini"
|
||||
dtype: torch.dtype = torch.complex64
|
||||
symbols_range: tuple|float|int = 16
|
||||
data_size_range: tuple|float|int = 32
|
||||
shuffle: bool = True
|
||||
target_delay: float = 0
|
||||
xy_delay_range: tuple|float|int = 0
|
||||
drop_first: int = 10
|
||||
train_split: float = 0.8
|
||||
|
||||
|
||||
# pytorch settings
|
||||
@dataclass
|
||||
class PytorchSettings:
|
||||
device: str = "cuda"
|
||||
batchsize: int = 1024
|
||||
epochs: int = 10
|
||||
summary_dir: str = ".runs"
|
||||
|
||||
|
||||
# model settings
|
||||
@dataclass
|
||||
class ModelSettings:
|
||||
output_size: int = 2
|
||||
n_layer_range: tuple|float|int = (2,8)
|
||||
n_units_range: tuple|float|int = (2,32)
|
||||
# activation_func_range: tuple = ("ReLU",)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OptimizerSettings:
|
||||
# optimizer_range: tuple|str = ("Adam", "RMSprop", "SGD")
|
||||
optimizer_range: tuple|str = "RMSprop"
|
||||
# lr_range: tuple|float = (1e-5, 1e-1)
|
||||
lr_range: tuple|float = 2e-5
|
||||
|
||||
|
||||
# optuna settings
|
||||
@dataclass
|
||||
class OptunaSettings:
|
||||
n_trials: int = 128
|
||||
n_threads: int = 8
|
||||
timeout: int = 600
|
||||
directions: tuple = ("minimize",)
|
||||
metrics_names: tuple = ("sse",)
|
||||
|
||||
limit_examples: bool = True
|
||||
n_train_examples: int = PytorchSettings.batchsize * 50
|
||||
# n_valid_examples: int = PytorchSettings.batchsize * 100
|
||||
n_valid_examples: int = float("inf")
|
||||
storage: str = "sqlite:///optuna_single_core_regen.db"
|
||||
study_name: str = (
|
||||
f"single_core_regen_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}"
|
||||
)
|
||||
|
||||
|
||||
class HyperTraining:
|
||||
def __init__(self):
|
||||
self.global_settings = GlobalSettings()
|
||||
self.data_settings = DataSettings()
|
||||
self.pytorch_settings = PytorchSettings()
|
||||
self.model_settings = ModelSettings()
|
||||
self.optimizer_settings = OptimizerSettings()
|
||||
self.optuna_settings = OptunaSettings()
|
||||
|
||||
self.console = Console()
|
||||
|
||||
# set some extra settings to make the code more readable
|
||||
self._extra_optuna_settings()
|
||||
|
||||
def setup_tb_writer(self, study_name=None, append=None):
|
||||
log_dir = self.pytorch_settings.summary_dir + "/" + (study_name or self.optuna_settings.study_name)
|
||||
if append is not None:
|
||||
log_dir += "_" + str(append)
|
||||
|
||||
return SummaryWriter(log_dir)
|
||||
|
||||
def resume_latest_study(self, verbose=True):
|
||||
study_name = hyper_training.get_latest_study()
|
||||
|
||||
if study_name:
|
||||
print(f"Resuming study: {study_name}")
|
||||
self.optuna_settings.study_name = study_name
|
||||
|
||||
def get_latest_study(self, verbose=True):
|
||||
studies = self.get_studies()
|
||||
for study in studies:
|
||||
study.datetime_start = study.datetime_start or datetime.min
|
||||
if studies:
|
||||
study = sorted(studies, key = lambda x: x.datetime_start, reverse=True)[0]
|
||||
if verbose:
|
||||
print(f"Last study: {study.study_name}")
|
||||
study_name = study.study_name
|
||||
else:
|
||||
if verbose:
|
||||
print("No previous studies found")
|
||||
study_name = None
|
||||
return study_name
|
||||
|
||||
def get_studies(self):
|
||||
return optuna.get_all_study_summaries(storage=self.optuna_settings.storage)
|
||||
|
||||
def setup_study(self):
|
||||
self.study = optuna.create_study(
|
||||
study_name=self.optuna_settings.study_name,
|
||||
storage=self.optuna_settings.storage,
|
||||
load_if_exists=True,
|
||||
direction=self.optuna_settings.direction,
|
||||
directions=self.optuna_settings.directions,
|
||||
)
|
||||
|
||||
with warnings.catch_warnings(action="ignore"):
|
||||
self.study.set_metric_names(self.optuna_settings.metrics_names)
|
||||
|
||||
self.n_threads = min(
|
||||
self.optuna_settings.n_trials, self.optuna_settings.n_threads
|
||||
)
|
||||
self.processes = []
|
||||
if self.n_threads > 1:
|
||||
for _ in range(self.n_threads):
|
||||
p = multiprocessing.Process(
|
||||
# target=lambda n_trials: self._run_optimize(self, n_trials),
|
||||
target = self._run_optimize,
|
||||
args = (self.optuna_settings.n_trials // self.n_threads,),
|
||||
)
|
||||
self.processes.append(p)
|
||||
|
||||
def run_study(self):
|
||||
if self.processes:
|
||||
for p in self.processes:
|
||||
p.start()
|
||||
for p in self.processes:
|
||||
p.join()
|
||||
|
||||
remaining_trials = (
|
||||
self.optuna_settings.n_trials
|
||||
- self.optuna_settings.n_trials % self.optuna_settings.n_threads
|
||||
)
|
||||
else:
|
||||
remaining_trials = self.optuna_settings.n_trials
|
||||
|
||||
if remaining_trials:
|
||||
self._run_optimize(remaining_trials)
|
||||
|
||||
def _run_optimize(self, n_trials):
|
||||
self.study.optimize(
|
||||
self.objective, n_trials=n_trials, timeout=self.optuna_settings.timeout
|
||||
)
|
||||
|
||||
def plot_eye(self, show=True):
|
||||
if not hasattr(self, "eye_data"):
|
||||
data, config = util.datasets.load_data(
|
||||
self.data_settings.config_path, skipfirst=10, symbols=1000
|
||||
)
|
||||
self.eye_data = {"data": data, "sps": int(config["glova"]["sps"])}
|
||||
return util.plot.eye(**self.eye_data, show=show)
|
||||
|
||||
def _extra_optuna_settings(self):
|
||||
self.optuna_settings.multi_objective = len(self.optuna_settings.directions) > 1
|
||||
if self.optuna_settings.multi_objective:
|
||||
self.optuna_settings.direction = None
|
||||
else:
|
||||
self.optuna_settings.direction = self.optuna_settings.directions[0]
|
||||
self.optuna_settings.directions = None
|
||||
|
||||
self.optuna_settings.n_train_examples = (
|
||||
self.optuna_settings.n_train_examples
|
||||
if self.optuna_settings.limit_examples
|
||||
else float("inf")
|
||||
)
|
||||
self.optuna_settings.n_valid_examples = (
|
||||
self.optuna_settings.n_valid_examples
|
||||
if self.optuna_settings.limit_examples
|
||||
else float("inf")
|
||||
)
|
||||
|
||||
def define_model(self, trial: optuna.Trial, writer=None):
|
||||
n_layers = optional_suggest_int(trial, "model_n_layers", self.model_settings.n_layer_range)
|
||||
|
||||
in_features = 2 * trial.params.get(
|
||||
"dataset_data_size",
|
||||
optional_suggest_int(trial, "dataset_data_size", self.data_settings.data_size_range),
|
||||
)
|
||||
trial.set_user_attr("input_dim", in_features)
|
||||
|
||||
layers = []
|
||||
for i in range(n_layers):
|
||||
out_features = optional_suggest_int(trial, f"model_n_units_l{i}", self.model_settings.n_units_range, log=True)
|
||||
|
||||
layers.append(nn.Linear(in_features, out_features, dtype=self.data_settings.dtype))
|
||||
# layers.append(getattr(nn, activation_func)())
|
||||
in_features = out_features
|
||||
|
||||
layers.append(nn.Linear(in_features, self.model_settings.output_size, dtype=self.data_settings.dtype))
|
||||
|
||||
if writer is not None:
|
||||
writer.add_graph(nn.Sequential(*layers), torch.zeros(1, trial.user_attrs["input_dim"], dtype=self.data_settings.dtype))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def get_sliced_data(self, trial: optuna.Trial):
|
||||
symbols = optional_suggest_float(trial, "dataset_symbols", self.data_settings.symbols_range)
|
||||
|
||||
xy_delay = optional_suggest_float(trial, "dataset_xy_delay", self.data_settings.xy_delay_range)
|
||||
|
||||
data_size = trial.params.get(
|
||||
"dataset_data_size",
|
||||
optional_suggest_int(trial, "dataset_data_size", self.data_settings.data_size_range)
|
||||
)
|
||||
|
||||
# get dataset
|
||||
dataset = FiberRegenerationDataset(
|
||||
file_path=self.data_settings.config_path,
|
||||
symbols=symbols,
|
||||
data_size=data_size,
|
||||
target_delay=self.data_settings.target_delay,
|
||||
xy_delay=xy_delay,
|
||||
drop_first=self.data_settings.drop_first,
|
||||
dtype=self.data_settings.dtype,
|
||||
)
|
||||
|
||||
dataset_size = len(dataset)
|
||||
indices = list(range(dataset_size))
|
||||
split = int(np.floor(self.data_settings.train_split * dataset_size))
|
||||
if self.data_settings.shuffle:
|
||||
np.random.seed(self.global_settings.seed)
|
||||
np.random.shuffle(indices)
|
||||
|
||||
train_indices, valid_indices = indices[:split], indices[split:]
|
||||
|
||||
train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
|
||||
valid_sampler = torch.utils.data.SubsetRandomSampler(valid_indices)
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
dataset, batch_size=self.pytorch_settings.batchsize, sampler=train_sampler, drop_last=True
|
||||
)
|
||||
valid_loader = torch.utils.data.DataLoader(
|
||||
dataset, batch_size=self.pytorch_settings.batchsize, sampler=valid_sampler, drop_last=True
|
||||
)
|
||||
|
||||
return train_loader, valid_loader
|
||||
|
||||
def train_model(self, model, optimizer, train_loader, epoch, writer=None, enable_progress=True):
|
||||
if enable_progress:
|
||||
progress = Progress(
|
||||
TextColumn("[yellow] Training..."),
|
||||
TextColumn(" Loss: {task.description}"),
|
||||
BarColumn(),
|
||||
TaskProgressColumn(),
|
||||
TextColumn("[green]Batch"),
|
||||
MofNCompleteColumn(),
|
||||
TimeRemainingColumn(),
|
||||
# description="Training",
|
||||
transient=False,
|
||||
console=self.console,
|
||||
refresh_per_second=10,
|
||||
)
|
||||
task = progress.add_task("-.---e--", total=len(train_loader))
|
||||
|
||||
running_loss = 0.0
|
||||
last_loss = 0.0
|
||||
model.train()
|
||||
for batch_idx, (x, y) in enumerate(train_loader):
|
||||
if (
|
||||
batch_idx * train_loader.batch_size
|
||||
>= self.optuna_settings.n_train_examples
|
||||
):
|
||||
break
|
||||
optimizer.zero_grad()
|
||||
x, y = (
|
||||
x.to(self.pytorch_settings.device),
|
||||
y.to(self.pytorch_settings.device),
|
||||
)
|
||||
y_pred = model(x)
|
||||
loss = complex_sse_loss(y_pred, y)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# clamp weights to keep energy bounded
|
||||
for p in model.parameters():
|
||||
p.data.clamp_(-1.0, 1.0)
|
||||
|
||||
last_loss = loss.item()
|
||||
|
||||
if enable_progress:
|
||||
progress.update(task, advance=1, description=f"{last_loss:.3e}")
|
||||
|
||||
running_loss += loss.item()
|
||||
if writer is not None:
|
||||
if batch_idx % 10 == 0:
|
||||
writer.add_scalar("training loss", running_loss/10, epoch*min(len(train_loader), self.optuna_settings.n_train_examples/train_loader.batch_size) + batch_idx)
|
||||
running_loss = 0.0
|
||||
|
||||
if enable_progress:
|
||||
progress.update(task, description=f"{last_loss:.3e}")
|
||||
progress.stop()
|
||||
|
||||
|
||||
def eval_model(self, model, valid_loader, epoch, writer=None, enable_progress=True):
|
||||
if enable_progress:
|
||||
progress = Progress(
|
||||
TextColumn("[green]Evaluating..."),
|
||||
TextColumn("Error: {task.description}"),
|
||||
BarColumn(),
|
||||
TaskProgressColumn(),
|
||||
TextColumn("[green]Batch"),
|
||||
MofNCompleteColumn(),
|
||||
TimeRemainingColumn(),
|
||||
# description="Training",
|
||||
transient=False,
|
||||
console=self.console,
|
||||
refresh_per_second=10,
|
||||
)
|
||||
task = progress.add_task("-.---e--", total=len(valid_loader))
|
||||
|
||||
model.eval()
|
||||
running_error = 0
|
||||
running_error_2 = 0
|
||||
with torch.no_grad():
|
||||
for batch_idx, (x, y) in enumerate(valid_loader):
|
||||
if (
|
||||
batch_idx * valid_loader.batch_size
|
||||
>= self.optuna_settings.n_valid_examples
|
||||
):
|
||||
break
|
||||
x, y = (
|
||||
x.to(self.pytorch_settings.device),
|
||||
y.to(self.pytorch_settings.device),
|
||||
)
|
||||
y_pred = model(x)
|
||||
error = complex_sse_loss(y_pred, y)
|
||||
running_error += error.item()
|
||||
running_error_2 += error.item()
|
||||
|
||||
if enable_progress:
|
||||
progress.update(task, advance=1, description=f"{error.item():.3e}")
|
||||
|
||||
if writer is not None:
|
||||
if batch_idx % 10 == 0:
|
||||
writer.add_scalar("sse", running_error_2/10, epoch*min(len(valid_loader), self.optuna_settings.n_valid_examples/valid_loader.batch_size) + batch_idx)
|
||||
running_error_2 = 0.0
|
||||
|
||||
running_error /= batch_idx + 1
|
||||
|
||||
if enable_progress:
|
||||
progress.update(task, description=f"{running_error:.3e}")
|
||||
progress.stop()
|
||||
|
||||
return running_error
|
||||
|
||||
def run_model(self, model, loader):
|
||||
model.eval()
|
||||
y_preds = []
|
||||
with torch.no_grad():
|
||||
for x, y in loader:
|
||||
x, y = (
|
||||
x.to(self.pytorch_settings.device),
|
||||
y.to(self.pytorch_settings.device),
|
||||
)
|
||||
y_preds.append(model(x))
|
||||
return torch.stack(y_preds)
|
||||
|
||||
|
||||
def objective(self, trial: optuna.Trial):
|
||||
writer = self.setup_tb_writer(self.optuna_settings.study_name, f"{trial.number:0>len(str(self.optuna_settings.n_trials))}")
|
||||
train_loader, valid_loader = self.get_sliced_data(trial)
|
||||
|
||||
model = self.define_model(trial, writer).to(self.pytorch_settings.device)
|
||||
|
||||
optimizer_name = optional_suggest_categorical(trial, "optimizer", self.optimizer_settings.optimizer_range)
|
||||
|
||||
lr = optional_suggest_float(trial, "lr", self.optimizer_settings.lr_range, log=True)
|
||||
|
||||
optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)
|
||||
|
||||
for epoch in range(self.pytorch_settings.epochs):
|
||||
enable_progress = self.optuna_settings.n_threads == 1
|
||||
if enable_progress:
|
||||
print(f"Epoch {epoch+1}/{self.pytorch_settings.epochs}")
|
||||
self.train_model(model, optimizer, train_loader, epoch, writer, enable_progress=enable_progress)
|
||||
sse = self.eval_model(model, valid_loader, epoch, writer, enable_progress=enable_progress)
|
||||
|
||||
if not self.optuna_settings.multi_objective:
|
||||
trial.report(sse, epoch)
|
||||
if trial.should_prune():
|
||||
raise optuna.exceptions.TrialPruned()
|
||||
|
||||
writer.close()
|
||||
|
||||
return sse
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
hyper_training = HyperTraining()
|
||||
|
||||
# hyper_training.resume_latest_study()
|
||||
|
||||
hyper_training.setup_study()
|
||||
hyper_training.run_study()
|
||||
|
||||
best_model = hyper_training.define_model(hyper_training.study.best_trial).to(hyper_training.pytorch_settings.device)
|
||||
data_settings_backup = copy.copy(hyper_training.data_settings)
|
||||
hyper_training.data_settings.shuffle = False
|
||||
hyper_training.data_settings.train_split = 0.01
|
||||
plot_loader, _ = hyper_training.get_sliced_data(hyper_training.study.best_trial)
|
||||
|
||||
regen = hyper_training.run_model(best_model, plot_loader)
|
||||
regen = regen.view(-1, 2)
|
||||
# [batch_no, batch_size, 2] -> [no, 2]
|
||||
|
||||
original, _ = util.datasets.load_data(hyper_training.data_settings.config_path, skipfirst=hyper_training.data_settings.drop_first)
|
||||
original = original[:len(regen)]
|
||||
|
||||
regen = regen.cpu().numpy()
|
||||
_, axs = plt.subplots(2)
|
||||
for i, ax in enumerate(axs):
|
||||
ax.plot(np.abs(original[:, i])**2, label="original")
|
||||
ax.plot(np.abs(regen[:, i])**2, label="regen")
|
||||
ax.legend()
|
||||
plt.show()
|
||||
|
||||
|
||||
print(f"Best model: {best_model}")
|
||||
|
||||
|
||||
# eye_fig = hyper_training.plot_eye()
|
||||
...
|
||||
429
src/single-core-regen/regen_no_hyper.py
Normal file
429
src/single-core-regen/regen_no_hyper.py
Normal file
@@ -0,0 +1,429 @@
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# import torch.nn.functional as F # mse_loss doesn't support complex numbers
|
||||
import torch.optim as optim
|
||||
import torch.utils.data
|
||||
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from rich.progress import (
|
||||
Progress,
|
||||
TextColumn,
|
||||
BarColumn,
|
||||
TaskProgressColumn,
|
||||
TimeRemainingColumn,
|
||||
MofNCompleteColumn,
|
||||
TimeElapsedColumn,
|
||||
)
|
||||
from rich.console import Console
|
||||
from rich import print as rprint
|
||||
|
||||
# from util.optuna_helpers import optional_suggest_categorical, optional_suggest_float, optional_suggest_int
|
||||
import util
|
||||
|
||||
|
||||
# global settings
|
||||
@dataclass
|
||||
class GlobalSettings:
|
||||
seed: int = 42
|
||||
|
||||
|
||||
# data settings
|
||||
@dataclass
|
||||
class DataSettings:
|
||||
config_path: str = "data/*-128-16384-100000-0-0-17-0-PAM4-0.ini"
|
||||
dtype: torch.dtype = torch.complex64
|
||||
symbols_range: float | int = 8
|
||||
data_size_range: float | int = 64
|
||||
shuffle: bool = True
|
||||
target_delay: float = 0
|
||||
xy_delay_range: float | int = 0
|
||||
drop_first: int = 10
|
||||
train_split: float = 0.8
|
||||
|
||||
|
||||
# pytorch settings
|
||||
@dataclass
|
||||
class PytorchSettings:
|
||||
epochs: int = 1000
|
||||
batchsize: int = 2**12
|
||||
device: str = "cuda"
|
||||
summary_dir: str = ".runs"
|
||||
model_dir: str = ".models"
|
||||
|
||||
|
||||
# model settings
|
||||
@dataclass
|
||||
class ModelSettings:
|
||||
output_size: int = 2
|
||||
# n_layer_range: float|int = 2
|
||||
# n_units_range: float|int = 32
|
||||
n_layers: int = 3
|
||||
n_units: int = 32
|
||||
activation_func: tuple | str = "ModReLU"
|
||||
|
||||
|
||||
@dataclass
|
||||
class OptimizerSettings:
|
||||
optimizer_range: str = "Adam"
|
||||
lr_range: float = 2e-3
|
||||
|
||||
|
||||
class Training:
|
||||
def __init__(self):
|
||||
self.global_settings = GlobalSettings()
|
||||
self.data_settings = DataSettings()
|
||||
self.pytorch_settings = PytorchSettings()
|
||||
self.model_settings = ModelSettings()
|
||||
self.optimizer_settings = OptimizerSettings()
|
||||
self.study_name = (
|
||||
f"single_core_regen_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}"
|
||||
)
|
||||
|
||||
if not hasattr(self.pytorch_settings, "model_dir"):
|
||||
self.pytorch_settings.model_dir = ".models"
|
||||
|
||||
self.writer = None
|
||||
self.console = Console()
|
||||
|
||||
def setup_tb_writer(self, study_name=None):
|
||||
log_dir = (
|
||||
self.pytorch_settings.summary_dir + "/" + (study_name or self.study_name)
|
||||
)
|
||||
self.writer = SummaryWriter(log_dir)
|
||||
|
||||
def plot_eye(self, width=2, symbols=None, alpha=None, complex=False, show=True):
|
||||
if not hasattr(self, "eye_data"):
|
||||
data, config = util.datasets.load_data(
|
||||
self.data_settings.config_path,
|
||||
skipfirst=10,
|
||||
symbols=symbols or 1000,
|
||||
real=not self.data_settings.dtype.is_complex,
|
||||
normalize=True,
|
||||
)
|
||||
self.eye_data = {"data": data, "sps": int(config["glova"]["sps"])}
|
||||
return util.plot.eye(
|
||||
**self.eye_data,
|
||||
width=width,
|
||||
show=show,
|
||||
alpha=alpha,
|
||||
complex=complex,
|
||||
symbols=symbols or 1000,
|
||||
skipfirst=0,
|
||||
)
|
||||
|
||||
def define_model(self):
|
||||
n_layers = self.model_settings.n_layers
|
||||
|
||||
in_features = 2 * self.data_settings.data_size_range
|
||||
|
||||
layers = []
|
||||
for i in range(n_layers):
|
||||
out_features = self.model_settings.n_units
|
||||
|
||||
layers.append(util.complexNN.UnitaryLayer(in_features, out_features))
|
||||
# layers.append(getattr(nn, self.model_settings.activation_func)())
|
||||
layers.append(
|
||||
getattr(util.complexNN, self.model_settings.activation_func)()
|
||||
)
|
||||
in_features = out_features
|
||||
|
||||
layers.append(
|
||||
util.complexNN.UnitaryLayer(in_features, self.model_settings.output_size)
|
||||
)
|
||||
|
||||
if self.writer is not None:
|
||||
self.writer.add_graph(
|
||||
nn.Sequential(*layers),
|
||||
torch.zeros(1, layers[0].in_features, dtype=self.data_settings.dtype),
|
||||
)
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def get_sliced_data(self):
|
||||
symbols = self.data_settings.symbols_range
|
||||
|
||||
xy_delay = self.data_settings.xy_delay_range
|
||||
|
||||
data_size = self.data_settings.data_size_range
|
||||
|
||||
# get dataset
|
||||
dataset = util.datasets.FiberRegenerationDataset(
|
||||
file_path=self.data_settings.config_path,
|
||||
symbols=symbols,
|
||||
data_size=data_size,
|
||||
target_delay=self.data_settings.target_delay,
|
||||
xy_delay=xy_delay,
|
||||
drop_first=self.data_settings.drop_first,
|
||||
dtype=self.data_settings.dtype,
|
||||
real=not self.data_settings.dtype.is_complex,
|
||||
# device=self.pytorch_settings.device,
|
||||
)
|
||||
|
||||
dataset_size = len(dataset)
|
||||
indices = list(range(dataset_size))
|
||||
split = int(np.floor(self.data_settings.train_split * dataset_size))
|
||||
if self.data_settings.shuffle:
|
||||
np.random.seed(self.global_settings.seed)
|
||||
np.random.shuffle(indices)
|
||||
|
||||
train_indices, valid_indices = indices[:split], indices[split:]
|
||||
|
||||
if self.data_settings.shuffle:
|
||||
train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
|
||||
valid_sampler = torch.utils.data.SubsetRandomSampler(valid_indices)
|
||||
else:
|
||||
train_sampler = train_indices
|
||||
valid_sampler = valid_indices
|
||||
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=self.pytorch_settings.batchsize,
|
||||
sampler=train_sampler,
|
||||
drop_last=True,
|
||||
pin_memory=True,
|
||||
num_workers=24,
|
||||
prefetch_factor=4,
|
||||
# persistent_workers=True
|
||||
)
|
||||
valid_loader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=self.pytorch_settings.batchsize,
|
||||
sampler=valid_sampler,
|
||||
drop_last=True,
|
||||
pin_memory=True,
|
||||
num_workers=24,
|
||||
prefetch_factor=4,
|
||||
# persistent_workers=True
|
||||
)
|
||||
|
||||
return train_loader, valid_loader
|
||||
|
||||
def train_model(self, model, optimizer, train_loader, epoch):
|
||||
with Progress(
|
||||
TextColumn("[yellow] Training..."),
|
||||
TextColumn("Loss: {task.description}"),
|
||||
BarColumn(),
|
||||
TaskProgressColumn(),
|
||||
TextColumn("[green]Batch"),
|
||||
MofNCompleteColumn(),
|
||||
TimeRemainingColumn(),
|
||||
TimeElapsedColumn(),
|
||||
# description="Training",
|
||||
transient=False,
|
||||
console=self.console,
|
||||
refresh_per_second=10,
|
||||
) as progress:
|
||||
task = progress.add_task("-.---e--", total=len(train_loader))
|
||||
|
||||
running_loss = 0.0
|
||||
model.train()
|
||||
for batch_idx, (x, y) in enumerate(train_loader):
|
||||
model.zero_grad(set_to_none=True)
|
||||
x, y = (
|
||||
x.to(self.pytorch_settings.device),
|
||||
y.to(self.pytorch_settings.device),
|
||||
)
|
||||
y_pred = model(x)
|
||||
loss = util.complexNN.complex_mse_loss(y_pred, y)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
progress.update(task, advance=1, description=f"{loss.item():.3e}")
|
||||
|
||||
running_loss += loss.item()
|
||||
if self.writer is not None:
|
||||
if (batch_idx + 1) % 10 == 0:
|
||||
self.writer.add_scalar(
|
||||
"training loss",
|
||||
running_loss / 10,
|
||||
epoch * len(train_loader) + batch_idx,
|
||||
)
|
||||
running_loss = 0.0
|
||||
|
||||
return running_loss
|
||||
|
||||
def eval_model(self, model, valid_loader, epoch):
|
||||
with Progress(
|
||||
TextColumn("[green]Evaluating..."),
|
||||
TextColumn("Loss: {task.description}"),
|
||||
BarColumn(),
|
||||
TaskProgressColumn(),
|
||||
TextColumn("[green]Batch"),
|
||||
MofNCompleteColumn(),
|
||||
TimeRemainingColumn(),
|
||||
TimeElapsedColumn(),
|
||||
# description="Training",
|
||||
transient=False,
|
||||
console=self.console,
|
||||
refresh_per_second=10,
|
||||
) as progress:
|
||||
task = progress.add_task("-.---e--", total=len(valid_loader))
|
||||
|
||||
model.eval()
|
||||
running_loss = 0
|
||||
running_loss2 = 0
|
||||
with torch.no_grad():
|
||||
for batch_idx, (x, y) in enumerate(valid_loader):
|
||||
x, y = (
|
||||
x.to(self.pytorch_settings.device),
|
||||
y.to(self.pytorch_settings.device),
|
||||
)
|
||||
y_pred = model(x)
|
||||
loss = util.complexNN.complex_mse_loss(y_pred, y)
|
||||
running_loss += loss.item()
|
||||
running_loss2 += loss.item()
|
||||
|
||||
progress.update(task, advance=1, description=f"{loss.item():.3e}")
|
||||
if self.writer is not None:
|
||||
if (batch_idx + 1) % 10 == 0:
|
||||
self.writer.add_scalar(
|
||||
"loss",
|
||||
running_loss / 10,
|
||||
epoch * len(valid_loader) + batch_idx,
|
||||
)
|
||||
running_loss = 0.0
|
||||
|
||||
if self.writer is not None:
|
||||
self.writer.add_figure("fiber response", self.plot_model_response(model, plot=False), epoch+1)
|
||||
|
||||
return running_loss2 / len(valid_loader)
|
||||
|
||||
def run_model(self, model, loader):
|
||||
model.eval()
|
||||
xs = []
|
||||
ys = []
|
||||
y_preds = []
|
||||
with torch.no_grad():
|
||||
model = model.to(self.pytorch_settings.device)
|
||||
for x, y in loader:
|
||||
x, y = (
|
||||
x.to(self.pytorch_settings.device),
|
||||
y.to(self.pytorch_settings.device),
|
||||
)
|
||||
y_pred = model(x).cpu()
|
||||
# x = x.cpu()
|
||||
# y = y.cpu()
|
||||
y_pred = y_pred.view(y_pred.shape[0], -1, 2)
|
||||
y = y.view(y.shape[0], -1, 2)
|
||||
x = x.view(x.shape[0], -1, 2)
|
||||
xs.append(x[:, 0, :].squeeze())
|
||||
ys.append(y.squeeze())
|
||||
y_preds.append(y_pred.squeeze())
|
||||
|
||||
xs = torch.vstack(xs).cpu()
|
||||
ys = torch.vstack(ys).cpu()
|
||||
y_preds = torch.vstack(y_preds).cpu()
|
||||
return ys, xs, y_preds
|
||||
|
||||
def dummy_model(self, loader):
|
||||
xs = []
|
||||
ys = []
|
||||
for x, y in loader:
|
||||
y = y.cpu().view(y.shape[0], -1, 2)
|
||||
x = x.cpu().view(x.shape[0], -1, 2)
|
||||
xs.append(x[:, 0, :].squeeze())
|
||||
ys.append(y.squeeze())
|
||||
xs = torch.vstack(xs)
|
||||
ys = torch.vstack(ys)
|
||||
return xs, ys
|
||||
|
||||
def objective(self, save=False, plot_before=False):
|
||||
try:
|
||||
rprint(*list(self.study_name.split("_")))
|
||||
|
||||
self.model = self.define_model().to(self.pytorch_settings.device)
|
||||
|
||||
if self.writer is not None:
|
||||
self.writer.add_figure("fiber response", self.plot_model_response(plot=plot_before), 0)
|
||||
|
||||
train_loader, valid_loader = self.get_sliced_data()
|
||||
|
||||
optimizer_name = self.optimizer_settings.optimizer_range
|
||||
|
||||
lr = self.optimizer_settings.lr_range
|
||||
|
||||
optimizer = getattr(optim, optimizer_name)(self.model.parameters(), lr=lr)
|
||||
|
||||
for epoch in range(self.pytorch_settings.epochs):
|
||||
self.console.rule(f"Epoch {epoch + 1}/{self.pytorch_settings.epochs}")
|
||||
self.train_model(self.model, optimizer, train_loader, epoch)
|
||||
eval_loss = self.eval_model(self.model, valid_loader, epoch)
|
||||
|
||||
if save:
|
||||
save_path = (
|
||||
Path(self.pytorch_settings.model_dir) / f"{self.study_name}.pth"
|
||||
)
|
||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
torch.save(self.model, save_path)
|
||||
|
||||
return eval_loss
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
if hasattr(self, "model"):
|
||||
except_save_path = Path(".models/exception") / f"{self.study_name}.pth"
|
||||
except_save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
torch.save(self.model, except_save_path)
|
||||
|
||||
def _plot_model_response_plotter(self, fiber_in, fiber_out, regen, plot=True):
|
||||
fig, axs = plt.subplots(2)
|
||||
for i, ax in enumerate(axs):
|
||||
ax.plot(np.abs(fiber_in[:, i]) ** 2, label="fiber in")
|
||||
ax.plot(np.abs(fiber_out[:, i]) ** 2, label="fiber out")
|
||||
ax.plot(np.abs(regen[:, i]) ** 2, label="regenerated")
|
||||
ax.legend()
|
||||
if plot:
|
||||
plt.show()
|
||||
return fig
|
||||
|
||||
def plot_model_response(self, model=None, plot=True):
|
||||
data_settings_backup = copy.copy(self.data_settings)
|
||||
self.data_settings.shuffle = False
|
||||
self.data_settings.train_split = 0.01
|
||||
self.data_settings.drop_first = 100
|
||||
plot_loader, _ = self.get_sliced_data()
|
||||
self.data_settings = data_settings_backup
|
||||
|
||||
fiber_in, fiber_out, regen = self.run_model(model or self.model, plot_loader)
|
||||
fiber_in = fiber_in.view(-1, 2)
|
||||
fiber_out = fiber_out.view(-1, 2)
|
||||
regen = regen.view(-1, 2)
|
||||
|
||||
fiber_in = fiber_in.numpy()
|
||||
fiber_out = fiber_out.numpy()
|
||||
regen = regen.numpy()
|
||||
|
||||
# https://github.com/matplotlib/matplotlib/issues/27713#issue-2104110987
|
||||
# https://github.com/matplotlib/matplotlib/issues/27713#issuecomment-1915497463
|
||||
import gc
|
||||
fig = self._plot_model_response_plotter(fiber_in, fiber_out, regen, plot=plot)
|
||||
gc.collect()
|
||||
|
||||
return fig
|
||||
|
||||
if __name__ == "__main__":
|
||||
trainer = Training()
|
||||
|
||||
# trainer.plot_eye()
|
||||
trainer.setup_tb_writer()
|
||||
trainer.objective(save=True)
|
||||
|
||||
best_model = trainer.model
|
||||
|
||||
# best_model = trainer.define_model(trainer.study.best_trial).to(trainer.pytorch_settings.device)
|
||||
trainer.plot_model_response(best_model)
|
||||
|
||||
# print(f"Best model: {best_model}")
|
||||
|
||||
...
|
||||
194
src/single-core-regen/testing/learn_optuna.py
Normal file
194
src/single-core-regen/testing/learn_optuna.py
Normal file
@@ -0,0 +1,194 @@
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import optuna
|
||||
import warnings
|
||||
from util.optuna_vis import show_figures
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
import torch.utils.data
|
||||
|
||||
from torchvision import datasets
|
||||
from torchvision import transforms
|
||||
import multiprocessing
|
||||
|
||||
# from util.dataset import SlicedDataset
|
||||
|
||||
DEVICE = torch.device("cuda")
|
||||
BATCHSIZE = 128
|
||||
CLASSES = 10
|
||||
DIR = Path(__file__).parent
|
||||
EPOCHS = 100
|
||||
N_TRAIN_EXAMPLES = BATCHSIZE * 30
|
||||
N_VALID_EXAMPLES = BATCHSIZE * 10
|
||||
|
||||
n_trials = 128
|
||||
n_threads = 16
|
||||
|
||||
|
||||
def define_model(trial):
|
||||
n_layers = trial.suggest_int("n_layers", 1, 3)
|
||||
layers = []
|
||||
|
||||
in_features = 28 * 28
|
||||
for i in range(n_layers):
|
||||
out_features = trial.suggest_int(f"n_units_l{i}", 4, 128)
|
||||
layers.append(nn.Linear(in_features, out_features))
|
||||
layers.append(nn.ReLU())
|
||||
p = trial.suggest_float(f"dropout_l{i}", 0.2, 0.5)
|
||||
layers.append(nn.Dropout(p))
|
||||
|
||||
in_features = out_features
|
||||
|
||||
layers.append(nn.Linear(in_features, CLASSES))
|
||||
layers.append(nn.LogSoftmax(dim=1))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
def get_mnist():
|
||||
# Load FashionMNIST dataset.
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
datasets.FashionMNIST(
|
||||
DIR / ".data", train=True, download=True, transform=transforms.ToTensor()
|
||||
),
|
||||
batch_size=BATCHSIZE,
|
||||
shuffle=True,
|
||||
)
|
||||
valid_loader = torch.utils.data.DataLoader(
|
||||
datasets.FashionMNIST(
|
||||
DIR / ".data", train=False, transform=transforms.ToTensor()
|
||||
),
|
||||
batch_size=BATCHSIZE,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
return train_loader, valid_loader
|
||||
|
||||
|
||||
def objective(trial):
|
||||
model = define_model(trial).to(DEVICE)
|
||||
|
||||
optimizer_name = trial.suggest_categorical("optimizer", ["Adam", "RMSprop", "SGD"])
|
||||
lr = trial.suggest_float("lr", 1e-5, 1e-1, log=True)
|
||||
optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)
|
||||
|
||||
train_loader, valid_loader = get_mnist()
|
||||
|
||||
for epoch in range(EPOCHS):
|
||||
train_model(model, optimizer, train_loader)
|
||||
accuracy, num_params = eval_model(model, valid_loader)
|
||||
|
||||
return accuracy, num_params
|
||||
|
||||
|
||||
def eval_model(model, valid_loader):
|
||||
model.eval()
|
||||
correct = 0
|
||||
with torch.no_grad():
|
||||
for batch_idx, (data, target) in enumerate(valid_loader):
|
||||
if batch_idx * BATCHSIZE >= N_VALID_EXAMPLES:
|
||||
break
|
||||
|
||||
data, target = data.view(data.size(0), -1).to(DEVICE), target.to(DEVICE)
|
||||
output = model(data)
|
||||
|
||||
pred = output.argmax(dim=1, keepdim=True)
|
||||
correct += pred.eq(target.view_as(pred)).sum().item()
|
||||
|
||||
accuracy = correct / min(len(valid_loader.dataset), N_VALID_EXAMPLES)
|
||||
num_params = sum(p.numel() for p in model.parameters())
|
||||
return accuracy, num_params
|
||||
|
||||
|
||||
def train_model(model, optimizer, train_loader):
|
||||
model.train()
|
||||
for batch_idx, (data, target) in enumerate(train_loader):
|
||||
if batch_idx * BATCHSIZE >= N_TRAIN_EXAMPLES:
|
||||
break
|
||||
|
||||
data, target = data.view(data.size(0), -1).to(DEVICE), target.to(DEVICE)
|
||||
|
||||
optimizer.zero_grad()
|
||||
output = model(data)
|
||||
loss = F.nll_loss(output, target)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
|
||||
def run_optimize(n_trials, study):
|
||||
study.optimize(objective, n_trials=n_trials, timeout=600)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
study_name = f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')} mnist example"
|
||||
storage = "sqlite:///db.sqlite3"
|
||||
directions = ["maximize", "minimize"]
|
||||
|
||||
study = optuna.create_study(
|
||||
directions=directions,
|
||||
storage=storage,
|
||||
study_name=study_name,
|
||||
)
|
||||
|
||||
with warnings.catch_warnings(action="ignore"):
|
||||
study.set_metric_names(["accuracy", "num params"])
|
||||
|
||||
n_threads = min(n_trials, n_threads)
|
||||
|
||||
processes = []
|
||||
for _ in range(n_threads):
|
||||
p = multiprocessing.Process(
|
||||
target=run_optimize, args=(n_trials // n_threads, study)
|
||||
)
|
||||
p.start()
|
||||
processes.append(p)
|
||||
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
remaining_trials = n_trials - ((n_trials // n_threads) * n_threads)
|
||||
if remaining_trials:
|
||||
print(
|
||||
f"\nRunning last {remaining_trials} trial{'s' if remaining_trials > 1 else ''}:"
|
||||
)
|
||||
run_optimize(directions, remaining_trials, study_name, storage)
|
||||
|
||||
print(f"Number of trials on the Pareto front: {len(study.best_trials)}")
|
||||
|
||||
trial_with_highest_accuracy = max(study.best_trials, key=lambda t: t.values[1])
|
||||
print("Trial with highest accuracy: ")
|
||||
print(f"\tnumber: {trial_with_highest_accuracy.number}")
|
||||
print(f"\tparams: {trial_with_highest_accuracy.params}")
|
||||
print(f"\tvalues: {trial_with_highest_accuracy.values}")
|
||||
|
||||
# for trial in trials:
|
||||
# print(f"Trial {trial.number}")
|
||||
# print(f" Accuracy: {trial.values[0]}")
|
||||
# print(f" n_params: {int(trial.values[1])}")
|
||||
# print( " Params: ")
|
||||
# for key, value in trial.params.items():
|
||||
# print(" {}: {}".format(key, value))
|
||||
# print()
|
||||
|
||||
# print(" Value: ", trial.value)
|
||||
|
||||
# print(" Params: ")
|
||||
# for key, value in trial.params.items():
|
||||
# print(" {}: {}".format(key, value))
|
||||
|
||||
figures = []
|
||||
figures.append(
|
||||
optuna.visualization.plot_pareto_front(
|
||||
study, target_names=["accuracy", "num_params"]
|
||||
)
|
||||
)
|
||||
figures.append(optuna.visualization.plot_timeline(study))
|
||||
|
||||
plt = show_figures(*figures)
|
||||
|
||||
print()
|
||||
# plt.show()
|
||||
51
src/single-core-regen/testing/sliced_dataset_test.py
Normal file
51
src/single-core-regen/testing/sliced_dataset_test.py
Normal file
@@ -0,0 +1,51 @@
|
||||
# move into dir single-core-regen before running
|
||||
|
||||
from util.dataset import SlicedDataset
|
||||
from torch.utils.data import DataLoader
|
||||
from matplotlib import pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
def eye_dataset(dataset, no_symbols=None, offset=False, show=True):
|
||||
if no_symbols is None:
|
||||
no_symbols = len(dataset)
|
||||
_, axs = plt.subplots(2,2, sharex=True, sharey=True)
|
||||
|
||||
xaxis = np.linspace(0,dataset.symbols_per_slice,dataset.samples_per_slice)
|
||||
roll = dataset.samples_per_symbol//2 if offset else 0
|
||||
for E_out, E_in in dataset[roll:dataset.samples_per_symbol*no_symbols+roll:dataset.samples_per_symbol]:
|
||||
E_in_x, E_in_y, E_out_x, E_out_y = E_in[0], E_in[1], E_out[0], E_out[1]
|
||||
axs[0,0].plot(xaxis, np.abs( E_in_x.numpy())**2, alpha=0.05, color='C0')
|
||||
axs[1,0].plot(xaxis, np.abs( E_in_y.numpy())**2, alpha=0.05, color='C0')
|
||||
axs[0,1].plot(xaxis, np.abs(E_out_x.numpy())**2, alpha=0.05, color='C0')
|
||||
axs[1,1].plot(xaxis, np.abs(E_out_y.numpy())**2, alpha=0.05, color='C0')
|
||||
|
||||
if show:
|
||||
plt.show()
|
||||
|
||||
# def plt_dataloader(dataloader, show=True):
|
||||
# _, axs = plt.subplots(2,2, sharex=True, sharey=True)
|
||||
|
||||
# E_outs, E_ins = next(iter(dataloader))
|
||||
# for i, (E_out, E_in) in enumerate(zip(E_outs, E_ins)):
|
||||
# xaxis = np.linspace(dataset.symbols_per_slice*i,dataset.symbols_per_slice+dataset.symbols_per_slice*i,dataset.samples_per_slice)
|
||||
# E_in_x, E_in_y, E_out_x, E_out_y = E_in[0], E_in[1], E_out[0], E_out[1]
|
||||
# axs[0,0].plot(xaxis, np.abs(E_in_x.numpy())**2)
|
||||
# axs[1,0].plot(xaxis, np.abs(E_in_y.numpy())**2)
|
||||
# axs[0,1].plot(xaxis, np.abs(E_out_x.numpy())**2)
|
||||
# axs[1,1].plot(xaxis, np.abs(E_out_y.numpy())**2)
|
||||
|
||||
# if show:
|
||||
# plt.show()
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
dataset = SlicedDataset("data/20241115-175517-128-16384-10000-0-0-17-0-PAM4-0.ini", symbols=1, drop_first=100)
|
||||
print(dataset[0][0].shape)
|
||||
|
||||
eye_dataset(dataset, 1000, offset=True, show=False)
|
||||
|
||||
train_loader = DataLoader(dataset, batch_size=10, shuffle=False)
|
||||
|
||||
# plt_dataloader(train_loader, show=False)
|
||||
|
||||
plt.show()
|
||||
68
src/single-core-regen/testing/torch-import-test.py
Normal file
68
src/single-core-regen/testing/torch-import-test.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import torch
|
||||
import time
|
||||
|
||||
def print_torch_env():
|
||||
print("Torch version: ", torch.__version__)
|
||||
print("CUDA available: ", torch.cuda.is_available())
|
||||
print("CUDA version: ", torch.version.cuda)
|
||||
print("CUDNN version: ", torch.backends.cudnn.version())
|
||||
print("Device count: ", torch.cuda.device_count())
|
||||
print("Current device: ", torch.cuda.current_device())
|
||||
print("Device name: ", torch.cuda.get_device_name(0))
|
||||
print("Device capability: ", torch.cuda.get_device_capability(0))
|
||||
print("Device memory: ", torch.cuda.get_device_properties(0).total_memory)
|
||||
|
||||
def measure_runtime(func):
|
||||
"""
|
||||
Measure the runtime of a function.
|
||||
|
||||
:param func: Function to measure
|
||||
:type func: function
|
||||
:return: Wrapped function with runtime measurement
|
||||
:rtype: function
|
||||
"""
|
||||
def wrapper(*args, **kwargs):
|
||||
start_time = time.time()
|
||||
result = func(*args, **kwargs)
|
||||
end_time = time.time()
|
||||
print(f"Runtime: {end_time - start_time:.6f} seconds")
|
||||
return result, end_time - start_time
|
||||
return wrapper
|
||||
|
||||
@measure_runtime
|
||||
def tensor_addition(a, b):
|
||||
"""
|
||||
Perform tensor addition.
|
||||
|
||||
:param a: First tensor
|
||||
:type a: torch.Tensor
|
||||
:param b: Second tensor
|
||||
:type b: torch.Tensor
|
||||
:return: Sum of tensors
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
return a + b
|
||||
|
||||
def runtime_test():
|
||||
x = torch.rand(2**18, 2**10)
|
||||
y = torch.rand(2**18, 2**10)
|
||||
|
||||
print("Tensor addition on CPU")
|
||||
_, cpu_time = tensor_addition(x, y)
|
||||
|
||||
print()
|
||||
print("Tensor addition on GPU")
|
||||
if not torch.cuda.is_available():
|
||||
print("CUDA is not available")
|
||||
return
|
||||
|
||||
_, gpu_time = tensor_addition(x.cuda(), y.cuda())
|
||||
|
||||
print()
|
||||
print(f"Speedup: {cpu_time / gpu_time *100:.2f}%")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print_torch_env()
|
||||
print()
|
||||
runtime_test()
|
||||
17
src/single-core-regen/util/__init__.py
Normal file
17
src/single-core-regen/util/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from . import datasets # noqa: F401
|
||||
# from .datasets import FiberRegenerationDataset # noqa: F401
|
||||
# from .datasets import load_data # noqa: F401
|
||||
|
||||
|
||||
from . import plot # noqa: F401
|
||||
# from .plot import eye # noqa: F401
|
||||
|
||||
from . import optuna_helpers # noqa: F401
|
||||
# from .optuna_helpers import optional_suggest_categorical # noqa: F401
|
||||
# from .optuna_helpers import optional_suggest_float # noqa: F401
|
||||
# from .optuna_helpers import optional_suggest_int # noqa: F401
|
||||
|
||||
from . import complexNN # noqa: F401
|
||||
# from .complexNN import UnitaryLayer # noqa: F401
|
||||
# from .complexNN import complex_mse_loss # noqa: F401
|
||||
# from .complexNN import complex_sse_loss # noqa: F401
|
||||
141
src/single-core-regen/util/complexNN.py
Normal file
141
src/single-core-regen/util/complexNN.py
Normal file
@@ -0,0 +1,141 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
def complex_mse_loss(input, target):
|
||||
"""
|
||||
Compute the mean squared error between two complex tensors.
|
||||
"""
|
||||
return torch.mean(torch.square(input.real - target.real) + torch.square(input.imag - target.imag))
|
||||
|
||||
def complex_sse_loss(input, target):
|
||||
"""
|
||||
Compute the sum squared error between two complex tensors.
|
||||
"""
|
||||
if input.is_complex():
|
||||
return torch.sum(torch.square(input.real - target.real) + torch.square(input.imag - target.imag))
|
||||
else:
|
||||
return torch.sum(torch.square(input - target))
|
||||
|
||||
|
||||
|
||||
|
||||
class UnitaryLayer(nn.Module):
|
||||
def __init__(self, in_features, out_features):
|
||||
super(UnitaryLayer, self).__init__()
|
||||
assert in_features >= out_features
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.weight = nn.Parameter(torch.randn(in_features, out_features, dtype=torch.cfloat))
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
q, _ = torch.linalg.qr(self.weight)
|
||||
self.weight.data = q
|
||||
|
||||
@staticmethod
|
||||
@torch.jit.script
|
||||
def _unitary_forward(x, weight):
|
||||
out = torch.matmul(x, weight)
|
||||
return out
|
||||
|
||||
def forward(self, x):
|
||||
return self._unitary_forward(x, self.weight)
|
||||
|
||||
|
||||
#### as defined by zhang et al
|
||||
|
||||
class Identity(nn.Module):
|
||||
"""
|
||||
implements the "activation" function
|
||||
M(z) = z
|
||||
"""
|
||||
def __init__(self):
|
||||
super(Identity, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
class Mag(nn.Module):
|
||||
"""
|
||||
implements the activation function
|
||||
M(z) = ||z||
|
||||
"""
|
||||
def __init__(self):
|
||||
super(Mag, self).__init__()
|
||||
|
||||
@torch.jit.script
|
||||
def forward(self, x):
|
||||
return torch.abs(x.real**2 + x.imag**2)
|
||||
|
||||
# class Tanh(nn.Module):
|
||||
# """
|
||||
# implements the activation function
|
||||
# M(z) = tanh(z) = sinh(z)/cosh(z) = (exp(z)-exp(-z))/(exp(z)+exp(-z)) = (exp(2*z)-1)/(exp(2*z)+1)
|
||||
# """
|
||||
# def __init__(self):
|
||||
# super(Tanh, self).__init__()
|
||||
|
||||
# def forward(self, x):
|
||||
# return torch.tanh(x)
|
||||
|
||||
class ModReLU(nn.Module):
|
||||
"""
|
||||
implements the activation function
|
||||
M(z) = ReLU(||z|| + b)*exp(j*theta_z)
|
||||
= ReLU(||z|| + b)*z/||z||
|
||||
"""
|
||||
def __init__(self, b=0):
|
||||
super(ModReLU, self).__init__()
|
||||
self.b = b
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
@staticmethod
|
||||
# @torch.jit.script
|
||||
def _mod_relu(x, b):
|
||||
mod = torch.abs(x.real**2 + x.imag**2)
|
||||
return torch.relu(mod + b) * x / mod
|
||||
|
||||
def forward(self, x):
|
||||
return self._mod_relu(x, self.b)
|
||||
|
||||
class CReLU(nn.Module):
|
||||
"""
|
||||
implements the activation function
|
||||
M(z) = ReLU(Re(z)) + j*ReLU(Im(z))
|
||||
"""
|
||||
def __init__(self):
|
||||
super(CReLU, self).__init__()
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
@torch.jit.script
|
||||
def forward(self, x):
|
||||
return torch.relu(x.real) + 1j*torch.relu(x.imag)
|
||||
|
||||
class ZReLU(nn.Module):
|
||||
"""
|
||||
implements the activation function
|
||||
|
||||
M(z) = z if 0 <= angle(z) <= pi/2
|
||||
= 0 otherwise
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(ZReLU, self).__init__()
|
||||
|
||||
@torch.jit.script
|
||||
def forward(self, x):
|
||||
return x * (torch.angle(x) >= 0) * (torch.angle(x) <= torch.pi/2)
|
||||
|
||||
# class ComplexFeedForwardNN(nn.Module):
|
||||
# def __init__(self, in_features, hidden_features, out_features):
|
||||
# super(ComplexFeedForwardNN, self).__init__()
|
||||
# self.in_features = in_features
|
||||
# self.hidden_features = hidden_features
|
||||
# self.out_features = out_features
|
||||
# self.fc1 = UnitaryLayer(in_features, hidden_features)
|
||||
# self.fc2 = UnitaryLayer(hidden_features, out_features)
|
||||
|
||||
# def forward(self, x):
|
||||
# x = self.fc1(x)
|
||||
# x = self.fc2(x)
|
||||
# return x
|
||||
281
src/single-core-regen/util/datasets.py
Normal file
281
src/single-core-regen/util/datasets.py
Normal file
@@ -0,0 +1,281 @@
|
||||
from pathlib import Path
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
# from torch.utils.data import Sampler
|
||||
import numpy as np
|
||||
import configparser
|
||||
|
||||
# class SubsetSampler(Sampler[int]):
|
||||
# """
|
||||
# Samples elements from a given list of indices.
|
||||
|
||||
# :param indices: List of indices to sample from.
|
||||
# :type indices: list[int]
|
||||
# """
|
||||
|
||||
# def __init__(self, indices):
|
||||
# self.indices = indices
|
||||
|
||||
# def __iter__(self):
|
||||
# return iter(self.indices)
|
||||
|
||||
# def __len__(self):
|
||||
# return len(self.indices)
|
||||
|
||||
def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=False, device=None, dtype=None):
|
||||
filepath = Path(config_path)
|
||||
filepath = filepath.parent.glob(filepath.name)
|
||||
config = configparser.ConfigParser()
|
||||
config.read(filepath)
|
||||
path_elements = (
|
||||
config["data"]["dir"],
|
||||
config["data"]["npy_dir"],
|
||||
config["data"]["file"],
|
||||
)
|
||||
datapath = Path("/".join(path_elements).replace('"', ""))
|
||||
sps = int(config["glova"]["sps"])
|
||||
|
||||
if symbols is None:
|
||||
symbols = int(config["glova"]["nos"]) - skipfirst
|
||||
|
||||
data = np.load(datapath)[skipfirst * sps : symbols * sps + skipfirst * sps]
|
||||
|
||||
if normalize:
|
||||
a, b, c, d = data.T
|
||||
a, b, c, d = a/np.max(np.abs(a)), b/np.max(np.abs(b)), c/np.max(np.abs(c)), d/np.max(np.abs(d))
|
||||
data = np.array([a, b, c, d]).T
|
||||
|
||||
if real:
|
||||
data = np.abs(data)
|
||||
|
||||
config["glova"]["nos"] = str(symbols)
|
||||
|
||||
data = torch.tensor(data, device=device, dtype=dtype)
|
||||
|
||||
return data, config
|
||||
|
||||
def roll_along(arr, shifts, dim):
|
||||
# https://stackoverflow.com/a/76920720
|
||||
# (c) Mateen Ulhaq, 2023
|
||||
# CC BY-SA 4.0
|
||||
shifts = torch.tensor(shifts)
|
||||
assert arr.ndim - 1 == shifts.ndim
|
||||
dim %= arr.ndim
|
||||
shape = (1,) * dim + (-1,) + (1,) * (arr.ndim - dim - 1)
|
||||
dim_indices = torch.arange(arr.shape[dim]).reshape(shape)
|
||||
indices = (dim_indices - shifts.unsqueeze(dim)) % arr.shape[dim]
|
||||
return torch.gather(arr, dim, indices)
|
||||
|
||||
class FiberRegenerationDataset(Dataset):
|
||||
"""
|
||||
Dataset for fiber regeneration training.
|
||||
|
||||
The dataset is loaded from a configuration file, which must contain (at least) the following sections:
|
||||
```
|
||||
[data]
|
||||
dir = <data_dir>
|
||||
npy_dir = <npy_dir>
|
||||
file = <data_file>
|
||||
|
||||
[glova]
|
||||
sps = <samples per symbol>
|
||||
```
|
||||
The data is loaded from the file `<data_dir>/<npy_dir>/<data_file>` and is assumed to be in the following format:
|
||||
```
|
||||
[ E_in_x,
|
||||
E_in_y,
|
||||
E_out_x,
|
||||
E_out_y ]
|
||||
```
|
||||
|
||||
The dataset is sliced into slices, where each slice consists of a (fractional) number of symbols.
|
||||
The target can be delayed relative to the input data by a (fractional) number of symbols.
|
||||
The x and y channels can be delayed relative to each other by a (fractional) number of symbols.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str | Path,
|
||||
symbols: int | float,
|
||||
*,
|
||||
data_size: int = None,
|
||||
target_delay: float | int = 0,
|
||||
xy_delay: float | int = 0,
|
||||
drop_first: float | int = 0,
|
||||
dtype: torch.dtype = None,
|
||||
real: bool = False,
|
||||
device = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initialize the dataset.
|
||||
|
||||
:param file_path: Path to the data file. Can contain wildcards (*). The first
|
||||
:type file_path: str | pathlib.Path
|
||||
:param symbols: Number of symbols in each slice. Can be a float to specify a fraction of a symbol.
|
||||
:type symbols: float | int
|
||||
:param data_size: Number of samples in each slice. The data is reduced by taking equally spaced samples. If unset, each slice will contain symbols*samples_per_symbol samples.
|
||||
:type data_size: int, optional
|
||||
:param target_delay: Delay (in fractional symbols) between data and target. A positive delay means the target is delayed relative to the data. Default is 0.
|
||||
:type target_delay: float | int, optional
|
||||
:param xy_delay: Delay (in fractional symbols) between the x and y channels. A positive delay means the y channel is delayed relative to the x channel. Default is 0.
|
||||
:type xy_delay: float | int, optional
|
||||
:param drop_first: Number of (fractional) symbols to drop from the beginning
|
||||
:type drop_first: float | int
|
||||
"""
|
||||
|
||||
# check types
|
||||
assert isinstance(file_path, str), "file_path must be a string"
|
||||
assert isinstance(symbols, (float, int)), (
|
||||
"symbols must be a float or an integer"
|
||||
)
|
||||
assert data_size is None or isinstance(data_size, int), (
|
||||
"output_len must be an integer"
|
||||
)
|
||||
assert isinstance(target_delay, (float, int)), (
|
||||
"target_delay must be a float or an integer"
|
||||
)
|
||||
assert isinstance(xy_delay, (float, int)), (
|
||||
"xy_delay must be a float or an integer"
|
||||
)
|
||||
assert isinstance(drop_first, int), "drop_first must be an integer"
|
||||
|
||||
# check values
|
||||
assert symbols > 0, "symbols must be positive"
|
||||
assert data_size is None or data_size > 0, "output_len must be positive or None"
|
||||
assert drop_first >= 0, "drop_first must be non-negative"
|
||||
|
||||
faux = kwargs.pop("faux", False)
|
||||
|
||||
if faux:
|
||||
data_raw = np.array(
|
||||
[[i + 0.1j, i + 0.2j, i + 1.1j, i + 1.2j] for i in range(12800)],
|
||||
dtype=np.complex128,
|
||||
)
|
||||
data_raw = torch.tensor(data_raw, device=device, dtype=dtype)
|
||||
self.config = {
|
||||
"data": {"dir": '"."', "npy_dir": '"."', "file": "faux"},
|
||||
"glova": {"sps": 128},
|
||||
}
|
||||
else:
|
||||
data_raw, self.config = load_data(file_path, skipfirst=drop_first, real=real, normalize=True, device=device, dtype=dtype)
|
||||
|
||||
self.device = data_raw.device
|
||||
|
||||
self.samples_per_symbol = int(self.config["glova"]["sps"])
|
||||
self.samples_per_slice = int(symbols * self.samples_per_symbol)
|
||||
self.symbols_per_slice = self.samples_per_slice / self.samples_per_symbol
|
||||
|
||||
self.data_size = data_size or self.samples_per_slice
|
||||
self.target_delay = target_delay or 0
|
||||
self.xy_delay = xy_delay or 0
|
||||
|
||||
ovrd_target_delay_samples = kwargs.pop("ovrd_target_delay_samples", None)
|
||||
ovrd_xy_delay_samples = kwargs.pop("ovrd_xy_delay_samples", None)
|
||||
|
||||
self.target_delay_samples = (
|
||||
ovrd_target_delay_samples
|
||||
if ovrd_target_delay_samples is not None
|
||||
else int(self.target_delay * self.samples_per_symbol)
|
||||
)
|
||||
self.xy_delay_samples = (
|
||||
ovrd_xy_delay_samples
|
||||
if ovrd_xy_delay_samples is not None
|
||||
else int(self.xy_delay * self.samples_per_symbol)
|
||||
)
|
||||
|
||||
# data_raw = torch.tensor(data_raw, dtype=dtype)
|
||||
|
||||
# data layout
|
||||
# [ [E_in_x0, E_in_y0, E_out_x0, E_out_y0],
|
||||
# [E_in_x1, E_in_y1, E_out_x1, E_out_y1],
|
||||
# ...
|
||||
# [E_in_xN, E_in_yN, E_out_xN, E_out_yN] ]
|
||||
|
||||
data_raw = data_raw.transpose(0, 1)
|
||||
|
||||
# data layout
|
||||
# [ E_in_x[0:N],
|
||||
# E_in_y[0:N],
|
||||
# E_out_x[0:N],
|
||||
# E_out_y[0:N] ]
|
||||
|
||||
# shift x data by xy_delay_samples relative to the y data (example value: 3)
|
||||
# [ E_in_x [0:N], [ E_in_x [ 0:N ], [ E_in_x [3:N ],
|
||||
# E_in_y [0:N], -> E_in_y [-3:N-3], -> E_in_y [0:N-3],
|
||||
# E_out_x[0:N], E_out_x[ 0:N ], E_out_x[3:N ],
|
||||
# E_out_y[0:N] ] E_out_y[-3:N-3] ] E_out_y[0:N-3] ]
|
||||
|
||||
if self.xy_delay_samples != 0:
|
||||
data_raw = roll_along(
|
||||
data_raw, [0, self.xy_delay_samples, 0, self.xy_delay_samples], dim=1
|
||||
)
|
||||
if self.xy_delay_samples > 0:
|
||||
data_raw = data_raw[:, self.xy_delay_samples :]
|
||||
elif self.xy_delay_samples < 0:
|
||||
data_raw = data_raw[:, : self.xy_delay_samples]
|
||||
|
||||
# shift fiber input data (target) by target_delay_samples relative to the fiber output data (input)
|
||||
# (example value: 5)
|
||||
# [ E_in_x [0:N], [ E_in_x [-5:N-5], [ E_in_x [0:N-5],
|
||||
# E_in_y [0:N], -> E_in_y [-5:N-5], -> E_in_y [0:N-5],
|
||||
# E_out_x[0:N], E_out_x[ 0:N ], E_out_x[5:N ],
|
||||
# E_out_y[0:N] ] E_out_y[ 0:N ] ] E_out_y[5:N ] ]
|
||||
|
||||
if self.target_delay_samples != 0:
|
||||
data_raw = roll_along(
|
||||
data_raw,
|
||||
[self.target_delay_samples, self.target_delay_samples, 0, 0],
|
||||
dim=1,
|
||||
)
|
||||
if self.target_delay_samples > 0:
|
||||
data_raw = data_raw[:, self.target_delay_samples :]
|
||||
elif self.target_delay_samples < 0:
|
||||
data_raw = data_raw[:, : self.target_delay_samples]
|
||||
|
||||
data_raw = data_raw.view(2, 2, -1)
|
||||
# data layout
|
||||
# [ [E_in_x, E_in_y],
|
||||
# [E_out_x, E_out_y] ]
|
||||
|
||||
self.data = data_raw.unfold(dimension=-1, size=self.samples_per_slice, step=1)
|
||||
self.data = self.data.movedim(-2, 0)
|
||||
# -> [no_slices, 2, 2, samples_per_slice]
|
||||
|
||||
# data layout
|
||||
# [
|
||||
# [ [E_in_x[0:N+0], E_in_y[0:N+0] ], [ E_out_x[0:N+0], E_out_y[0:N+0] ] ],
|
||||
# [ [E_in_x[1:N+1], E_in_y[1:N+1] ], [ E_out_x[1:N+1], E_out_y[1:N+1] ] ],
|
||||
# ...
|
||||
# ] -> [no_slices, 2, 2, samples_per_slice]
|
||||
|
||||
...
|
||||
|
||||
def __len__(self):
|
||||
return self.data.shape[0]
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if isinstance(idx, slice):
|
||||
return [self.__getitem__(i) for i in range(*idx.indices(len(self)))]
|
||||
else:
|
||||
data, target = self.data[idx, 1].squeeze(), self.data[idx, 0].squeeze()
|
||||
|
||||
# reduce by by taking self.output_dim equally spaced samples
|
||||
data = data[:, : data.shape[1] // self.data_size * self.data_size]
|
||||
data = data.view(data.shape[0], self.data_size, -1)
|
||||
data = data[:, :, 0]
|
||||
|
||||
# target is corresponding to the middle of the data as the output sample is influenced by the data before and after it
|
||||
target = target[:, : target.shape[1] // self.data_size * self.data_size]
|
||||
target = target.view(target.shape[0], self.data_size, -1)
|
||||
target = target[:, 0, target.shape[2] // 2]
|
||||
|
||||
data = data.transpose(0, 1).flatten().squeeze()
|
||||
target = target.flatten().squeeze()
|
||||
|
||||
# data layout:
|
||||
# [sample_x0, sample_y0, sample_x1, sample_y1, ...]
|
||||
# target layout:
|
||||
# [sample_x0, sample_y0]
|
||||
|
||||
return data, target
|
||||
30
src/single-core-regen/util/optuna_helpers.py
Normal file
30
src/single-core-regen/util/optuna_helpers.py
Normal file
@@ -0,0 +1,30 @@
|
||||
def _optional_suggest(trial, name, range_or_value, log=False, step=None, type='int'):
|
||||
# not a range
|
||||
if not hasattr(range_or_value, '__iter__') or isinstance(range_or_value, str):
|
||||
return range_or_value
|
||||
|
||||
# range with only one value
|
||||
if len(range_or_value) == 1:
|
||||
return range_or_value[0]
|
||||
|
||||
if type == 'int':
|
||||
step = step or 1
|
||||
return trial.suggest_int(name, *range_or_value, step=step, log=log)
|
||||
|
||||
if type == 'float':
|
||||
return trial.suggest_float(name, *range_or_value, step=step, log=log)
|
||||
|
||||
if type == 'categorical':
|
||||
return trial.suggest_categorical(name, range_or_value)
|
||||
|
||||
raise ValueError(f"Unknown type: {type}")
|
||||
|
||||
|
||||
def optional_suggest_categorical(trial, name, choices_or_value):
|
||||
return _optional_suggest(trial, name, choices_or_value, type='categorical')
|
||||
|
||||
def optional_suggest_int(trial, name, range_or_value, step=None, log=False):
|
||||
return _optional_suggest(trial, name, range_or_value, step=step, log=log, type='int')
|
||||
|
||||
def optional_suggest_float(trial, name, range_or_value, step=None, log=False):
|
||||
return _optional_suggest(trial, name, range_or_value, step=step, log=log, type='float')
|
||||
18
src/single-core-regen/util/optuna_vis.py
Normal file
18
src/single-core-regen/util/optuna_vis.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from dash import Dash, dcc, html
|
||||
import logging
|
||||
import dash_bootstrap_components as dbc
|
||||
|
||||
|
||||
def show_figures(*figures):
|
||||
for figure in figures:
|
||||
figure.layout.template = 'plotly_dark'
|
||||
|
||||
app = Dash(external_stylesheets=[dbc.themes.DARKLY])
|
||||
app.layout = html.Div([
|
||||
dcc.Graph(figure=figure) for figure in figures
|
||||
])
|
||||
log = logging.getLogger('werkzeug')
|
||||
log.setLevel(logging.ERROR)
|
||||
|
||||
app.show = lambda *args, **kwargs: app.run_server(*args, **kwargs, debug=False)
|
||||
return app
|
||||
73
src/single-core-regen/util/plot.py
Normal file
73
src/single-core-regen/util/plot.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from .datasets import load_data
|
||||
|
||||
def eye(*, path=None, data=None, sps=None, title=None, symbols=1000, skipfirst=0, width=2, alpha=None, complex=False, show=True):
|
||||
"""Plot an eye diagram for the data given by filepath.
|
||||
|
||||
Either path or data and sps must be given.
|
||||
|
||||
Args:
|
||||
path (str): Path to the data description file.
|
||||
data (np.ndarray): Data to plot.
|
||||
sps (int): Samples per symbol.
|
||||
title (str): Title of the plot.
|
||||
head (int): Number of symbols to plot.
|
||||
skipfirst (int): Number of symbols to skip.
|
||||
show (bool): Whether to call plt.show().
|
||||
"""
|
||||
if path is None and data is None:
|
||||
raise ValueError("Either path or data and sps must be given.")
|
||||
if path is not None:
|
||||
data, config = load_data(path, skipfirst, symbols)
|
||||
sps = int(config["glova"]["sps"])
|
||||
if sps is None:
|
||||
raise ValueError("sps not set.")
|
||||
|
||||
xaxis = np.linspace(0, width, width*sps, endpoint=False)
|
||||
fig, axs = plt.subplots(2, 2, figsize=(10, 10), sharex=True, sharey=True)
|
||||
if complex:
|
||||
# create secondary axis for phase
|
||||
axs2 = axs[0, 0].twinx(), axs[0, 1].twinx(), axs[1, 0].twinx(), axs[1, 1].twinx()
|
||||
axs2 = np.reshape(axs2, (2, 2))
|
||||
|
||||
for i in range(symbols-(width-1)):
|
||||
inx, iny, outx, outy = data[i*sps:(i+width)*sps].T
|
||||
if complex:
|
||||
axs[0, 0].plot(xaxis, np.abs(inx), color="C0", alpha=alpha or 0.1)
|
||||
axs[0, 1].plot(xaxis, np.abs(outx), color="C0", alpha=alpha or 0.1)
|
||||
axs[1, 0].plot(xaxis, np.abs(iny), color="C0", alpha=alpha or 0.1)
|
||||
axs[1, 1].plot(xaxis, np.abs(outy), color="C0", alpha=alpha or 0.1)
|
||||
axs[0,0].set_ylim(0, 1.1*np.max(np.abs(data)))
|
||||
|
||||
axs2[0, 0].plot(xaxis, np.angle(inx), color="C1", alpha=alpha or 0.1)
|
||||
axs2[0, 1].plot(xaxis, np.angle(outx), color="C1", alpha=alpha or 0.1)
|
||||
axs2[1, 0].plot(xaxis, np.angle(iny), color="C1", alpha=alpha or 0.1)
|
||||
axs2[1, 1].plot(xaxis, np.angle(outy), color="C1", alpha=alpha or 0.1)
|
||||
else:
|
||||
axs[0, 0].plot(xaxis, np.abs(inx)**2, color="C0", alpha=alpha or 0.1)
|
||||
axs[0, 1].plot(xaxis, np.abs(outx)**2, color="C0", alpha=alpha or 0.1)
|
||||
axs[1, 0].plot(xaxis, np.abs(iny)**2, color="C0", alpha=alpha or 0.1)
|
||||
axs[1, 1].plot(xaxis, np.abs(outy)**2, color="C0", alpha=alpha or 0.1)
|
||||
|
||||
if complex:
|
||||
axs2[0, 0].sharey(axs2[0, 1])
|
||||
axs2[0, 1].sharey(axs2[1, 0])
|
||||
axs2[1, 0].sharey(axs2[1, 1])
|
||||
# make y axis symmetric
|
||||
ylim = np.max(np.abs(np.angle(data)))*1.1
|
||||
if ylim != 0:
|
||||
axs2[0, 0].set_ylim(-ylim, ylim)
|
||||
else:
|
||||
axs[0,0].set_ylim(0, 1.1*np.max(np.abs(data))**2)
|
||||
|
||||
axs[0, 0].set_title("Input x")
|
||||
axs[0, 1].set_title("Output x")
|
||||
axs[1, 0].set_title("Input y")
|
||||
axs[1, 1].set_title("Output y")
|
||||
fig.suptitle(title or "Eye diagram")
|
||||
|
||||
if show:
|
||||
plt.show()
|
||||
|
||||
return fig
|
||||
Reference in New Issue
Block a user