Compare commits

...

5 Commits

Author SHA1 Message Date
Joseph Hopfmüller
249fe1e940 wip 2025-01-27 21:05:49 +01:00
Joseph Hopfmüller
f38d0ca3bb model robustness testing 2025-01-10 23:40:54 +01:00
Joseph Hopfmüller
3af73343c1 add corrected datasets with PMD and dispersion 2024-12-29 16:41:52 +01:00
Joseph Hopfmüller
7a0b65f82d Merge branch 'machine_learning' of git.suuppl.dev:seppl/optical-regeneration into machine_learning 2024-12-29 16:00:41 +01:00
Joseph Hopfmüller
98305fdf47 update dataset configurations, add rotation module, and refine model settings for training, new hyperparameter tuning run for corrected datasets 2024-12-29 16:00:36 +01:00
56 changed files with 3863 additions and 931 deletions

2
.gitignore vendored
View File

@@ -163,3 +163,5 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear # and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder. # option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/ #.idea/
tolerance_results/*
data/*

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:bc02b0099ea3bb136733e3d20817cad79b6c50c2e4b845f0d206455dde188cc4
size 615

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d80ff6f2a84acf973fbdf81a05ed0b1902f8bf97856cd5132b646f6b1173f496
size 615

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:54ac6b6a452aa6b7d312a4c8ab8f7ebe2f96c1c4170cbc56147e8f2f9d934ad6
size 615

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:bd2c6f4050488b6e857759d48aa1f1f37399d81cee1667d3668145e938d17c83
size 615

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c59b8113092459a8751b385a7b1a6f10828626d2ec2f29775512157fd9bbc75c
size 615

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:04eaa2a29b3302e5de3bf99bf6a57fc8f27361fd3df3cac9245e25ab99324829
size 615

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:704d3b0b17b9d320f4717b5a5a8bdbc5714f3caa4efa7153e980766429e834f4
size 615

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:29304f35a88fd777566105f8666fff1c9927beb32756822365bcf9c159feb98e
size 615

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:4a87488c12e0253b2bb5d1ac7aa1536f69ef569e62c2aab6a10149d753e049b4
size 615

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:ff47c8d5413881edb03cfadfcde3b550ef7089543615a467c4f0027edaf1455e
size 615

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:9e05a0a54c7e3aaffeca0ac90cce1b274a544d90b329a93d453392f5df4e91a8
size 616

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:25e45b06b551ab8031c2159030d658999fcb3d1f0a34538c90768b94c8116771
size 616

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:707ed73713b6c2d80d4e333b1ccdf4650f50aefefff58ddb471c1d5411954b3d
size 616

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a8c3baf878943741d83835c1afed05c1f9780ff3f0df260c0d92706343e59c50
size 616

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:98878d09510dedc24f1429bbd12cce49c66be6c9d279a28765b120efe820a171
size 616

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:fcbdaffa211d6b0b44b3ae1c66645999e95901bfdb2fffee4c45e34a0d901ee1
size 649

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:73e922310068a66ab1a0c3d39b0b0fd7db798f2637b199a8c4bd113a38bb28c8
size 134217856

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:43e80dc7d21aeff62c73f0ed02b20a4ac9b573d0e131a3ab8d1471077e03634b
size 134217856

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:fb74cfbeec54b4f263c08510312881527fc7e484604fa0a1213b596f175fecc2
size 134217856

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:447cca0af25e309c8be61216cea4fb2d3a8a967b0522760f1dab8f15e6b41574
size 134217856

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:63321906920de825fc7857aa9f2e4944c3b32f3eadf99da06b66f6599924bc4c
size 134217856

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:92e3fcc41f05380cb7b334fefc0a30bb8c1dfd9ca5c3b2cfaad36b0c7093914e
size 134217856

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:0bf671d666d35617edd7bfb58548a5120bd92e5f9cb6edb4b5fc8d3bf5db8987
size 134217856

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:bd429a5776555671b19d42e3ae152026dafd5bf95aeed9847df1432ed37f3eba
size 134217856

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:1df90745cc2e6d4b0ad964fca2de1441e6e0b4b8345fbb0fbc1ffe9820674269
size 134481920

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:477365edd696f66610298198257422a017f825e1f8bec4363bdfb0da0d741ebc
size 134217856

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:1814f8ae0e6cdb69c0030741a3c6b35c74f2d6985eed34c0d5b4135384014abc
size 134217856

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d53bc0fd0897c7372c2f182b84163bcf401ad91f26b6949c6d7a1d70c5dbb513
size 134217856

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:fc882f29530f7d683be631214f6667611e0aba87453a11157d71c50f3548fb3c
size 134217856

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:0f0bc616c2e581444a6fa658e45ee942c1ef5a4d21f22363518331f8d80cbe62
size 134217856

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:ad94f2af43a2a06ebddc78566cbef0ea538c3da9191682e2fde4ecddb061b0f0
size 134217856

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:6930349d1ae1479dcb4c9ee9eaefe2da3eae62539cb4b97de873a7a8b175e809
size 134217856

View File

@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:746ea83013870351296e01e294905ee027291ef79fd78c1e6b69dd9ebaa1cba0 oid sha256:76934d1d202aea1311ba67f5ea35eeb99a9c5c856f491565032e7d54ca6f072d
size 10240000 size 13598720

37
notes/models.md Normal file
View File

@@ -0,0 +1,37 @@
# models
## no polarisation flipping
```py
config_path="data/20241229-163*-128-16384-50000-*.ini"
model=".models/best_20241230_011907.tar"
```
```py
config_path="data/20241229-163*-128-16384-80000-*.ini"
model=".models/best_20241230_103752.tar"
```
```py
config_path="data/20241229-163*-128-16384-100000-*.ini"
model=".models/best_20241230_164534.tar"
```
## with polarisation flipping
polarisation flipping: signal is randomly rotated by 180°. polarization rotation can be detected by adding a tone on one of the polarisations, but only to mod 180° with a direct detection setup. the randomly flipped signal should allow the network to hopefully learn to compensate for dispersion, pmd independently from the polarization rot. the training data includes the flipped signal as well, but no indication if the polarisation is flipped.
```py
config_path="data/20241229-163*-128-16384-50000-*.ini"
model=".models/best_20241231_000328.tar"
```
```py
config_path="data/20241229-163*-128-16384-80000-*.ini"
model=".models/best_20241231_163614.tar"
```
```py
config_path="data/20241229-163*-128-16384-100000-*.ini"
model=".models/best_20241231_170532.tar"
```

View File

@@ -0,0 +1,59 @@
# Baseline Models
## a) D+S, pol_error 0, ortho_error 0, DGD 0
dataset
```raw
data/20250118-225840-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini
```
model
```raw
.models/best_20250118_225918.tar
```
## b) D+S, pol_error 0.4, ortho_error 0, DGD 0
dataset
```raw
data/20250116-214606-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini
```
model
```raw
.models/best_20250116_214816.tar
```
## c) D+S, pol_error 0, ortho_error 0.1, DGD 0
dataset
```raw
data/20250116-214547-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini
```
model
```raw
.models/best_20250117_122319.tar
```
## d) D+S, pol_error 0, ortho_error 0, DGD 10ps (1 T_sym)
birefringence angle pi/2 (worst case)
dataset
```raw
data/20250117-143926-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini
```
model
```raw
.models/best_20250117_144001.tar
```

2
pypho

Submodule pypho updated: dd015f4852...e44fc477fe

View File

@@ -26,6 +26,8 @@ import torch
import torch.optim as optim import torch.optim as optim
import torch.utils.data import torch.utils.data
import hypertraining.models as models
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
import multiprocessing import multiprocessing
@@ -253,14 +255,17 @@ class HyperTraining:
model_kwargs = { model_kwargs = {
"dims": (input_dim, *hidden_dims, self.model_settings.output_dim), "dims": (input_dim, *hidden_dims, self.model_settings.output_dim),
"layer_function": layer_func, "layer_function": layer_func,
"layer_parametrizations": layer_parametrizations, "layer_func_kwargs": self.model_settings.model_layer_kwargs,
"activation_function": afunc, "act_function": afunc,
"act_func_kwargs": None,
"parametrizations": layer_parametrizations,
"dtype": dtype, "dtype": dtype,
"droupout_prob": self.model_settings.dropout_prob, "dropout_prob": self.model_settings.dropout_prob,
"scale": scale_layers, "scale_layers": scale_layers,
"rotate": False,
} }
model = util.complexNN.regenerator(**model_kwargs) model = models.regenerator(*model_kwargs.pop("dims"), **model_kwargs)
n_nodes = sum(hidden_dims) n_nodes = sum(hidden_dims)
if writer is not None: if writer is not None:
@@ -381,7 +386,10 @@ class HyperTraining:
running_loss = 0.0 running_loss = 0.0
model.train() model.train()
loader_len = len(train_loader) loader_len = len(train_loader)
for batch_idx, (x, y, _) in enumerate(train_loader): for batch_idx, batch in enumerate(train_loader):
x = batch["x"]
y = batch["y"]
if batch_idx >= self.optuna_settings._n_train_batches: if batch_idx >= self.optuna_settings._n_train_batches:
break break
model.zero_grad(set_to_none=True) model.zero_grad(set_to_none=True)
@@ -390,7 +398,7 @@ class HyperTraining:
y.to(self.pytorch_settings.device), y.to(self.pytorch_settings.device),
) )
y_pred = model(x) y_pred = model(x)
loss = util.complexNN.complex_mse_loss(y_pred, y) loss = util.complexNN.complex_mse_loss(y_pred, y, power=True)
loss_value = loss.item() loss_value = loss.item()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
@@ -444,7 +452,9 @@ class HyperTraining:
model.eval() model.eval()
running_error = 0 running_error = 0
with torch.no_grad(): with torch.no_grad():
for batch_idx, (x, y, _) in enumerate(valid_loader): for batch_idx, batch in enumerate(valid_loader):
x = batch["x"]
y = batch["y"]
if batch_idx >= self.optuna_settings._n_valid_batches: if batch_idx >= self.optuna_settings._n_valid_batches:
break break
x, y = ( x, y = (
@@ -452,50 +462,44 @@ class HyperTraining:
y.to(self.pytorch_settings.device), y.to(self.pytorch_settings.device),
) )
y_pred = model(x) y_pred = model(x)
error = util.complexNN.complex_mse_loss(y_pred, y) error = util.complexNN.complex_mse_loss(y_pred, y, power=True)
error_value = error.item() error_value = error.item()
running_error += error_value running_error += error_value
running_error /= min(len(valid_loader), self.optuna_settings._n_valid_batches) running_error /= min(len(valid_loader), self.optuna_settings._n_valid_batches)
if writer is not None: if writer is not None:
title_append, subtitle = self.build_title(trial) writer.add_scalar(
writer.add_figure( "eval loss",
"fiber response", running_error,
self.plot_model_response( epoch,
trial,
model=model,
title_append=title_append,
subtitle=subtitle,
show=False,
),
epoch + 1,
)
writer.add_figure(
"eye diagram",
self.plot_model_response(
trial,
model=self.model,
title_append=title_append,
subtitle=subtitle,
show=False,
mode="eye",
),
epoch + 1,
) )
# if (epoch + 1) % 10 == 0 or epoch < 10:
# # plotting is slow, so only do it every 10 epochs
# title_append, subtitle = self.build_title(trial)
# head_fig, eye_fig, powers_fig = self.plot_model_response(
# model=model,
# title_append=title_append,
# subtitle=subtitle,
# show=False,
# )
# writer.add_figure(
# "fiber response",
# head_fig,
# epoch + 1,
# )
# writer.add_figure(
# "eye diagram",
# eye_fig,
# epoch + 1,
# )
writer.add_figure( # writer.add_figure(
"powers", # "powers",
self.plot_model_response( # powers_fig,
trial, # epoch + 1,
model=self.model, # )
title_append=title_append, # writer.flush()
subtitle=subtitle,
mode="powers",
show=False,
),
epoch + 1,
)
# if enable_progress: # if enable_progress:
# progress.stop() # progress.stop()
@@ -511,15 +515,18 @@ class HyperTraining:
with torch.no_grad(): with torch.no_grad():
model = model.to(self.pytorch_settings.device) model = model.to(self.pytorch_settings.device)
for x, y, timestamp in loader: for batch in loader:
x = batch["x"]
y = batch["y"]
timestamp = batch["timestamp"]
x, y = ( x, y = (
x.to(self.pytorch_settings.device), x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device), y.to(self.pytorch_settings.device),
) )
if trace_powers: if trace_powers:
y_pred, powers = model(x, trace_powers).cpu() y_pred, powers = model(x, trace_powers=True).cpu()
else: else:
y_pred = model(x, trace_powers).cpu() y_pred = model(x, trace_powers=True).cpu()
# x = x.cpu() # x = x.cpu()
# y = y.cpu() # y = y.cpu()
y_pred = y_pred.view(y_pred.shape[0], -1, 2) y_pred = y_pred.view(y_pred.shape[0], -1, 2)
@@ -539,7 +546,7 @@ class HyperTraining:
return fiber_in, fiber_out, regen, timestamps, powers return fiber_in, fiber_out, regen, timestamps, powers
return fiber_in, fiber_out, regen, timestamps return fiber_in, fiber_out, regen, timestamps
def objective(self, trial: optuna.Trial, plot_before=False): def objective(self, trial: optuna.Trial):
if self.stop_study: if self.stop_study:
trial.study.stop() trial.study.stop()
model = None model = None
@@ -555,54 +562,54 @@ class HyperTraining:
title_append, subtitle = self.build_title(trial) title_append, subtitle = self.build_title(trial)
writer.add_figure( # writer.add_figure(
"fiber response", # "fiber response",
self.plot_model_response( # self.plot_model_response(
trial, # trial,
model=model, # model=model,
title_append=title_append, # title_append=title_append,
subtitle=subtitle, # subtitle=subtitle,
show=False, # show=False,
), # ),
0, # 0,
) # )
writer.add_figure( # writer.add_figure(
"eye diagram", # "eye diagram",
self.plot_model_response( # self.plot_model_response(
trial, # trial,
model=self.model, # model=self.model,
title_append=title_append, # title_append=title_append,
subtitle=subtitle, # subtitle=subtitle,
mode="eye", # mode="eye",
show=False, # show=False,
), # ),
0, # 0,
) # )
writer.add_figure( # writer.add_figure(
"powers", # "powers",
self.plot_model_response( # self.plot_model_response(
trial, # trial,
model=self.model, # model=self.model,
title_append=title_append, # title_append=title_append,
subtitle=subtitle, # subtitle=subtitle,
mode="powers", # mode="powers",
show=False, # show=False,
), # ),
0, # 0,
) # )
train_loader, valid_loader = self.get_sliced_data(trial) train_loader, valid_loader = self.get_sliced_data(trial)
optimizer_name = trial.suggest_categorical_optional("optimizer", self.optimizer_settings.optimizer) optimizer_name = trial.suggest_categorical_optional("optimizer", self.optimizer_settings.optimizer)
lr = trial.suggest_float_optional("lr", self.optimizer_settings.learning_rate, log=True) lr = trial.suggest_float_optional("lr", self.optimizer_settings.optimizer_kwargs["lr"], log=True)
optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr) optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)
if self.optimizer_settings.scheduler is not None: # if self.optimizer_settings.scheduler is not None:
scheduler = getattr(optim.lr_scheduler, self.optimizer_settings.scheduler)( # scheduler = getattr(optim.lr_scheduler, self.optimizer_settings.scheduler)(
optimizer, **self.optimizer_settings.scheduler_kwargs # optimizer, **self.optimizer_settings.scheduler_kwargs
) # )
for epoch in range(self.pytorch_settings.epochs): for epoch in range(self.pytorch_settings.epochs):
trial.set_user_attr("epoch", epoch) trial.set_user_attr("epoch", epoch)
@@ -628,8 +635,8 @@ class HyperTraining:
writer, writer,
# enable_progress=enable_progress, # enable_progress=enable_progress,
) )
if self.optimizer_settings.scheduler is not None: # if self.optimizer_settings.scheduler is not None:
scheduler.step(error) # scheduler.step(error)
trial.set_user_attr("mse", error) trial.set_user_attr("mse", error)
trial.set_user_attr("log_mse", np.log10(error + np.finfo(float).eps)) trial.set_user_attr("log_mse", np.log10(error + np.finfo(float).eps))
@@ -645,10 +652,10 @@ class HyperTraining:
if self.optuna_settings._multi_objective: if self.optuna_settings._multi_objective:
return -np.log10(error + np.finfo(float).eps), trial.user_attrs.get("model_n_nodes", -1) return -np.log10(error + np.finfo(float).eps), trial.user_attrs.get("model_n_nodes", -1)
if self.pytorch_settings.save_models and model is not None: # if self.pytorch_settings.save_models and model is not None:
save_path = Path(self.pytorch_settings.model_dir) / f"{self.optuna_settings.study_name}_{trial.number}.pth" # save_path = Path(self.pytorch_settings.model_dir) / f"{self.optuna_settings.study_name}_{trial.number}.pth"
save_path.parent.mkdir(parents=True, exist_ok=True) # save_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(model, save_path) # torch.save(model, save_path)
return error return error

View File

@@ -8,7 +8,8 @@ from util.complexNN import (
photodiode, photodiode,
EOActivation, EOActivation,
polarimeter, polarimeter,
normalize_by_first # normalize_by_first,
rotate,
) )
@@ -19,11 +20,11 @@ class polarisation_estimator2(Module):
polarimeter(), polarimeter(),
torch.nn.Linear(4, 4), torch.nn.Linear(4, 4),
torch.nn.ReLU(), torch.nn.ReLU(),
torch.nn.Dropout(p=0.01), # torch.nn.Dropout(p=0.01),
torch.nn.Linear(4, 4), torch.nn.Linear(4, 4),
torch.nn.ReLU(), torch.nn.ReLU(),
torch.nn.Dropout(p=0.01), # torch.nn.Dropout(p=0.01),
torch.nn.Linear(4, 4), torch.nn.Linear(4, 1),
) )
def forward(self, x): def forward(self, x):
@@ -123,7 +124,8 @@ class regenerator(Module):
parametrizations: list[dict] = None, parametrizations: list[dict] = None,
dtype=torch.float64, dtype=torch.float64,
dropout_prob=0.01, dropout_prob=0.01,
scale_layers=False, prescale=1,
rotate=False,
): ):
super(regenerator, self).__init__() super(regenerator, self).__init__()
self._n_hidden_layers = len(dims) - 2 self._n_hidden_layers = len(dims) - 2
@@ -131,14 +133,15 @@ class regenerator(Module):
layer_func_kwargs = layer_func_kwargs or {} layer_func_kwargs = layer_func_kwargs or {}
act_func_kwargs = act_func_kwargs or {} act_func_kwargs = act_func_kwargs or {}
self.build_model(dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob, scale_layers) self.rotation = rotate
self.prescale = prescale
def build_model(self, dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob, scale_layers): self.build_model(dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob)
def build_model(self, dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob):
for i in range(0, self._n_hidden_layers): for i in range(0, self._n_hidden_layers):
self.add_module(f"layer_{i}", Sequential()) self.add_module(f"layer_{i}", Sequential())
if scale_layers:
self.get_submodule(f"layer_{i}").add_module("scale", Scale(dims[i]))
module = layer_function(dims[i], dims[i + 1], dtype=dtype, **layer_func_kwargs) module = layer_function(dims[i], dims[i + 1], dtype=dtype, **layer_func_kwargs)
self.get_submodule(f"layer_{i}").add_module("ONN", module) self.get_submodule(f"layer_{i}").add_module("ONN", module)
@@ -146,13 +149,14 @@ class regenerator(Module):
module = act_function(size=dims[i + 1], **act_func_kwargs) module = act_function(size=dims[i + 1], **act_func_kwargs)
self.get_submodule(f"layer_{i}").add_module("activation", module) self.get_submodule(f"layer_{i}").add_module("activation", module)
module = DropoutComplex(p=dropout_prob) if dropout_prob is not None and dropout_prob > 0:
self.get_submodule(f"layer_{i}").add_module("dropout", module) module = DropoutComplex(p=dropout_prob)
self.get_submodule(f"layer_{i}").add_module("dropout", module)
self.add_module(f"layer_{self._n_hidden_layers}", Sequential()) self.add_module(f"layer_{self._n_hidden_layers}", Sequential())
if scale_layers: # if scale_layers:
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("scale", Scale(dims[-2])) # self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("scale", Scale(dims[-2]))
module = layer_function(dims[-2], dims[-1], dtype=dtype, **layer_func_kwargs) module = layer_function(dims[-2], dims[-1], dtype=dtype, **layer_func_kwargs)
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("ONN", module) self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("ONN", module)
@@ -160,6 +164,14 @@ class regenerator(Module):
module = act_function(size=dims[-1], **act_func_kwargs) module = act_function(size=dims[-1], **act_func_kwargs)
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("activation", module) self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("activation", module)
module = Scale(size=dims[-1])
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("out_scale", module)
if self.rotation:
module = rotate()
self.add_module("rotate", module)
# module = Scale(size=dims[-1]) # module = Scale(size=dims[-1])
# self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("out_scale", module) # self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("out_scale", module)
@@ -190,15 +202,28 @@ class regenerator(Module):
powers.append(x.abs().square().sum()) powers.append(x.abs().square().sum())
return powers return powers
def forward(self, x, trace_powers=False): def forward(self, x, angle=None, pre_rot=False, trace_powers=False):
x = x * self.prescale
powers = self._trace_powers(trace_powers, x) powers = self._trace_powers(trace_powers, x)
x = self.layer_0(x) # x = self.layer_0(x)
powers = self._trace_powers(trace_powers, x, powers) # powers = self._trace_powers(trace_powers, x, powers)
for i in range(1, self._n_hidden_layers): for i in range(0, self._n_hidden_layers):
x = getattr(self, f"layer_{i}")(x) x = getattr(self, f"layer_{i}")(x)
powers = self._trace_powers(trace_powers, x, powers) powers = self._trace_powers(trace_powers, x, powers)
x = getattr(self, f"layer_{self._n_hidden_layers}")(x) x = getattr(self, f"layer_{self._n_hidden_layers}")(x)
powers = self._trace_powers(trace_powers, x, powers) if self.rotation:
if trace_powers: try:
return x, powers x_rot = self.rotate(x, angle)
return x except AttributeError:
pass
powers = self._trace_powers(trace_powers, x_rot, powers)
else:
x_rot = x
if pre_rot and trace_powers:
return x_rot, x, powers
if pre_rot and not trace_powers:
return x_rot, x
if not pre_rot and trace_powers:
return x_rot, powers
return x_rot

View File

@@ -18,10 +18,14 @@ class DataSettings:
shuffle: bool = True shuffle: bool = True
in_out_delay: float = 0 in_out_delay: float = 0
xy_delay: tuple | float | int = 0 xy_delay: tuple | float | int = 0
drop_first: int = 1000 drop_first: int = 64
drop_last: int = 64
train_split: float = 0.8 train_split: float = 0.8
polarisations: tuple | list = (0,) polarisations: tuple | list = (0,)
# cross_pol_interference: float = 0
randomise_polarisations: bool = False randomise_polarisations: bool = False
osnr: float | int = None
seed: int = None
""" """
change to: change to:
@@ -91,6 +95,12 @@ class ModelSettings:
""" """
def _early_stop_default_kwargs():
return {
"threshold": 1e-05,
"plateau": 25,
}
@dataclass @dataclass
class OptimizerSettings: class OptimizerSettings:
optimizer: tuple | str = ("Adam", "RMSprop", "SGD") optimizer: tuple | str = ("Adam", "RMSprop", "SGD")
@@ -99,6 +109,9 @@ class OptimizerSettings:
scheduler: str | None = None scheduler: str | None = None
scheduler_kwargs: dict | None = None scheduler_kwargs: dict | None = None
early_stopping: bool = False
early_stop_kwargs: dict = field(default_factory=_early_stop_default_kwargs)
""" """
change to: change to:

View File

@@ -2,9 +2,9 @@ import copy
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
import random import random
from typing import Literal
import matplotlib import matplotlib
from matplotlib.colors import LinearSegmentedColormap from matplotlib.colors import LinearSegmentedColormap
from mpl_toolkits.axes_grid1 import make_axes_locatable
import torch.nn.utils.parametrize import torch.nn.utils.parametrize
try: try:
@@ -47,46 +47,107 @@ from .settings import (
PytorchSettings, PytorchSettings,
) )
from cmcrameri import cm
# from matplotlib import colors as mcolors
# alpha_map = mcolors.LinearSegmentedColormap(
# 'alphamap',
# {
# 'red': [(0, 0, 0), (1, 0, 0)],
# 'green': [(0, 0, 0), (1, 0, 0)],
# 'blue': [(0, 0, 0), (1, 0, 0)],
# 'alpha': [
# (0, 1, 1),
# # (0.2, 0.2, 0.1),
# (1, 0, 0)
# ]
# }
# )
# alpha_map.set_bad(color="#AAAAAA")
def pad_to_size(array, size):
if not hasattr(size, "__len__"):
size = (size, size)
left = (
(size[0] - array.shape[0] + 1) // 2 if size[0] is not None else 0
)
right = (
(size[0] - array.shape[0]) // 2 if size[0] is not None else 0
)
top = (
(size[1] - array.shape[1] + 1) // 2 if size[1] is not None else 0
)
bottom = (
(size[1] - array.shape[1]) // 2 if size[1] is not None else 0
)
array: np.ndarray = array
if array.ndim == 2:
return np.pad(
array,
(
(left, right),
(top, bottom),
),
constant_values=(np.nan, np.nan),
)
elif array.ndim == 3:
return np.pad(
array,
(
(left, right),
(top, bottom),
(0,0)
),
constant_values=(np.nan, np.nan),
)
def traverse_dict_update(target, source): def traverse_dict_update(target, source):
for k, v in source.items(): for k, v in source.items():
if isinstance(v, dict): if isinstance(v, dict):
if k not in target: try:
target[k] = {} if k not in target:
traverse_dict_update(target[k], v) target[k] = {}
traverse_dict_update(target[k], v)
except TypeError:
if k not in target.__dict__:
setattr(target, k, {})
traverse_dict_update(target.__dict__[k], v)
else: else:
try: try:
target[k] = v target[k] = v
except TypeError: except TypeError:
target.__dict__[k] = v target.__dict__[k] = v
def get_parameter_names_and_values(model): def get_parameter_names_and_values(model):
def is_parametrized(module): def is_parametrized(module):
if hasattr(module, "parametrizations"): if hasattr(module, "parametrizations"):
return True return True
return False return False
def _get_param_info(module, prefix='', parametrization=False): def _get_param_info(module, prefix="", parametrization=False):
param_list = [] param_list = []
for name, param in module.named_parameters(recurse = parametrization): for name, param in module.named_parameters(recurse=parametrization):
if parametrization and name.startswith("parametrizations"): if parametrization and name.startswith("parametrizations"):
name_parts = name.split('.') name_parts = name.split(".")
name = name_parts[1] name = name_parts[1]
param = getattr(module, name) param = getattr(module, name)
full_name = prefix + ('.' if prefix else '') + name full_name = prefix + ("." if prefix else "") + name
param_value = param.data param_value = param.data
param_list.append((full_name, param_value)) param_list.append((full_name, param_value))
for child_name, child_module in module.named_children(): for child_name, child_module in module.named_children():
child_prefix = prefix + ('.' if prefix else '') + child_name child_prefix = prefix + ("." if prefix else "") + child_name
if child_name == "parametrizations": if child_name == "parametrizations":
continue continue
param_list.extend(_get_param_info(child_module, child_prefix, is_parametrized(child_module))) param_list.extend(_get_param_info(child_module, child_prefix, is_parametrized(child_module)))
return param_list return param_list
return _get_param_info(model) return _get_param_info(model)
class PolarizationTrainer: class PolarizationTrainer:
def __init__( def __init__(
self, self,
@@ -101,7 +162,7 @@ class PolarizationTrainer:
settings_override=None, settings_override=None,
reset_epoch=False, reset_epoch=False,
): ):
self.mod = torch.pi/2 self.mod = torch.pi / 2
self.resume = checkpoint_path is not None self.resume = checkpoint_path is not None
torch.serialization.add_safe_globals([ torch.serialization.add_safe_globals([
*util.complexNN.__all__, *util.complexNN.__all__,
@@ -219,7 +280,7 @@ class PolarizationTrainer:
# dims = self.model_kwargs.pop("dims") # dims = self.model_kwargs.pop("dims")
model_kwargs = copy.deepcopy(self.model_kwargs) model_kwargs = copy.deepcopy(self.model_kwargs)
self.model = models.polarisation_estimator(*model_kwargs.pop('dims'),**model_kwargs) self.model = models.polarisation_estimator(*model_kwargs.pop("dims"), **model_kwargs)
# self.model = models.polarisation_estimator2() # self.model = models.polarisation_estimator2()
if self.writer is not None: if self.writer is not None:
@@ -260,6 +321,7 @@ class PolarizationTrainer:
target_delay=in_out_delay, target_delay=in_out_delay,
xy_delay=xy_delay, xy_delay=xy_delay,
drop_first=self.data_settings.drop_first, drop_first=self.data_settings.drop_first,
drop_last=self.data_settings.drop_last,
dtype=dtype, dtype=dtype,
real=not dtype.is_complex, real=not dtype.is_complex,
num_symbols=num_symbols, num_symbols=num_symbols,
@@ -336,17 +398,20 @@ class PolarizationTrainer:
write_div = 0 write_div = 0
loss_div = 0 loss_div = 0
for batch_idx, batch in enumerate(train_loader): for batch_idx, batch in enumerate(train_loader):
x = batch["x"] x = batch["angle_data2"]
y = batch["sop"] y = batch["center_angle"]
self.model.zero_grad(set_to_none=True) self.model.zero_grad(set_to_none=True)
x, y = ( x, y = (
x.to(self.pytorch_settings.device), x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device), y.to(self.pytorch_settings.device),
) )
y_pred = self.model(x) y_pred = self.model(x).abs().real
# y_pred = torch.fmod(y_pred, self.mod)
y = y.abs().real
# y = torch.fmod(y, self.mod)
# loss = torch.nn.functional.mse_loss(torch.cos(y_pred), torch.cos(y))
# loss = torch.nn.functional.smooth_l1_loss(torch.cos(torch.fmod(y_pred, torch.pi/2)).squeeze(), torch.cos(torch.fmod(y, torch.pi/2)).squeeze(), beta=0.5) # loss = torch.nn.functional.smooth_l1_loss(torch.cos(torch.fmod(y_pred, torch.pi/2)).squeeze(), torch.cos(torch.fmod(y, torch.pi/2)).squeeze(), beta=0.5)
loss = torch.nn.functional.mse_loss(y_pred, y) loss = util.complexNN.naive_angle_loss(y_pred, y, mod=self.mod)
# loss = util.complexNN.naive_angle_loss(y_pred, y, mod=torch.pi/2)
loss_value = loss.item() loss_value = loss.item()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
@@ -356,7 +421,7 @@ class PolarizationTrainer:
loss_div += 1 loss_div += 1
if enable_progress: if enable_progress:
progress.update(task, advance=1, description=f"{loss_value:.3e}") progress.update(task, advance=1, description=f"{loss_value/np.pi*180:.3e} °")
if batch_idx % self.pytorch_settings.write_every == 0: if batch_idx % self.pytorch_settings.write_every == 0:
self.writer.add_scalar( self.writer.add_scalar(
@@ -395,24 +460,28 @@ class PolarizationTrainer:
loss_div = 0 loss_div = 0
with torch.no_grad(): with torch.no_grad():
for _, batch in enumerate(valid_loader): for _, batch in enumerate(valid_loader):
x = batch["x"] # x = batch["angle_data2"]
y = batch["sop"] x = batch["angle_data2"]
y = batch["center_angle"]
x, y = ( x, y = (
x.to(self.pytorch_settings.device), x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device), y.to(self.pytorch_settings.device),
) )
y_pred = self.model(x) y_pred = self.model(x).abs().real
# y_pred = torch.fmod(y_pred, self.mod)
y = y.abs().real
# y = torch.fmod(y, self.mod)
# loss = torch.nn.functional.mse_loss(torch.cos(y_pred), torch.cos(y))
# loss = torch.nn.functional.smooth_l1_loss(torch.cos(torch.fmod(y_pred, torch.pi/2)).squeeze(), torch.cos(torch.fmod(y, torch.pi/2)).squeeze(), beta=0.5) # loss = torch.nn.functional.smooth_l1_loss(torch.cos(torch.fmod(y_pred, torch.pi/2)).squeeze(), torch.cos(torch.fmod(y, torch.pi/2)).squeeze(), beta=0.5)
loss = torch.nn.functional.mse_loss(y_pred, y) loss = util.complexNN.naive_angle_loss(y_pred, y, mod=self.mod)
# loss = util.complexNN.naive_angle_loss(y_pred, y, mod=torch.pi/2)
loss_value = loss.item() loss_value = loss.item()
running_loss += loss_value running_loss += loss_value
loss_div += 1 loss_div += 1
if enable_progress: if enable_progress:
progress.update(task, advance=1, description=f"{loss_value:.3e}") progress.update(task, advance=1, description=f"{loss_value/np.pi*180:.3e} °")
running_loss = running_loss/loss_div running_loss = running_loss / loss_div
self.writer.add_scalar( self.writer.add_scalar(
"eval loss", "eval loss",
@@ -506,19 +575,19 @@ class PolarizationTrainer:
for i, config_path in enumerate(self.data_settings.config_path): for i, config_path in enumerate(self.data_settings.config_path):
paths = Path.cwd().glob(config_path) paths = Path.cwd().glob(config_path)
for j, path in enumerate(paths): for j, path in enumerate(paths):
text = str(path) + '\n' text = str(path) + "\n"
with open(path, 'r') as f: with open(path, "r") as f:
text += f.read() text += f.read()
text += '\n' text += "\n"
self.writer.add_text(f"config_{i*len(self.data_settings.config_path)+j}", text) self.writer.add_text(f"config_{i * len(self.data_settings.config_path) + j}", text)
elif isinstance(self.data_settings.config_path, str): elif isinstance(self.data_settings.config_path, str):
paths = Path.cwd().glob(self.data_settings.config_path) paths = Path.cwd().glob(self.data_settings.config_path)
for j, path in enumerate(paths): for j, path in enumerate(paths):
text = str(path) + '\n' text = str(path) + "\n"
with open(path, 'r') as f: with open(path, "r") as f:
text += f.read() text += f.read()
text += '\n' text += "\n"
self.writer.add_text(f"config_{j}", text) self.writer.add_text(f"config_{j}", text)
self.writer.flush() self.writer.flush()
@@ -571,7 +640,8 @@ class PolarizationTrainer:
if loss < self.best["loss"]: if loss < self.best["loss"]:
self.best = checkpoint self.best = checkpoint
save_path = ( save_path = (
Path(self.pytorch_settings.model_dir) / f"best_pol_{self.writer.get_logdir().split('/')[-1]}.tar" Path(self.pytorch_settings.model_dir)
/ f"best_pol_{self.writer.get_logdir().split('/')[-1]}.tar"
) )
save_path.parent.mkdir(parents=True, exist_ok=True) save_path.parent.mkdir(parents=True, exist_ok=True)
self.save_checkpoint(self.best, save_path) self.save_checkpoint(self.best, save_path)
@@ -580,6 +650,7 @@ class PolarizationTrainer:
self.writer.close() self.writer.close()
return self.best return self.best
class RegenerationTrainer: class RegenerationTrainer:
def __init__( def __init__(
self, self,
@@ -592,6 +663,7 @@ class RegenerationTrainer:
console=None, console=None,
checkpoint_path=None, checkpoint_path=None,
settings_override=None, settings_override=None,
new_model=False,
reset_epoch=False, reset_epoch=False,
): ):
self.resume = checkpoint_path is not None self.resume = checkpoint_path is not None
@@ -605,12 +677,23 @@ class RegenerationTrainer:
models.regenerator, models.regenerator,
torch.nn.utils.parametrizations.orthogonal, torch.nn.utils.parametrizations.orthogonal,
]) ])
# self.new_model = True
self.model_name = datetime.now().strftime("%Y%m%d_%H%M%S")
if self.resume: if self.resume:
print(f"loading checkpoint from {checkpoint_path}") print(f"loading checkpoint from {checkpoint_path}")
self.checkpoint_dict = torch.load(checkpoint_path, weights_only=True) self.checkpoint_dict = torch.load(checkpoint_path, weights_only=True)
if settings_override is not None: if settings_override is not None:
traverse_dict_update(self.checkpoint_dict["settings"], settings_override) traverse_dict_update(self.checkpoint_dict["settings"], settings_override)
if reset_epoch:
if not new_model:
# self.new_model = False
checkpoint_file = checkpoint_path.split("/")[-1].split(".")[0]
if checkpoint_file.startswith("best"):
self.model_name = "_".join(checkpoint_file.split("_")[1:])
else:
self.model_name = "_".join(checkpoint_file.split("_")[:-1])
if new_model or reset_epoch:
self.checkpoint_dict["epoch"] = -1 self.checkpoint_dict["epoch"] = -1
self.global_settings: GlobalSettings = self.checkpoint_dict["settings"]["global_settings"] self.global_settings: GlobalSettings = self.checkpoint_dict["settings"]["global_settings"]
@@ -636,11 +719,15 @@ class RegenerationTrainer:
self.model_settings: ModelSettings = model_settings self.model_settings: ModelSettings = model_settings
self.optimizer_settings: OptimizerSettings = optimizer_settings self.optimizer_settings: OptimizerSettings = optimizer_settings
if self.global_settings.seed is not None:
random.seed(self.global_settings.seed)
np.random.seed(self.global_settings.seed)
self.console = console or Console() self.console = console or Console()
self.writer = None self.writer = None
def setup_tb_writer(self, append=None): def setup_tb_writer(self, append=None):
log_dir = self.pytorch_settings.summary_dir + "/" + (datetime.now().strftime("%Y%m%d_%H%M%S")) log_dir = self.pytorch_settings.summary_dir + "/" + self.model_name
if append is not None: if append is not None:
log_dir += "_" + str(append) log_dir += "_" + str(append)
@@ -669,7 +756,7 @@ class RegenerationTrainer:
def define_model(self, model_kwargs=None): def define_model(self, model_kwargs=None):
if self.resume: if self.resume:
model_kwargs = self.checkpoint_dict["model_kwargs"] model_kwargs = None
else: else:
model_kwargs = model_kwargs model_kwargs = model_kwargs
@@ -678,6 +765,14 @@ class RegenerationTrainer:
input_dim = 2 * self.data_settings.output_size input_dim = 2 * self.data_settings.output_size
# if self.data_settings.polarisations:
# input_dim *= 2
output_dim = self.model_settings.output_dim
if self.data_settings.polarisations:
output_dim *= 2
dtype = getattr(torch, self.data_settings.dtype) dtype = getattr(torch, self.data_settings.dtype)
afunc = getattr(util.complexNN, self.model_settings.model_activation_func) afunc = getattr(util.complexNN, self.model_settings.model_activation_func)
@@ -689,7 +784,7 @@ class RegenerationTrainer:
hidden_dims = [self.model_settings.overrides.get(f"n_hidden_nodes_{i}") for i in range(n_hidden_layers)] hidden_dims = [self.model_settings.overrides.get(f"n_hidden_nodes_{i}") for i in range(n_hidden_layers)]
self.model_kwargs = { self.model_kwargs = {
"dims": (input_dim, *hidden_dims, self.model_settings.output_dim), "dims": (input_dim, *hidden_dims, output_dim),
"layer_function": layer_func, "layer_function": layer_func,
"layer_func_kwargs": self.model_settings.model_layer_kwargs, "layer_func_kwargs": self.model_settings.model_layer_kwargs,
"act_function": afunc, "act_function": afunc,
@@ -697,7 +792,7 @@ class RegenerationTrainer:
"parametrizations": layer_parametrizations, "parametrizations": layer_parametrizations,
"dtype": dtype, "dtype": dtype,
"dropout_prob": self.model_settings.dropout_prob, "dropout_prob": self.model_settings.dropout_prob,
"scale_layers": self.model_settings.scale, "prescale": self.model_settings.scale,
} }
else: else:
self.model_kwargs = model_kwargs self.model_kwargs = model_kwargs
@@ -706,10 +801,12 @@ class RegenerationTrainer:
# dims = self.model_kwargs.pop("dims") # dims = self.model_kwargs.pop("dims")
model_kwargs = copy.deepcopy(self.model_kwargs) model_kwargs = copy.deepcopy(self.model_kwargs)
self.model = models.regenerator(*model_kwargs.pop('dims'),**model_kwargs) self.model = models.regenerator(*model_kwargs.pop("dims"), **model_kwargs)
if self.writer is not None: if self.writer is not None:
self.writer.add_graph(self.model, torch.zeros(1, input_dim, dtype=dtype)) self.writer.add_graph(
self.model, (torch.rand(1, input_dim, dtype=dtype), torch.rand(1, 1, dtype=dtype.to_real()))
)
self.model = self.model.to(self.pytorch_settings.device) self.model = self.model.to(self.pytorch_settings.device)
if self.resume: if self.resume:
@@ -728,13 +825,16 @@ class RegenerationTrainer:
num_symbols = None num_symbols = None
config_path = self.data_settings.config_path config_path = self.data_settings.config_path
polarisations = self.data_settings.polarisations
randomise_polarisations = self.data_settings.randomise_polarisations randomise_polarisations = self.data_settings.randomise_polarisations
polarisations = self.data_settings.polarisations
osnr = self.data_settings.osnr
# cross_pol_interference = self.data_settings.cross_pol_interference
if override is not None: if override is not None:
num_symbols = override.get("num_symbols", None) num_symbols = override.get("num_symbols", None)
config_path = override.get("config_path", config_path) config_path = override.get("config_path", config_path)
polarisations = override.get("polarisations", polarisations) polarisations = override.get("polarisations", polarisations)
randomise_polarisations = override.get("randomise_polarisation", randomise_polarisations) randomise_polarisations = override.get("randomise_polarisation", randomise_polarisations)
# cross_pol_interference = override.get("angle_var", 0)
# get dataset # get dataset
dataset = FiberRegenerationDataset( dataset = FiberRegenerationDataset(
file_path=config_path, file_path=config_path,
@@ -743,11 +843,14 @@ class RegenerationTrainer:
target_delay=in_out_delay, target_delay=in_out_delay,
xy_delay=xy_delay, xy_delay=xy_delay,
drop_first=self.data_settings.drop_first, drop_first=self.data_settings.drop_first,
drop_last=self.data_settings.drop_last,
dtype=dtype, dtype=dtype,
real=not dtype.is_complex, real=not dtype.is_complex,
num_symbols=num_symbols, num_symbols=num_symbols,
polarisations=polarisations,
randomise_polarisations=randomise_polarisations, randomise_polarisations=randomise_polarisations,
polarisations=polarisations,
# cross_pol_interference=cross_pol_interference,
osnr = osnr,
) )
dataset_size = len(dataset) dataset_size = len(dataset)
@@ -816,16 +919,25 @@ class RegenerationTrainer:
running_loss = 0.0 running_loss = 0.0
self.model.train() self.model.train()
loader_len = len(train_loader) loader_len = len(train_loader)
# x_key = "x_stacked"# if self.data_settings.polarisations else "x"
# y_key = "y_stacked"# if self.data_settings.polarisations else "y"
x_key = "x"
y_key = "y"
for batch_idx, batch in enumerate(train_loader): for batch_idx, batch in enumerate(train_loader):
x = batch["x"] x = batch[x_key]
y = batch["y"] y = batch[y_key]
angle = batch["mean_angle"]
self.model.zero_grad(set_to_none=True) self.model.zero_grad(set_to_none=True)
x, y = ( x, y, angle = (
x.to(self.pytorch_settings.device), x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device), y.to(self.pytorch_settings.device),
angle.to(self.pytorch_settings.device),
) )
y_pred = self.model(x) y_pred = self.model(x, -angle)
# loss = util.complexNN.complex_mse_loss(y_pred*torch.sqrt(torch.ones(1, device=y.device)*5), y*torch.sqrt(torch.ones(1, device=y.device)*5), power=True)
loss = util.complexNN.complex_mse_loss(y_pred, y, power=True) loss = util.complexNN.complex_mse_loss(y_pred, y, power=True)
loss_value = loss.item() loss_value = loss.item()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
@@ -868,23 +980,31 @@ class RegenerationTrainer:
self.model.eval() self.model.eval()
running_error = 0 running_error = 0
# x_key = "x_stacked"# if self.data_settings.polarisations else "x"
# y_key = "y_stacked"# if self.data_settings.polarisations else "y"
x_key = "x"
y_key = "y"
with torch.no_grad(): with torch.no_grad():
for _, batch in enumerate(valid_loader): for _, batch in enumerate(valid_loader):
x = batch["x"] x = batch[x_key]
y = batch["y"] y = batch[y_key]
x, y = ( angle = batch["mean_angle"]
x, y, angle = (
x.to(self.pytorch_settings.device), x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device), y.to(self.pytorch_settings.device),
angle.to(self.pytorch_settings.device),
) )
y_pred = self.model(x) y_pred = self.model(x, -angle)
# error = util.complexNN.complex_mse_loss(y_pred*torch.sqrt(torch.ones(1, device=y.device)*5), y*torch.sqrt(torch.ones(1, device=y.device)*5), power=True)
error = util.complexNN.complex_mse_loss(y_pred, y, power=True) error = util.complexNN.complex_mse_loss(y_pred, y, power=True)
error_value = error.item() error_value = error.item()
running_error += error_value running_error += error_value
if enable_progress: if enable_progress:
progress.update(task, advance=1, description=f"{error_value:.3e}") progress.update(task, advance=1, description=f"{error_value:.3e}")
running_error = running_error/len(valid_loader) running_error = running_error / len(valid_loader)
self.writer.add_scalar( self.writer.add_scalar(
"eval loss", "eval loss",
@@ -894,7 +1014,7 @@ class RegenerationTrainer:
if (epoch + 1) % 10 == 0 or epoch < 10: if (epoch + 1) % 10 == 0 or epoch < 10:
# plotting is slow, so only do it every 10 epochs # plotting is slow, so only do it every 10 epochs
title_append, subtitle = self.build_title(epoch + 1) title_append, subtitle = self.build_title(epoch + 1)
head_fig, eye_fig, powers_fig = self.plot_model_response( head_fig, eye_fig, weight_fig, powers_fig = self.plot_model_response(
model=self.model, model=self.model,
title_append=title_append, title_append=title_append,
subtitle=subtitle, subtitle=subtitle,
@@ -910,6 +1030,11 @@ class RegenerationTrainer:
eye_fig, eye_fig,
epoch + 1, epoch + 1,
) )
self.writer.add_figure(
"weights",
weight_fig,
epoch + 1,
)
self.writer.add_figure( self.writer.add_figure(
"powers", "powers",
@@ -928,45 +1053,70 @@ class RegenerationTrainer:
def run_model(self, model, loader, trace_powers=False): def run_model(self, model, loader, trace_powers=False):
model.eval() model.eval()
fiber_out = [] fiber_out = []
fiber_out_rot = []
fiber_in = [] fiber_in = []
regen = [] regen = []
timestamps = [] timestamps = []
angles = []
# x_key = "x_stacked"# if self.data_settings.polarisations else "x"
# y_key = "y_stacked"# if self.data_settings.polarisations else "y"
x_key = "x"
y_key = "y"
with torch.no_grad(): with torch.no_grad():
model = model.to(self.pytorch_settings.device) model = model.to(self.pytorch_settings.device)
for batch in loader: for batch in loader:
x = batch["x"] x = batch[x_key]
y = batch["y"] y = batch[y_key]
plot_target = batch["plot_target"]
angle = batch["mean_angle"]
# center_angle = batch["center_angle"]
timestamp = batch["timestamp"] timestamp = batch["timestamp"]
plot_data = batch["plot_data"] plot_data = batch["plot_data"]
x, y = ( plot_data_rot = batch["plot_data_rot"]
x, y, angle = (
x.to(self.pytorch_settings.device), x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device), y.to(self.pytorch_settings.device),
angle.to(self.pytorch_settings.device),
) )
if trace_powers: if trace_powers:
y_pred, powers = model(x, trace_powers).cpu() y_pred, powers = model(x, -angle, True).cpu()
else: else:
y_pred = model(x, trace_powers).cpu() y_pred = model(x, -angle).cpu()
# x = x.cpu() # x = x.cpu()
# y = y.cpu() # y = y.cpu()
# if self.data_settings.polarisations:
y_pred = y_pred[:, :2]
y_pred = y_pred.view(y_pred.shape[0], -1, 2) y_pred = y_pred.view(y_pred.shape[0], -1, 2)
y = y.view(y.shape[0], -1, 2) y_pred = y_pred[:, y_pred.shape[1]//2, :]
plot_data = plot_data.view(plot_data.shape[0], -1, 2) # y = y.view(y.shape[0], -1, 2)
# plot_data = plot_data.view(plot_data.shape[0], -1, 2)
# c = torch.cos(-angle).cpu()
# s = torch.sin(-angle).cpu()
# rot = torch.stack([torch.stack([c, -s], dim=1), torch.stack([s, c], dim=1)], dim=2).squeeze(-1)
# plot_data = torch.bmm(plot_data, rot.to(dtype=plot_data.dtype))
# plot_data = plot_data
# sines = torch.sin(-angle.cpu())
# cosines = torch.cos(-angle.cpu())
# plot_data = torch.stack((plot_data[..., 0] * cosines - plot_data[..., 1] * sines, plot_data[..., 0] * sines + plot_data[..., 1] * cosines), dim=-1)
# x = x.view(x.shape[0], -1, 2) # x = x.view(x.shape[0], -1, 2)
# timestamp = timestamp.view(-1, 1) # timestamp = timestamp.view(-1, 1)
fiber_out.append(plot_data.squeeze()) fiber_out.append(plot_data.squeeze())
fiber_in.append(y.squeeze()) fiber_out_rot.append(plot_data_rot.squeeze())
fiber_in.append(plot_target.squeeze())
regen.append(y_pred.squeeze()) regen.append(y_pred.squeeze())
timestamps.append(timestamp.squeeze()) timestamps.append(timestamp.squeeze())
angles.append(angle.squeeze())
fiber_out = torch.vstack(fiber_out).cpu() fiber_out = torch.vstack(fiber_out).cpu()
fiber_out_rot = torch.vstack(fiber_out_rot).cpu()
fiber_in = torch.vstack(fiber_in).cpu() fiber_in = torch.vstack(fiber_in).cpu()
regen = torch.vstack(regen).cpu() regen = torch.vstack(regen).cpu()
angles = torch.vstack(angles).cpu()
timestamps = torch.concat(timestamps).cpu() timestamps = torch.concat(timestamps).cpu()
if trace_powers: if trace_powers:
return fiber_in, fiber_out, regen, timestamps, powers return fiber_in, fiber_out, fiber_out_rot, angles, regen, timestamps, powers
return fiber_in, fiber_out, regen, timestamps return fiber_in, fiber_out, fiber_out_rot, angles, regen, timestamps
def write_parameters(self, epoch, attributes: list[str] | tuple[str] = None): def write_parameters(self, epoch, attributes: list[str] | tuple[str] = None):
parameter_list = get_parameter_names_and_values(self.model) parameter_list = get_parameter_names_and_values(self.model)
@@ -998,7 +1148,7 @@ class RegenerationTrainer:
) )
title_append, subtitle = self.build_title(0) title_append, subtitle = self.build_title(0)
head_fig, eye_fig, powers_fig = self.plot_model_response( head_fig, eye_fig, weight_fig, powers_fig = self.plot_model_response(
model=self.model, model=self.model,
title_append=title_append, title_append=title_append,
subtitle=subtitle, subtitle=subtitle,
@@ -1014,6 +1164,11 @@ class RegenerationTrainer:
eye_fig, eye_fig,
0, 0,
) )
self.writer.add_figure(
"weights",
weight_fig,
0,
)
self.writer.add_figure( self.writer.add_figure(
"powers", "powers",
@@ -1027,24 +1182,27 @@ class RegenerationTrainer:
for i, config_path in enumerate(self.data_settings.config_path): for i, config_path in enumerate(self.data_settings.config_path):
paths = Path.cwd().glob(config_path) paths = Path.cwd().glob(config_path)
for j, path in enumerate(paths): for j, path in enumerate(paths):
text = str(path) + '\n' text = str(path) + "\n"
with open(path, 'r') as f: with open(path, "r") as f:
text += f.read() text += f.read()
text += '\n' text += "\n"
self.writer.add_text(f"config_{i*len(self.data_settings.config_path)+j}", text) self.writer.add_text(f"config_{i * len(self.data_settings.config_path) + j}", text)
elif isinstance(self.data_settings.config_path, str): elif isinstance(self.data_settings.config_path, str):
paths = Path.cwd().glob(self.data_settings.config_path) paths = Path.cwd().glob(self.data_settings.config_path)
for j, path in enumerate(paths): for j, path in enumerate(paths):
text = str(path) + '\n' text = str(path) + "\n"
with open(path, 'r') as f: with open(path, "r") as f:
text += f.read() text += f.read()
text += '\n' text += "\n"
self.writer.add_text(f"config_{j}", text) self.writer.add_text(f"config_{j}", text)
self.writer.flush() self.writer.flush()
train_loader, valid_loader = self.get_sliced_data() train_loader, valid_loader = self.get_sliced_data()
# train_loader.dataset.fiber_out.to(self.pytorch_settings.device)
# train_loader.dataset.fiber_in.to(self.pytorch_settings.device)
optimizer_name = self.optimizer_settings.optimizer optimizer_name = self.optimizer_settings.optimizer
# lr = self.optimizer_settings.learning_rate # lr = self.optimizer_settings.learning_rate
@@ -1074,6 +1232,7 @@ class RegenerationTrainer:
# except ValueError: # except ValueError:
# pass # pass
self.early_stop_vals = {"min_loss": float("inf"), "plateau_cnt": 0}
for epoch in range(self.best["epoch"] + 1, self.pytorch_settings.epochs): for epoch in range(self.best["epoch"] + 1, self.pytorch_settings.epochs):
enable_progress = True enable_progress = True
if enable_progress: if enable_progress:
@@ -1089,33 +1248,69 @@ class RegenerationTrainer:
epoch, epoch,
enable_progress=enable_progress, enable_progress=enable_progress,
) )
if self.early_stop(loss):
self.save_model_checkpoints(epoch, loss)
break
if self.optimizer_settings.scheduler is not None: if self.optimizer_settings.scheduler is not None:
self.scheduler.step(loss) self.scheduler.step(loss)
self.writer.add_scalar("learning rate", self.optimizer.param_groups[0]["lr"], epoch) self.writer.add_scalar("learning rate", self.optimizer.param_groups[0]["lr"], epoch)
if self.pytorch_settings.save_models and self.model is not None: self.save_model_checkpoints(epoch, loss)
save_path = (
Path(self.pytorch_settings.model_dir) / f"{self.writer.get_logdir().split('/')[-1]}_{epoch}.tar"
)
save_path.parent.mkdir(parents=True, exist_ok=True)
checkpoint = self.build_checkpoint_dict(loss, epoch)
self.save_checkpoint(checkpoint, save_path)
if loss < self.best["loss"]:
self.best = checkpoint
save_path = (
Path(self.pytorch_settings.model_dir) / f"best_{self.writer.get_logdir().split('/')[-1]}.tar"
)
save_path.parent.mkdir(parents=True, exist_ok=True)
self.save_checkpoint(self.best, save_path)
self.writer.flush() self.writer.flush()
save_path = (Path(self.pytorch_settings.model_dir) / f"best_{self.writer.get_logdir().split('/')[-1]}.tar")
print(f"Training complete. Best checkpoint: {save_path}")
self.writer.close() self.writer.close()
return self.best return self.best
def early_stop(self, loss):
# not stopping early at all
if not self.optimizer_settings.early_stopping:
return False
# stopping because of abs threshold
if (loss_thr := self.optimizer_settings.early_stop_kwargs.get("threshold", None)) is not None:
if loss <= loss_thr:
print(f"Early stop: loss is below threshold ({loss:.2e} <= {loss_thr:.2e})")
return True
# update vals
if loss < self.early_stop_vals["min_loss"]:
self.early_stop_vals["min_loss"] = loss
self.early_stop_vals["plateau_cnt"] = 0
return False
# stopping because of plateau
if (plateau_thresh := self.optimizer_settings.early_stop_kwargs.get("plateau", None)) is not None:
self.early_stop_vals["plateau_cnt"] += 1
if self.early_stop_vals["plateau_cnt"] >= plateau_thresh:
print(f"Early stop: loss plateau length over threshold ({self.early_stop_vals["plateau_cnt"]} >= {plateau_thresh})")
return True
# no stop
return False
def save_model_checkpoints(self, epoch, loss):
if self.pytorch_settings.save_models and self.model is not None:
save_path = (
Path(self.pytorch_settings.model_dir) / f"{self.writer.get_logdir().split('/')[-1]}_{epoch}.tar"
)
save_path.parent.mkdir(parents=True, exist_ok=True)
checkpoint = self.build_checkpoint_dict(loss, epoch)
self.save_checkpoint(checkpoint, save_path)
if loss < self.best["loss"]:
self.best = checkpoint
save_path = (
Path(self.pytorch_settings.model_dir) / f"best_{self.writer.get_logdir().split('/')[-1]}.tar"
)
save_path.parent.mkdir(parents=True, exist_ok=True)
self.save_checkpoint(self.best, save_path)
def _plot_model_response_powers(self, powers, layer_names, title_append="", subtitle="", show=True): def _plot_model_response_powers(self, powers, layer_names, title_append="", subtitle="", show=True):
powers = [power / powers[0] for power in powers] powers = [power / powers[0] for power in powers]
fig, ax = plt.subplots() fig, ax = plt.subplots()
fig.set_figwidth(18) fig.set_figwidth(18)
fig.set_figheight(4)
fig.suptitle( fig.suptitle(
f"Energy conservation{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}" f"Energy conservation{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}"
) )
@@ -1131,6 +1326,77 @@ class RegenerationTrainer:
plt.show() plt.show()
return fig return fig
def _plot_model_weights(self, model, title_append="", subtitle="", show=True):
model_params = []
plots = []
dims = []
for num, (layer_name, layer) in enumerate(model.named_children()):
onn_weights = layer.ONN.weight
onn_weights = onn_weights.detach().cpu().numpy()
onn_values = np.abs(onn_weights).real
onn_angles = np.mod(np.angle(onn_weights), 2*np.pi).real
model_params.append({layer_name: onn_weights})
plots.append({layer_name: (num, onn_values, onn_angles)})
dims.append(onn_weights.shape[0])
max_size = np.max(dims)
for plot in plots:
layer_name, (num, onn_values, onn_angles) = plot.popitem()
if num == 0:
value_img = onn_values
angle_img = onn_angles
onn_angles = pad_to_size(onn_angles, (max_size, None))
onn_values = pad_to_size(onn_values, (max_size, None))
else:
onn_values = pad_to_size(onn_values, (max_size, onn_values.shape[1]+1))
onn_angles = pad_to_size(onn_angles, (max_size, onn_angles.shape[1]+1))
value_img = np.concatenate((value_img, onn_values), axis=1)
angle_img = np.concatenate((angle_img, onn_angles), axis=1)
value_img = np.ma.array(value_img, mask=np.isnan(value_img))
angle_img = np.ma.array(angle_img, mask=np.isnan(angle_img))
fig, axs = plt.subplots(1, 2, sharex=True, sharey=True, figsize=(18, 6.5))
fig.tight_layout()
dividers = map(make_axes_locatable, axs)
caxs = list(map(lambda d: d.append_axes("right", size="5%", pad=0.05), dividers))
masked_value_img = value_img
cmap = cm.batlow
cmap.set_bad(color="#AAAAAA")
im_val = axs[0].imshow(masked_value_img, cmap=cmap, vmin=0, vmax=1)
fig.colorbar(im_val, cax=caxs[0], orientation="vertical")
masked_angle_img = np.ma.masked_where(np.isnan(angle_img), angle_img)
cmap = cm.romaO
cmap.set_bad(color="#AAAAAA")
im_ang = axs[1].imshow(masked_angle_img, cmap=cmap, vmin=0, vmax=2*np.pi)
cbar = fig.colorbar(im_ang, cax=caxs[1], orientation="vertical", ticks=[i/8 * 2*np.pi for i in range(9)])
cbar.ax.set_yticklabels(["0", "π/4", "π/2", "3π/4", "π", "5π/4", "3π/2", "7π/4", ""])
axs[0].axis("off")
axs[1].axis("off")
axs[0].set_title("Values")
axs[1].set_title("Angles")
title = "Layer Weights"
if title_append:
title += f" {title_append}"
if subtitle:
title += f"\n{subtitle}"
fig.suptitle(title)
if show:
plt.show()
return fig
def _plot_model_response_eye( def _plot_model_response_eye(
self, *signals, timestamps, labels=None, sps=None, title_append="", subtitle="", show=True self, *signals, timestamps, labels=None, sps=None, title_append="", subtitle="", show=True
): ):
@@ -1184,6 +1450,7 @@ class RegenerationTrainer:
fig, axs = plt.subplots(1, 2 * len(signals), sharex=True, sharey=True) fig, axs = plt.subplots(1, 2 * len(signals), sharex=True, sharey=True)
fig.set_figwidth(18) fig.set_figwidth(18)
fig.set_figheight(4)
fig.suptitle(f"Eye diagram{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}") fig.suptitle(f"Eye diagram{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}")
# xaxis = timestamps / sps # xaxis = timestamps / sps
# xaxis = np.arange(2 * sps) / sps # xaxis = np.arange(2 * sps) / sps
@@ -1253,7 +1520,7 @@ class RegenerationTrainer:
xaxis = timestamps / sps xaxis = timestamps / sps
else: else:
xaxis = timestamps xaxis = timestamps
ax.plot(xaxis, np.abs(signal[:, i]) ** 2, label=label) ax.plot(xaxis, np.abs(signal[:, i]) ** 2, label=label, alpha=0.7)
ax.set_xlabel("Sample" if sps is None else "Symbol") ax.set_xlabel("Sample" if sps is None else "Symbol")
ax.set_ylabel("normalized power") ax.set_ylabel("normalized power")
ax.minorticks_on() ax.minorticks_on()
@@ -1269,7 +1536,7 @@ class RegenerationTrainer:
def plot_model_response( def plot_model_response(
self, self,
model:torch.nn.Module=None, model: torch.nn.Module = None,
title_append="", title_append="",
subtitle="", subtitle="",
# mode: Literal["eye", "head", "powers"] = "head", # mode: Literal["eye", "head", "powers"] = "head",
@@ -1281,7 +1548,9 @@ class RegenerationTrainer:
model = model.to(self.pytorch_settings.device) model = model.to(self.pytorch_settings.device)
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
_, powers = model(input_data, trace_powers=True) _, powers = model(
input_data, torch.zeros(input_data.shape[0], 1).to(self.pytorch_settings.device), trace_powers=True
)
powers = [power.item() for power in powers] powers = [power.item() for power in powers]
layer_names = [name for (name, _) in model.named_children()] layer_names = [name for (name, _) in model.named_children()]
@@ -1292,33 +1561,48 @@ class RegenerationTrainer:
data_settings_backup = copy.deepcopy(self.data_settings) data_settings_backup = copy.deepcopy(self.data_settings)
pytorch_settings_backup = copy.deepcopy(self.pytorch_settings) pytorch_settings_backup = copy.deepcopy(self.pytorch_settings)
self.data_settings.drop_first = 99.5 + random.randint(0, 1000) self.data_settings.drop_first = int(64 + random.randint(0, 1000))
self.data_settings.shuffle = False self.data_settings.shuffle = False
self.data_settings.train_split = 1.0 self.data_settings.train_split = 1.0
self.pytorch_settings.batchsize = max(self.pytorch_settings.head_symbols, self.pytorch_settings.eye_symbols) self.pytorch_settings.batchsize = max(self.pytorch_settings.head_symbols, self.pytorch_settings.eye_symbols)
config_path = random.choice(self.data_settings.config_path) if isinstance(self.data_settings.config_path, (list, tuple)) else self.data_settings.config_path config_path = (
fiber_length = int(float(str(config_path).split('-')[4])/1000) random.choice(self.data_settings.config_path)
if isinstance(self.data_settings.config_path, (list, tuple))
else self.data_settings.config_path
)
# fiber_length = int(float(str(config_path).split("-")[4]) / 1000)
if not hasattr(self, "_plot_loader"): if not hasattr(self, "_plot_loader"):
self._plot_loader, _ = self.get_sliced_data( self._plot_loader, _ = self.get_sliced_data(
override={ override={
"num_symbols": self.pytorch_settings.batchsize, "num_symbols": self.pytorch_settings.batchsize,
"config_path": config_path, "config_path": config_path,
"shuffle": False, "shuffle": False,
"polarisations": (np.random.rand(1)*np.pi*2,), # "polarisations": (np.random.rand(1) * np.pi * 2,),
"randomise_polarisation": False, "polarisations": self.data_settings.polarisations,
"randomise_polarisation": self.data_settings.randomise_polarisations,
} }
) )
self._sps = self._plot_loader.dataset.samples_per_symbol self._sps = self._plot_loader.dataset.samples_per_symbol
fiber_length = float(self._plot_loader.dataset.config["fiber"]["length"])/1000
self.data_settings = data_settings_backup self.data_settings = data_settings_backup
self.pytorch_settings = pytorch_settings_backup self.pytorch_settings = pytorch_settings_backup
fiber_in, fiber_out, regen, timestamps = self.run_model(model, self._plot_loader) fiber_in, fiber_out, fiber_out_rot, angles, regen, timestamps = self.run_model(model, self._plot_loader)
fiber_in = fiber_in.view(-1, 2) fiber_in = fiber_in.view(-1, 2)
fiber_out = fiber_out.view(-1, 2) fiber_out = fiber_out.view(-1, 2)
fiber_out_rot = fiber_out_rot.view(-1, 2)
angles = angles.view(-1, 1)
angles = angles.real
angles = torch.fmod(angles, 2*torch.pi)
angles = torch.div(angles, 2*torch.pi)
angles = torch.repeat_interleave(angles, 2, dim=1)
regen = regen.view(-1, 2) regen = regen.view(-1, 2)
fiber_in = fiber_in.numpy() fiber_in = fiber_in.numpy()
fiber_out = fiber_out.numpy() fiber_out = fiber_out.numpy()
fiber_out_rot = fiber_out_rot.numpy()
angles = angles.numpy()
regen = regen.numpy() regen = regen.numpy()
timestamps = timestamps.numpy() timestamps = timestamps.numpy()
@@ -1327,31 +1611,34 @@ class RegenerationTrainer:
import gc import gc
head_fig = self._plot_model_response_head( head_fig = self._plot_model_response_head(
fiber_in[:self.pytorch_settings.head_symbols*self._sps], fiber_out[: self.pytorch_settings.head_symbols * self._sps],
fiber_out[:self.pytorch_settings.head_symbols*self._sps], fiber_in[: self.pytorch_settings.head_symbols * self._sps],
regen[:self.pytorch_settings.head_symbols*self._sps], regen[: self.pytorch_settings.head_symbols * self._sps],
timestamps=timestamps[:self.pytorch_settings.head_symbols*self._sps], angles[: self.pytorch_settings.head_symbols * self._sps],
timestamps=timestamps[: self.pytorch_settings.head_symbols * self._sps],
labels=("fiber out", "fiber in", "regen", "normed angle"),
sps=self._sps,
title_append=title_append + f" ({fiber_length} km)",
subtitle=subtitle,
show=show,
)
# raise NotImplementedError("Eye diagram not implemented")
eye_fig = self._plot_model_response_eye(
fiber_in[: self.pytorch_settings.eye_symbols * self._sps],
fiber_out[: self.pytorch_settings.eye_symbols * self._sps],
regen[: self.pytorch_settings.eye_symbols * self._sps],
timestamps=timestamps[: self.pytorch_settings.eye_symbols * self._sps],
labels=("fiber in", "fiber out", "regen"), labels=("fiber in", "fiber out", "regen"),
sps=self._sps, sps=self._sps,
title_append=title_append + f" ({fiber_length} km)", title_append=title_append + f" ({fiber_length} km)",
subtitle=subtitle, subtitle=subtitle,
show=show, show=show,
) )
# raise NotImplementedError("Eye diagram not implemented")
eye_fig = self._plot_model_response_eye( weight_fig = self._plot_model_weights(model, title_append=title_append, subtitle=subtitle, show=show)
fiber_in[:self.pytorch_settings.eye_symbols*self._sps],
fiber_out[:self.pytorch_settings.eye_symbols*self._sps],
regen[:self.pytorch_settings.eye_symbols*self._sps],
timestamps=timestamps[:self.pytorch_settings.eye_symbols*self._sps],
labels=("fiber in", "fiber out", "regen"),
sps=self._sps,
title_append=title_append + f" ({fiber_length} km)",
subtitle=subtitle,
show=show,
)
gc.collect() gc.collect()
return head_fig, eye_fig, power_fig return head_fig, eye_fig, weight_fig, power_fig
def build_title(self, number: int): def build_title(self, number: int):
title_append = f"epoch {number}" title_append = f"epoch {number}"
@@ -1361,7 +1648,7 @@ class RegenerationTrainer:
self.model_settings.overrides.get(f"n_hidden_nodes_{i}", -1) for i in range(model_n_hidden_layers) self.model_settings.overrides.get(f"n_hidden_nodes_{i}", -1) for i in range(model_n_hidden_layers)
] ]
model_dims.insert(0, input_dim) model_dims.insert(0, input_dim)
model_dims.append(2) model_dims.append(self.model_settings.output_dim)
model_dims = [str(dim) for dim in model_dims] model_dims = [str(dim) for dim in model_dims]
model_activation_func = self.model_settings.model_activation_func model_activation_func = self.model_settings.model_activation_func
model_dtype = self.data_settings.dtype model_dtype = self.data_settings.dtype

View File

@@ -0,0 +1,217 @@
from pathlib import Path
import sys
from matplotlib import pyplot as plt
import numpy as np
import torch
import util
from hypertraining.settings import GlobalSettings, DataSettings, ModelSettings, OptimizerSettings, PytorchSettings
from hypertraining import models
# def move_to_location_in_size(array, location, size):
# array_x, array_y = array.shape
# location_x, location_y = location
# size_x, size_y = size
# left = location_x
# right = size_x - array_x - location_x
# top = location_y
# bottom = size_y - array_y - location_y
# return np.pad(
# array,
# (
# (left, right),
# (top, bottom),
# ),
# constant_values=(-np.inf, -np.inf),
# )
def register_puccs_cmap(puccs_path=None):
puccs_path = Path(__file__).resolve().parent / 'puccs.csv' if puccs_path is None else puccs_path
colors = []
# keys = None
with open(puccs_path, "r") as f:
for i, line in enumerate(f.readlines()):
elements = tuple(line.split(","))
# if i == 0:
# # keys = elements
# continue
# else:
try:
colors.append(tuple(map(float, elements[4:])))
except ValueError:
continue
# colors = []
# for current in puccs_csv_data:
# colors.append(tuple(current[4:]))
from matplotlib.colors import LinearSegmentedColormap
import matplotlib as mpl
mpl.colormaps.register(LinearSegmentedColormap.from_list('puccs', colors))
def pad_to_size(array, size):
if not hasattr(size, "__len__"):
size = (size, size)
left = (
(size[0] - array.shape[0] + 1) // 2 if size[0] is not None else 0
)
right = (
(size[0] - array.shape[0]) // 2 if size[0] is not None else 0
)
top = (
(size[1] - array.shape[1] + 1) // 2 if size[1] is not None else 0
)
bottom = (
(size[1] - array.shape[1]) // 2 if size[1] is not None else 0
)
array: np.ndarray = array
if array.ndim == 2:
return np.pad(
array,
(
(left, right),
(top, bottom),
),
constant_values=(np.nan, np.nan),
)
elif array.ndim == 3:
return np.pad(
array,
(
(left, right),
(top, bottom),
(0,0)
),
constant_values=(np.nan, np.nan),
)
def model_plot(model_path, show=True):
torch.serialization.add_safe_globals([
*util.complexNN.__all__,
GlobalSettings,
DataSettings,
ModelSettings,
OptimizerSettings,
PytorchSettings,
models.regenerator,
torch.nn.utils.parametrizations.orthogonal,
])
checkpoint_dict = torch.load(model_path, weights_only=True)
dims = checkpoint_dict["model_kwargs"].pop("dims")
model = models.regenerator(*dims, **checkpoint_dict["model_kwargs"])
model.load_state_dict(checkpoint_dict["model_state_dict"], strict=False)
model_params = []
plots = []
max_size = np.max(dims)
# max_act_size = np.max(dims[1:])
# angles = [None, None]
# weights = [None, None]
for num, (layer_name, layer) in enumerate(model.named_children()):
# each layer contains an "ONN" layer and an "activation" layer
# activation layer is approximately the same for all layers and nodes -> rotation by 90 degrees
onn_weights = layer.ONN.weight
onn_weights = onn_weights.detach().cpu().numpy()
onn_values = np.abs(onn_weights).real
onn_angles = np.mod(np.angle(onn_weights), 2*np.pi).real
model_params.append({layer_name: onn_weights})
plots.append({layer_name: (num, onn_values, onn_angles)})#, act_values, act_angles)})
# fig, axs = plt.subplots(3, len(model_params)*2-1, figsize=(20, 5))
for plot in plots:
layer_name, (num, onn_values, onn_angles) = plot.popitem()
if num == 0:
value_img = onn_values
angle_img = onn_angles
onn_angles = pad_to_size(onn_angles, (max_size, None))
onn_values = pad_to_size(onn_values, (max_size, None))
else:
onn_values = pad_to_size(onn_values, (max_size, onn_values.shape[1]+1))
onn_angles = pad_to_size(onn_angles, (max_size, onn_angles.shape[1]+1))
value_img = np.concatenate((value_img, onn_values), axis=1)
angle_img = np.concatenate((angle_img, onn_angles), axis=1)
value_img = np.ma.array(value_img, mask=np.isnan(value_img))
angle_img = np.ma.array(angle_img, mask=np.isnan(angle_img))
# from cmcrameri import cm
from cmap import Colormap as cm
import scicomap as sc
# from matplotlib import colors as mcolors
# alpha_map = mcolors.LinearSegmentedColormap(
# 'alphamap',
# {
# 'red': [(0, 0, 0), (1, 0, 0)],
# 'green': [(0, 0, 0), (1, 0, 0)],
# 'blue': [(0, 0, 0), (1, 0, 0)],
# 'alpha': [
# (0, 1, 1),
# # (0.2, 0.2, 0.1),
# (1, 0, 0)
# ]
# }
# )
# alpha_map.set_bad(color="#AAAAAA")
from mpl_toolkits.axes_grid1 import make_axes_locatable
fig, axs = plt.subplots(1, 2, sharex=True, sharey=True, figsize=(16, 5))
# fig.tight_layout()
dividers = map(make_axes_locatable, axs)
caxs = list(map(lambda d: d.append_axes("right", size="5%", pad=0.05), dividers))
# masked_value_img = np.ma.masked_where(np.isnan(value_img), value_img)
masked_value_img = value_img
cmap = cm('google:turbo').to_matplotlib()
# cmap = sc.ScicoSequential("rainbow").get_mpl_color_map()
cmap.set_bad(color="#AAAAAA")
im_val = axs[0].imshow(masked_value_img, cmap=cmap, vmin=0, vmax=1)
fig.colorbar(im_val, cax=caxs[0], orientation="vertical")
masked_angle_img = np.ma.masked_where(np.isnan(angle_img), angle_img)
# cmap = cm('crameri:romao').to_matplotlib()
# cmap = plt.get_cmap('puccs')
# cmap = sc.ScicoCircular("colorwheel").get_mpl_color_map()
cmap = cm('colorcet:CET_C8').to_matplotlib()
cmap.set_bad(color="#AAAAAA")
im_ang = axs[1].imshow(masked_angle_img, cmap=cmap, vmin=0, vmax=2*np.pi)
cbar = fig.colorbar(im_ang, cax=caxs[1], orientation="vertical", ticks=[i/8 * 2*np.pi for i in range(9)])
cbar.ax.set_yticklabels(["0", "π/4", "π/2", "3π/4", "π", "5π/4", "3π/2", "7π/4", ""])
# im_ang_w = axs[2].imshow(masked_angle_img, cmap=cmap)
# im_ang_w = axs[2].imshow(masked_value_img, cmap=alpha_map)
axs[0].axis("off")
axs[1].axis("off")
# axs[2].axis("off")
axs[0].set_title("Values")
axs[1].set_title("Angles")
# axs[2].set_title("Values and Angles")
...
if show:
plt.show()
return fig
# model = models.regenerator(*dims, **model_kwargs)
if __name__ == "__main__":
register_puccs_cmap()
if len(sys.argv) > 1:
model_plot(sys.argv[1])
else:
print("Please provide a model path as an argument")
# model_plot(".models/best_20250114_224234.tar")

View File

@@ -0,0 +1,102 @@
"x","L","a","b","R","G","B"
0.,0.5187848173343539,0.6399990176455989,0.67,0.8889427469969852,0.22673227640012172,0.
0.01,0.5374499525557803,0.604014067614707,0.6777967519386492,0.8956274406155226,0.27553288030331824,0.
0.02,0.5560867887452998,0.5680836759482211,0.6855816828789898,0.9019507507843885,0.318608215541461,0.
0.03,0.5746877595125583,0.5322224300667823,0.6933516322080414,0.907905487190649,0.3580633000693721,0.
0.04,0.5932314662487472,0.49647158484797804,0.7010976613543587,0.9134808162089558,0.3949845524063657,0.
0.05,0.6117000836392819,0.46086550613202343,0.7088123243737041,0.918668356138916,0.43002019316005363,0.
0.06,0.6300828534995973,0.4254249348741487,0.7164911273850869,0.923462736751354,0.4635961938811463,0.
0.07,0.6483763163456417,0.3901565406944371,0.7241326253017896,0.9278609626724071,0.49601354353255284,0.
0.08,0.6665840140182806,0.3550534951951814,0.7317382976124045,0.9318616057744784,0.5274983630587982,0.
0.09,0.6847162776119433,0.3200958808181962,0.7393124597949372,0.9354640163365924,0.5582303922647159,0.
0.1,0.7027902128942014,0.2852507189547545,0.7468622572263107,0.9386675557407496,0.5883604892249517,0.004034952213848706
0.11,0.7208298719332069,0.25047163906104203,0.7543977368741345,0.9414708123927996,0.6180221032545026,0.016031521294251994
0.12,0.7388665670611175,0.2156982733607376,0.7619319784446927,0.943870754968487,0.6473392272576862,0.029857267582036696
0.13,0.7569392765472108,0.18085547473834482,0.7694812638396673,0.9458617774020323,0.676432172396153,0.045365670193636125
0.14,0.7750950944867471,0.14585244938794778,0.7770652650825484,0.9474345911958609,0.7054219201084561,0.06017985923530026
0.15,0.793389684293558,0.11058188251425949,0.7847072337503834,0.9485749196617762,0.7344334940032564,0.07418869502646075
0.16,0.8117919447684838,0.07510373484536464,0.792394178330817,0.9492596163836376,0.7634480277996188,0.08767517868137237
0.17,0.8293050962981561,0.03629277424762101,0.799038155466063,0.9462308253550155,0.7922009241807345,0.10066327128139077
0.18,0.8213303100752708,-0.0062517290795987,0.7879999288492758,0.9088702681901394,0.7940579017644396,0.10139639009534024
0.19,0.8134831311534617,-0.048115463155645855,0.7771383286984362,0.8716809050191757,0.7954897210083888,0.10232311621802098
0.2,0.80558613530069,-0.0902449644291895,0.7662077749032042,0.8337524177888596,0.7965471523787845,0.10344968926026826
0.21,0.7975860185564765,-0.13292460297117392,0.7551344872795225,0.7947193410849823,0.7972381033243311,0.10477682283894393
0.22,0.7894147026971006,-0.17651756772919341,0.7438242359834689,0.7540941866826836,0.7975605026647324,0.10631182441371936
0.23,0.7809997374598548,-0.2214103719409295,0.7321767396537806,0.7112894518675287,0.7974995317311054,0.1080672415170634
0.24,0.7722646970273015,-0.2680107379394189,0.7200862142018722,0.6655745739336695,0.7970267795229349,0.11006041388465265
0.25,0.7631307298557146,-0.3167393290089981,0.7074435179925446,0.6160047476007512,0.7960993904970947,0.11231257117602686
0.26,0.7535192192483822,-0.36801555555407994,0.6941398344519211,0.5612859274945571,0.794659599537827,0.11484733363789801
0.27,0.7433557597838075,-0.42223636134393283,0.6800721760037781,0.4994862901720824,0.7926351396848288,0.11768844813479104
0.28,0.732575139048096,-0.479749646583324,0.6651502794883674,0.42731393423789277,0.7899410218414098,0.12085678487511567
0.29,0.7211269294461059,-0.5408244362880141,0.6493043460161184,0.3378265607222193,0.786483110019224,0.124366774034814
0.3,0.7090756028785993,-0.6051167807996883,0.6326236137723747,0.2098475715121697,0.7821998608677176,0.12819222127525928
0.31,0.7094510768540225,-0.6165036055456403,0.5630307498747129,0.15061488620640032,0.7845112116922692,0.21943537230975235
0.32,0.7174669421288304,-0.5917687864932311,0.4797229624661701,0.18766933782916642,0.7905828987725732,0.31091344246312086
0.33,0.7249009746435938,-0.5688293479200438,0.40246208306061504,0.21160609617940718,0.7962175427587832,0.38519766326885596
0.34,0.7317072855135611,-0.5478268906666535,0.3317250285377912,0.22717569971119178,0.8013847719431052,0.4490960048955565
0.35,0.7379328517830899,-0.5286164561226088,0.26702357292455026,0.23690087622812972,0.8061220291668977,0.5056371468159843
0.36,0.7436229063122554,-0.5110584677642499,0.20788761731555405,0.24226377668817778,0.8104638164122776,0.5563570758573497
0.37,0.7488251728809415,-0.4950056627547577,0.15382117501783654,0.24424372086048424,0.8144455902164638,0.6022301663745243
0.38,0.7535943992285348,-0.48028910419451787,0.10425526029155024,0.24352232677523483,0.818107753931944,0.6440238320299774
0.39,0.757994865186593,-0.4667104416936734,0.05852182167144754,0.240562414747303,0.8214980148949816,0.6824536572462205
0.4,0.7620994844391137,-0.4540446830999986,0.015863077249098356,0.2356325204239052,0.8246710357361025,0.7182393675419642
0.41,0.7659871096124125,-0.4420485102716773,-0.024540477496154123,0.22880568593963535,0.8276865975886148,0.7521146815529202
0.42,0.7697410958994951,-0.4304647113488041,-0.06355514164248566,0.21993360985514526,0.8306086550266585,0.7848331944479765
0.43,0.773446484628189,-0.4190308715098135,-0.10206473803580057,0.20858849290850018,0.833503273690861,0.8171544357676854
0.44,0.7771893686864673,-0.4074813310994203,-0.14096401824224686,0.1939295692427068,0.8364382500400466,0.8498448067259188
0.45,0.7810574093604746,-0.3955455908045306,-0.18116403397486242,0.17438366103820427,0.839483669055626,0.8836865023336339
0.46,0.7851360804917298,-0.3829599011818591,-0.2235531031349741,0.14679145002531463,0.8427091517444469,0.9194481212717681
0.47,0.789525027020907,-0.369416784561489,-0.26916682191206776,0.10278921007810798,0.8461971304126237,0.9580316568065935
0.48,0.7942371698732826,-0.35487637041943493,-0.3181394757087982,0.0013920913109500188,0.8499626968466341,0.9995866371771526
0.49,0.7773897680996302,-0.31852357140025195,-0.34537976514700053,0.10740420703601522,0.8254781216972907,1.
0.5,0.7604011244310231,-0.28211213216592784,-0.3722846952738428,0.1581725581872408,0.8008522647497104,1.
0.51,0.7433440454962605,-0.2455540169176899,-0.3992980063927199,0.19300141807932156,0.7761561224913385,1.
0.52,0.7262590833969331,-0.20893614020926626,-0.42635547610418184,0.2194621842292243,0.751443124097109,1.
0.53,0.709058602701224,-0.17207067467417486,-0.453595892719742,0.2405673704012788,0.7265803324554873,1.
0.54,0.6915768892539101,-0.1346024482921609,-0.48128169789479536,0.25788347992973676,0.701321051230534,1.
0.55,0.6736331627810209,-0.09614399811510127,-0.5096991935104321,0.2722888922216317,0.6753950894563805,1.
0.56,0.6551463184003872,-0.05652149358027936,-0.5389768254408652,0.28422807900785235,0.6486730893521468,1.
0.57,0.6361671326276888,-0.01584376303510615,-0.5690341788729347,0.293907374075009,0.6212117649042732,1.
0.58,0.6168396823565967,0.025580396234342995,-0.5996430791016598,0.301442767979156,0.5931976878638505,1.
0.59,0.5973210287815495,0.06741435793529688,-0.6305547881733555,0.30694603901024253,0.5648312189065924,1.
0.6,0.5777303704171711,0.10940264614179468,-0.661580531294122,0.3105418468883679,0.5362525958007331,1.
0.61,0.5581475370499237,0.15137416317967575,-0.6925938819599547,0.3123531986526998,0.5075386530652202,1.
0.62,0.5386227795100639,0.19322120739317136,-0.7235152578861672,0.31248922600720636,0.4787151440558522,1.
0.63,0.5191666876024412,0.23492108185347996,-0.754327887989376,0.31103663081260624,0.44973844514160927,1.
0.64,0.4996990584326256,0.2766456839100268,-0.7851587896650079,0.30803814950244496,0.4204116611935119,1.
0.65,0.479957679121191,0.3189570094767831,-0.8164232296840259,0.30343473603466015,0.390226489453496,1.
0.66,0.4600072725872886,0.3617163391430824,-0.8480187063016573,0.29717122075330515,0.3591178757512998,1.
0.67,0.44600100870220305,0.4113853615984094,-0.8697728377551008,0.3178994129506999,0.3295740682997879,1.
0.68,0.4574651571354146,0.44026390446569547,-0.8504539292487465,0.3842479358768364,0.3280946443367561,1.
0.69,0.4691809168948424,0.46977626401045774,-0.830711015748157,0.44293649140770447,0.3260767554252525,1.
0.7,0.4811696900083858,0.49997635259991063,-0.8105080314416201,0.49708450874457527,0.3234487047238236,1.
0.71,0.49350094811609174,0.5310391714342613,-0.7897279055963483,0.5485591109413528,0.3201099534066949,1.
0.72,0.5062548753068121,0.5631667067020758,-0.7682355153041539,0.5985798481027601,0.3159263917472715,1.
0.73,0.5195243020949684,0.5965928013272943,-0.7458744264238399,0.6480500606439057,0.31071717884730565,1.
0.74,0.5334043922713477,0.6315571758288618,-0.7224842728734379,0.6976685401842261,0.3042411890803418,1.
0.75,0.5479805812358602,0.6682750446095802,-0.697921082452685,0.7479712773579563,0.29618040787504757,1.
0.76,0.5633244502526606,0.7069267230777347,-0.6720642293775535,0.7993701361353484,0.28611136999256687,1.
0.77,0.5794956601139,0.7476624986056212,-0.6448131757501174,0.8521918014427678,0.2734527325942473,1.
0.78,0.5965429098573916,0.7906050455688622,-0.6160858559672187,0.9067003897516911,0.2573693489198746,1.
0.79,0.6145761476424179,0.8360313267658297,-0.5856969899409387,0.963334644317004,0.23648492980159264,1.
0.8,0.6232910688128902,0.859291371252556,-0.5300995185388214,1.,0.21867949406239662,0.9712088595948508
0.81,0.6159984336377875,0.8439887543380684,-0.44635440435952856,1.,0.21606849746358275,0.9041480210597966
0.82,0.6091642745073532,0.8296481879180277,-0.36787420852419694,1.,0.21421830096504035,0.8419706002336461
0.83,0.6025478038652375,0.8157644115969636,-0.2918938425681935,1.,0.21295365915197917,0.7823908751330636
0.84,0.5961857222953111,0.8024144366282877,-0.21883475834162458,0.9971140114799418,0.21220068235083267,0.7256713129328118
0.85,0.5900921771070883,0.7896279492437488,-0.1488594167412921,0.993273906363258,0.2118788857127918,0.671860243327784
0.86,0.5842771639541229,0.7774259239818333,-0.08208260304413262,0.9887084084529413,0.21191070453347688,0.6209624706933893
0.87,0.578741582584259,0.7658102488427286,-0.018514649521559012,0.9835846378805114,0.2122246941077346,0.5728987835613306
0.88,0.5734741590353537,0.7547572669288056,0.04197390858426542,0.9780378159372328,0.21275878699579343,0.5274829957183049
0.89,0.5684517008574971,0.7442183119942206,0.09964940221121898,0.9721670725313721,0.21346242315895625,0.4844270603851604
0.9,0.5636419856510335,0.7341257696545772,0.15488185789614228,0.9660363209686843,0.21429691147008262,0.4433660148378527
0.91,0.5590069340453534,0.7243997354573974,0.20810856081277884,0.9596781387247791,0.2152344151262528,0.4038812338146013
0.92,0.5545051525321143,0.7149533506766244,0.25980485409830323,0.9530986696850675,0.21625626438013962,0.3655130449917989
0.93,0.5500961975299247,0.705701749880514,0.3104351723857584,0.9462863346513658,0.21735046958786286,0.327780364198278
0.94,0.545740378056064,0.6965616468647046,0.36045530782708896,0.93921469089265,0.21851014470332586,0.29014917175372823
0.95,0.5414004092067859,0.6874548042588865,0.41029342232076466,0.9318478255642132,0.21973168075163751,0.2519897371806688
0.96,0.5370416605957644,0.6783085548415655,0.46034719456417006,0.9241434776436454,0.22101341980094052,0.2124579038400577
0.97,0.5326309593934517,0.6690532898786764,0.5109975653738162,0.9160532016485884,0.22235495330179011,0.17018252385769012
0.98,0.5281374148557197,0.6596241892863608,0.5625992691950712,0.90752576202319,0.22375597459867458,0.1223073280126531
0.99,0.5235317096396147,0.6499597345521199,0.615488972291106,0.8985077346125597,0.22521565729028564,0.05933950582860665
1.,0.5187848173343539,0.6399990176455989,0.67,0.8889427469969852,0.22673227640012172,0.
1 x L a b R G B
2 0. 0.5187848173343539 0.6399990176455989 0.67 0.8889427469969852 0.22673227640012172 0.
3 0.01 0.5374499525557803 0.604014067614707 0.6777967519386492 0.8956274406155226 0.27553288030331824 0.
4 0.02 0.5560867887452998 0.5680836759482211 0.6855816828789898 0.9019507507843885 0.318608215541461 0.
5 0.03 0.5746877595125583 0.5322224300667823 0.6933516322080414 0.907905487190649 0.3580633000693721 0.
6 0.04 0.5932314662487472 0.49647158484797804 0.7010976613543587 0.9134808162089558 0.3949845524063657 0.
7 0.05 0.6117000836392819 0.46086550613202343 0.7088123243737041 0.918668356138916 0.43002019316005363 0.
8 0.06 0.6300828534995973 0.4254249348741487 0.7164911273850869 0.923462736751354 0.4635961938811463 0.
9 0.07 0.6483763163456417 0.3901565406944371 0.7241326253017896 0.9278609626724071 0.49601354353255284 0.
10 0.08 0.6665840140182806 0.3550534951951814 0.7317382976124045 0.9318616057744784 0.5274983630587982 0.
11 0.09 0.6847162776119433 0.3200958808181962 0.7393124597949372 0.9354640163365924 0.5582303922647159 0.
12 0.1 0.7027902128942014 0.2852507189547545 0.7468622572263107 0.9386675557407496 0.5883604892249517 0.004034952213848706
13 0.11 0.7208298719332069 0.25047163906104203 0.7543977368741345 0.9414708123927996 0.6180221032545026 0.016031521294251994
14 0.12 0.7388665670611175 0.2156982733607376 0.7619319784446927 0.943870754968487 0.6473392272576862 0.029857267582036696
15 0.13 0.7569392765472108 0.18085547473834482 0.7694812638396673 0.9458617774020323 0.676432172396153 0.045365670193636125
16 0.14 0.7750950944867471 0.14585244938794778 0.7770652650825484 0.9474345911958609 0.7054219201084561 0.06017985923530026
17 0.15 0.793389684293558 0.11058188251425949 0.7847072337503834 0.9485749196617762 0.7344334940032564 0.07418869502646075
18 0.16 0.8117919447684838 0.07510373484536464 0.792394178330817 0.9492596163836376 0.7634480277996188 0.08767517868137237
19 0.17 0.8293050962981561 0.03629277424762101 0.799038155466063 0.9462308253550155 0.7922009241807345 0.10066327128139077
20 0.18 0.8213303100752708 -0.0062517290795987 0.7879999288492758 0.9088702681901394 0.7940579017644396 0.10139639009534024
21 0.19 0.8134831311534617 -0.048115463155645855 0.7771383286984362 0.8716809050191757 0.7954897210083888 0.10232311621802098
22 0.2 0.80558613530069 -0.0902449644291895 0.7662077749032042 0.8337524177888596 0.7965471523787845 0.10344968926026826
23 0.21 0.7975860185564765 -0.13292460297117392 0.7551344872795225 0.7947193410849823 0.7972381033243311 0.10477682283894393
24 0.22 0.7894147026971006 -0.17651756772919341 0.7438242359834689 0.7540941866826836 0.7975605026647324 0.10631182441371936
25 0.23 0.7809997374598548 -0.2214103719409295 0.7321767396537806 0.7112894518675287 0.7974995317311054 0.1080672415170634
26 0.24 0.7722646970273015 -0.2680107379394189 0.7200862142018722 0.6655745739336695 0.7970267795229349 0.11006041388465265
27 0.25 0.7631307298557146 -0.3167393290089981 0.7074435179925446 0.6160047476007512 0.7960993904970947 0.11231257117602686
28 0.26 0.7535192192483822 -0.36801555555407994 0.6941398344519211 0.5612859274945571 0.794659599537827 0.11484733363789801
29 0.27 0.7433557597838075 -0.42223636134393283 0.6800721760037781 0.4994862901720824 0.7926351396848288 0.11768844813479104
30 0.28 0.732575139048096 -0.479749646583324 0.6651502794883674 0.42731393423789277 0.7899410218414098 0.12085678487511567
31 0.29 0.7211269294461059 -0.5408244362880141 0.6493043460161184 0.3378265607222193 0.786483110019224 0.124366774034814
32 0.3 0.7090756028785993 -0.6051167807996883 0.6326236137723747 0.2098475715121697 0.7821998608677176 0.12819222127525928
33 0.31 0.7094510768540225 -0.6165036055456403 0.5630307498747129 0.15061488620640032 0.7845112116922692 0.21943537230975235
34 0.32 0.7174669421288304 -0.5917687864932311 0.4797229624661701 0.18766933782916642 0.7905828987725732 0.31091344246312086
35 0.33 0.7249009746435938 -0.5688293479200438 0.40246208306061504 0.21160609617940718 0.7962175427587832 0.38519766326885596
36 0.34 0.7317072855135611 -0.5478268906666535 0.3317250285377912 0.22717569971119178 0.8013847719431052 0.4490960048955565
37 0.35 0.7379328517830899 -0.5286164561226088 0.26702357292455026 0.23690087622812972 0.8061220291668977 0.5056371468159843
38 0.36 0.7436229063122554 -0.5110584677642499 0.20788761731555405 0.24226377668817778 0.8104638164122776 0.5563570758573497
39 0.37 0.7488251728809415 -0.4950056627547577 0.15382117501783654 0.24424372086048424 0.8144455902164638 0.6022301663745243
40 0.38 0.7535943992285348 -0.48028910419451787 0.10425526029155024 0.24352232677523483 0.818107753931944 0.6440238320299774
41 0.39 0.757994865186593 -0.4667104416936734 0.05852182167144754 0.240562414747303 0.8214980148949816 0.6824536572462205
42 0.4 0.7620994844391137 -0.4540446830999986 0.015863077249098356 0.2356325204239052 0.8246710357361025 0.7182393675419642
43 0.41 0.7659871096124125 -0.4420485102716773 -0.024540477496154123 0.22880568593963535 0.8276865975886148 0.7521146815529202
44 0.42 0.7697410958994951 -0.4304647113488041 -0.06355514164248566 0.21993360985514526 0.8306086550266585 0.7848331944479765
45 0.43 0.773446484628189 -0.4190308715098135 -0.10206473803580057 0.20858849290850018 0.833503273690861 0.8171544357676854
46 0.44 0.7771893686864673 -0.4074813310994203 -0.14096401824224686 0.1939295692427068 0.8364382500400466 0.8498448067259188
47 0.45 0.7810574093604746 -0.3955455908045306 -0.18116403397486242 0.17438366103820427 0.839483669055626 0.8836865023336339
48 0.46 0.7851360804917298 -0.3829599011818591 -0.2235531031349741 0.14679145002531463 0.8427091517444469 0.9194481212717681
49 0.47 0.789525027020907 -0.369416784561489 -0.26916682191206776 0.10278921007810798 0.8461971304126237 0.9580316568065935
50 0.48 0.7942371698732826 -0.35487637041943493 -0.3181394757087982 0.0013920913109500188 0.8499626968466341 0.9995866371771526
51 0.49 0.7773897680996302 -0.31852357140025195 -0.34537976514700053 0.10740420703601522 0.8254781216972907 1.
52 0.5 0.7604011244310231 -0.28211213216592784 -0.3722846952738428 0.1581725581872408 0.8008522647497104 1.
53 0.51 0.7433440454962605 -0.2455540169176899 -0.3992980063927199 0.19300141807932156 0.7761561224913385 1.
54 0.52 0.7262590833969331 -0.20893614020926626 -0.42635547610418184 0.2194621842292243 0.751443124097109 1.
55 0.53 0.709058602701224 -0.17207067467417486 -0.453595892719742 0.2405673704012788 0.7265803324554873 1.
56 0.54 0.6915768892539101 -0.1346024482921609 -0.48128169789479536 0.25788347992973676 0.701321051230534 1.
57 0.55 0.6736331627810209 -0.09614399811510127 -0.5096991935104321 0.2722888922216317 0.6753950894563805 1.
58 0.56 0.6551463184003872 -0.05652149358027936 -0.5389768254408652 0.28422807900785235 0.6486730893521468 1.
59 0.57 0.6361671326276888 -0.01584376303510615 -0.5690341788729347 0.293907374075009 0.6212117649042732 1.
60 0.58 0.6168396823565967 0.025580396234342995 -0.5996430791016598 0.301442767979156 0.5931976878638505 1.
61 0.59 0.5973210287815495 0.06741435793529688 -0.6305547881733555 0.30694603901024253 0.5648312189065924 1.
62 0.6 0.5777303704171711 0.10940264614179468 -0.661580531294122 0.3105418468883679 0.5362525958007331 1.
63 0.61 0.5581475370499237 0.15137416317967575 -0.6925938819599547 0.3123531986526998 0.5075386530652202 1.
64 0.62 0.5386227795100639 0.19322120739317136 -0.7235152578861672 0.31248922600720636 0.4787151440558522 1.
65 0.63 0.5191666876024412 0.23492108185347996 -0.754327887989376 0.31103663081260624 0.44973844514160927 1.
66 0.64 0.4996990584326256 0.2766456839100268 -0.7851587896650079 0.30803814950244496 0.4204116611935119 1.
67 0.65 0.479957679121191 0.3189570094767831 -0.8164232296840259 0.30343473603466015 0.390226489453496 1.
68 0.66 0.4600072725872886 0.3617163391430824 -0.8480187063016573 0.29717122075330515 0.3591178757512998 1.
69 0.67 0.44600100870220305 0.4113853615984094 -0.8697728377551008 0.3178994129506999 0.3295740682997879 1.
70 0.68 0.4574651571354146 0.44026390446569547 -0.8504539292487465 0.3842479358768364 0.3280946443367561 1.
71 0.69 0.4691809168948424 0.46977626401045774 -0.830711015748157 0.44293649140770447 0.3260767554252525 1.
72 0.7 0.4811696900083858 0.49997635259991063 -0.8105080314416201 0.49708450874457527 0.3234487047238236 1.
73 0.71 0.49350094811609174 0.5310391714342613 -0.7897279055963483 0.5485591109413528 0.3201099534066949 1.
74 0.72 0.5062548753068121 0.5631667067020758 -0.7682355153041539 0.5985798481027601 0.3159263917472715 1.
75 0.73 0.5195243020949684 0.5965928013272943 -0.7458744264238399 0.6480500606439057 0.31071717884730565 1.
76 0.74 0.5334043922713477 0.6315571758288618 -0.7224842728734379 0.6976685401842261 0.3042411890803418 1.
77 0.75 0.5479805812358602 0.6682750446095802 -0.697921082452685 0.7479712773579563 0.29618040787504757 1.
78 0.76 0.5633244502526606 0.7069267230777347 -0.6720642293775535 0.7993701361353484 0.28611136999256687 1.
79 0.77 0.5794956601139 0.7476624986056212 -0.6448131757501174 0.8521918014427678 0.2734527325942473 1.
80 0.78 0.5965429098573916 0.7906050455688622 -0.6160858559672187 0.9067003897516911 0.2573693489198746 1.
81 0.79 0.6145761476424179 0.8360313267658297 -0.5856969899409387 0.963334644317004 0.23648492980159264 1.
82 0.8 0.6232910688128902 0.859291371252556 -0.5300995185388214 1. 0.21867949406239662 0.9712088595948508
83 0.81 0.6159984336377875 0.8439887543380684 -0.44635440435952856 1. 0.21606849746358275 0.9041480210597966
84 0.82 0.6091642745073532 0.8296481879180277 -0.36787420852419694 1. 0.21421830096504035 0.8419706002336461
85 0.83 0.6025478038652375 0.8157644115969636 -0.2918938425681935 1. 0.21295365915197917 0.7823908751330636
86 0.84 0.5961857222953111 0.8024144366282877 -0.21883475834162458 0.9971140114799418 0.21220068235083267 0.7256713129328118
87 0.85 0.5900921771070883 0.7896279492437488 -0.1488594167412921 0.993273906363258 0.2118788857127918 0.671860243327784
88 0.86 0.5842771639541229 0.7774259239818333 -0.08208260304413262 0.9887084084529413 0.21191070453347688 0.6209624706933893
89 0.87 0.578741582584259 0.7658102488427286 -0.018514649521559012 0.9835846378805114 0.2122246941077346 0.5728987835613306
90 0.88 0.5734741590353537 0.7547572669288056 0.04197390858426542 0.9780378159372328 0.21275878699579343 0.5274829957183049
91 0.89 0.5684517008574971 0.7442183119942206 0.09964940221121898 0.9721670725313721 0.21346242315895625 0.4844270603851604
92 0.9 0.5636419856510335 0.7341257696545772 0.15488185789614228 0.9660363209686843 0.21429691147008262 0.4433660148378527
93 0.91 0.5590069340453534 0.7243997354573974 0.20810856081277884 0.9596781387247791 0.2152344151262528 0.4038812338146013
94 0.92 0.5545051525321143 0.7149533506766244 0.25980485409830323 0.9530986696850675 0.21625626438013962 0.3655130449917989
95 0.93 0.5500961975299247 0.705701749880514 0.3104351723857584 0.9462863346513658 0.21735046958786286 0.327780364198278
96 0.94 0.545740378056064 0.6965616468647046 0.36045530782708896 0.93921469089265 0.21851014470332586 0.29014917175372823
97 0.95 0.5414004092067859 0.6874548042588865 0.41029342232076466 0.9318478255642132 0.21973168075163751 0.2519897371806688
98 0.96 0.5370416605957644 0.6783085548415655 0.46034719456417006 0.9241434776436454 0.22101341980094052 0.2124579038400577
99 0.97 0.5326309593934517 0.6690532898786764 0.5109975653738162 0.9160532016485884 0.22235495330179011 0.17018252385769012
100 0.98 0.5281374148557197 0.6596241892863608 0.5625992691950712 0.90752576202319 0.22375597459867458 0.1223073280126531
101 0.99 0.5235317096396147 0.6499597345521199 0.615488972291106 0.8985077346125597 0.22521565729028564 0.05933950582860665
102 1. 0.5187848173343539 0.6399990176455989 0.67 0.8889427469969852 0.22673227640012172 0.

View File

@@ -1,6 +1,8 @@
from datetime import datetime from datetime import datetime
import optuna import optuna
import torch
import util
from hypertraining.hypertraining import HyperTraining from hypertraining.hypertraining import HyperTraining
from hypertraining.settings import ( from hypertraining.settings import (
GlobalSettings, GlobalSettings,
@@ -16,24 +18,29 @@ global_settings = GlobalSettings(
) )
data_settings = DataSettings( data_settings = DataSettings(
config_path="data/*-128-16384-100000-0-0-17-0-PAM4-0.ini", # config_path="data/*-128-16384-100000-0-0-17-0-PAM4-0.ini",
config_path="data/20241204-131003-128-16384-100000-0-0-17-0-PAM4-0.ini",
dtype="complex64", dtype="complex64",
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber # symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
symbols=13, # study: single_core_regen_20241123_011232 # symbols=13, # study: single_core_regen_20241123_011232
# symbols = (3, 13),
symbols=4,
# output_size = (11, 32), # ballpark 26 taps -> 2 taps per input symbol -> 1 tap every 0.01m (model has 52 inputs) # output_size = (11, 32), # ballpark 26 taps -> 2 taps per input symbol -> 1 tap every 0.01m (model has 52 inputs)
output_size=26, # study: single_core_regen_20241123_011232 (model_input_dim/2) # output_size=26, # study: single_core_regen_20241123_011232 (model_input_dim/2)
output_size=(8, 30),
shuffle=True, shuffle=True,
in_out_delay=0, in_out_delay=0,
xy_delay=0, xy_delay=0,
drop_first=128 * 100, drop_first=256,
train_split=0.8, train_split=0.8,
randomise_polarisations=False,
) )
pytorch_settings = PytorchSettings( pytorch_settings = PytorchSettings(
epochs=10000, epochs=10,
batchsize=2**10, batchsize=2**10,
device="cuda", device="cuda",
dataloader_workers=12, dataloader_workers=4,
dataloader_prefetch=4, dataloader_prefetch=4,
summary_dir=".runs", summary_dir=".runs",
write_every=2**5, write_every=2**5,
@@ -43,28 +50,70 @@ pytorch_settings = PytorchSettings(
model_settings = ModelSettings( model_settings = ModelSettings(
output_dim=2, output_dim=2,
# n_hidden_layers = (3, 8), n_hidden_layers = (2, 5),
n_hidden_layers=4, n_hidden_nodes=(2, 16),
overrides={ model_activation_func="EOActivation",
"n_hidden_nodes_0": 8, dropout_prob=0,
"n_hidden_nodes_1": 6, model_layer_function="ONNRect",
"n_hidden_nodes_2": 4, model_layer_kwargs={"square": True},
"n_hidden_nodes_3": 8, # scale=(False, True),
}, scale=False,
model_activation_func="Mag", model_layer_parametrizations=[
# satabsT0=(1e-6, 1), {
"tensor_name": "weight",
"parametrization": util.complexNN.energy_conserving,
},
{
"tensor_name": "alpha",
"parametrization": util.complexNN.clamp,
},
{
"tensor_name": "gain",
"parametrization": util.complexNN.clamp,
"kwargs": {
"min": 0,
"max": float("inf"),
},
},
{
"tensor_name": "phase_bias",
"parametrization": util.complexNN.clamp,
"kwargs": {
"min": 0,
"max": 2 * torch.pi,
},
},
{
"tensor_name": "scales",
"parametrization": util.complexNN.clamp,
},
{
"tensor_name": "angle",
"parametrization": util.complexNN.clamp,
"kwargs": {
"min": -torch.pi,
"max": torch.pi,
},
},
{
"tensor_name": "loss",
"parametrization": util.complexNN.clamp,
},
],
) )
optimizer_settings = OptimizerSettings( optimizer_settings = OptimizerSettings(
optimizer="Adam", optimizer="AdamW",
# learning_rate = (1e-5, 1e-1), optimizer_kwargs={
learning_rate=5e-3 "lr": 5e-3,
# learning_rate=5e-4, "amsgrad": True,
# "weight_decay": 1e-7,
},
) )
optuna_settings = OptunaSettings( optuna_settings = OptunaSettings(
n_trials=1, n_trials=1024,
n_workers=1, n_workers=8,
timeout=3600, timeout=3600,
directions=("minimize",), directions=("minimize",),
metrics_names=("mse",), metrics_names=("mse",),

View File

@@ -1,4 +1,4 @@
from datetime import datetime # from datetime import datetime
from pathlib import Path from pathlib import Path
import matplotlib import matplotlib
import numpy as np import numpy as np
@@ -13,7 +13,7 @@ from hypertraining.settings import (
OptimizerSettings, OptimizerSettings,
) )
from hypertraining.training import RegenerationTrainer, PolarizationTrainer from hypertraining.training import RegenerationTrainer#, PolarizationTrainer
# import torch # import torch
import json import json
@@ -26,25 +26,39 @@ global_settings = GlobalSettings(
) )
data_settings = DataSettings( data_settings = DataSettings(
config_path="data/*-128-16384-10000-0-0-0-0-PAM4-0-0.4.ini", # config_path="data/20250115-114319-128-16384-inf-100000-0-0-0-0-PAM4-1-0-10.ini", # no effects enabled - baseline
# config_path=[f"data/20241202-*-128-16384-{length}-0-0-17-0-PAM4-0.ini" for length in range(48000, 53000, 1000)], # config_path = "data/20250115-233553-128-16384-1060.0-100000-0-0.2-17.0-0.058-PAM4-1.0-0.0-10.ini", # dispersion + slope only
# config_path="data/20250115-115836-128-16384-60.0-100000-0-0.2-17-0.058-PAM4-1000-0.2-10.ini", # all linear effects enabled with realistic values + noise + pmd (delta_beta=0.2) + ortho_error = 0.1
# config_path="data/20250118-225840-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini", # a)
# config_path="data/20250116-214606-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini", # b)
# config_path="data/20250116-214547-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini", # c)
# config_path="data/20250117-143926-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini", # d) 10ps dgd
config_path="data/20250120-105720-128-16384-inf-100000-0-0.2-17-0.058-PAM4-0-0-10.ini", # d) 10ns
# config_path="data/20250114-215547-128-16384-60.0-100000-1.15-0.2-17-0.058-PAM4-1-0-10.ini", # with gamma=1.15, 2.5dBm launch power, no pmd
dtype="complex64", dtype="complex64",
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber # symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
symbols=13, # study: single_core_regen_20241123_011232 symbols=4, # study: single_core_regen_20241123_011232 -> taps spread over 4 symbols @ 10GBd
# output_size = (11, 32), # 26 taps -> 2 taps per input symbol -> 1 tap every 1cm (model has 52 inputs (x/y)) # output_size = (11, 32), # 26 taps -> 2 taps per input symbol -> 1 tap every 1cm (model has 52 inputs (x/y))
output_size=26, # study: single_core_regen_20241123_011232 (model_input_dim/2) output_size=20, # = model_input_dim/2 study: single_core_regen_20241123_011232
shuffle=True, shuffle=False,
drop_first=64, drop_first=256,
drop_last=256,
train_split=0.8, train_split=0.8,
randomise_polarisations=True, randomise_polarisations=False,
polarisations=False,
# cross_pol_interference=0.01,
osnr=16, #16dB due to amplification with NF 5
) )
pytorch_settings = PytorchSettings( pytorch_settings = PytorchSettings(
epochs=10000, epochs=1000,
batchsize=2**14, batchsize=2**13,
device="cuda", device="cuda",
dataloader_workers=16, dataloader_workers=32,
dataloader_prefetch=8, dataloader_prefetch=4,
summary_dir=".runs", summary_dir=".runs",
write_every=2**5, write_every=2**5,
save_models=True, save_models=True,
@@ -53,80 +67,51 @@ pytorch_settings = PytorchSettings(
model_settings = ModelSettings( model_settings = ModelSettings(
output_dim=2, output_dim=2,
n_hidden_layers=5, n_hidden_layers=3,
overrides={ overrides={
# "hidden_layer_dims": (8, 8, 4, 4), # "hidden_layer_dims": (8, 8, 4, 4),
"n_hidden_nodes_0": 8, "n_hidden_nodes_0": 16,
"n_hidden_nodes_1": 8, "n_hidden_nodes_1": 8,
"n_hidden_nodes_2": 4, "n_hidden_nodes_2": 8,
"n_hidden_nodes_3": 4, # "n_hidden_nodes_3": 4,
"n_hidden_nodes_4": 2, # "n_hidden_nodes_4": 2,
}, },
model_activation_func="EOActivation", model_activation_func="EOActivation",
dropout_prob=0.01, dropout_prob=0,
model_layer_function="ONNRect", model_layer_function="ONNRect",
model_layer_kwargs={"square": True}, model_layer_kwargs={"square": True},
scale=False, scale=2.0,
model_layer_parametrizations=[ model_layer_parametrizations=[
{ # EOactivation
"tensor_name": "weight",
"parametrization": util.complexNN.energy_conserving,
},
{ {
"tensor_name": "alpha", "tensor_name": "alpha",
"parametrization": util.complexNN.clamp, "parametrization": util.complexNN.clamp,
"kwargs": {
"min": 0,
"max": 1,
},
}, },
# ONNRect
{ {
"tensor_name": "gain", "tensor_name": "weight",
"parametrization": torch.nn.utils.parametrizations.orthogonal,
},
# Scale
{
"tensor_name": "scale",
"parametrization": util.complexNN.clamp, "parametrization": util.complexNN.clamp,
"kwargs": { "kwargs": {
"min": 0, "min": 0,
"max": float("inf"), "max": 10,
}, },
}, }
{
"tensor_name": "phase_bias",
"parametrization": util.complexNN.clamp,
"kwargs": {
"min": 0,
"max": 2 * torch.pi,
},
},
{
"tensor_name": "scales",
"parametrization": util.complexNN.clamp,
},
{
"tensor_name": "angle",
"parametrization": util.complexNN.clamp,
"kwargs": {
"min": -torch.pi,
"max": torch.pi,
},
},
# {
# "tensor_name": "scale",
# "parametrization": util.complexNN.clamp,
# },
# {
# "tensor_name": "bias",
# "parametrization": util.complexNN.clamp,
# },
# {
# "tensor_name": "V",
# "parametrization": torch.nn.utils.parametrizations.orthogonal,
# },
{
"tensor_name": "loss",
"parametrization": util.complexNN.clamp,
},
], ],
) )
optimizer_settings = OptimizerSettings( optimizer_settings = OptimizerSettings(
optimizer="AdamW", optimizer="AdamW",
optimizer_kwargs={ optimizer_kwargs={
"lr": 0.01, "lr": 0.005,
"amsgrad": True, "amsgrad": True,
# "weight_decay": 1e-7, # "weight_decay": 1e-7,
}, },
@@ -134,175 +119,35 @@ optimizer_settings = OptimizerSettings(
scheduler="ReduceLROnPlateau", scheduler="ReduceLROnPlateau",
scheduler_kwargs={ scheduler_kwargs={
"patience": 2**6, "patience": 2**6,
"factor": 0.75, "factor": 0.5,
# "threshold": 1e-3, # "threshold": 1e-3,
"min_lr": 1e-6, "min_lr": 1e-6,
"cooldown": 10, "cooldown": 10,
}, },
early_stopping=True,
early_stop_kwargs={
"threshold": 1e-06,
"plateau": 2**7,
}
) )
def save_dict_to_file(dictionary, filename):
"""
Save the best dictionary to a JSON file.
:param best: Dictionary containing the best training results.
:type best: dict
:param filename: Path to the JSON file where the dictionary will be saved.
:type filename: str
"""
with open(filename, "w") as f:
json.dump(dictionary, f, indent=4)
def sweep_lengths(*lengths, model=None, data_glob: str = None, strategy="newest"):
assert model is not None, "Model must be provided."
assert data_glob is not None, "Data glob must be provided."
model = model
fiber_ins = {}
fiber_outs = {}
regens = {}
timestampss = {}
trainer = RegenerationTrainer(
checkpoint_path=model,
)
trainer.define_model()
for length in lengths:
data_glob_length = data_glob.replace("{length}", str(length))
files = list(Path.cwd().glob(data_glob_length))
if len(files) == 0:
continue
if strategy == "newest":
sorted_kwargs = {
"key": lambda x: x.stat().st_mtime,
"reverse": True,
}
elif strategy == "oldest":
sorted_kwargs = {
"key": lambda x: x.stat().st_mtime,
"reverse": False,
}
else:
raise ValueError(f"Unknown strategy {strategy}.")
file = sorted(files, **sorted_kwargs)[0]
loader, _ = trainer.get_sliced_data(override={"config_path": file})
fiber_in, fiber_out, regen, timestamps = trainer.run_model(trainer.model, loader=loader)
fiber_ins[length] = fiber_in
fiber_outs[length] = fiber_out
regens[length] = regen
timestampss[length] = timestamps
data = torch.zeros(2 * len(timestampss.keys()) + 2, 2, tuple(fiber_outs.values())[-1].shape[0])
channel_names = ["" for _ in range(2 * len(timestampss.keys()) + 2)]
data[1, 0, :] = timestampss[tuple(timestampss.keys())[-1]] / 128
data[1, 1, :] = fiber_ins[tuple(timestampss.keys())[-1]][:, 0].abs().square()
channel_names[1] = "fiber in x"
for li, length in enumerate(timestampss.keys()):
data[2 + 2 * li, 0, :] = timestampss[length] / 128
data[2 + 2 * li, 1, :] = fiber_outs[length][:, 0].abs().square()
data[2 + 2 * li + 1, 0, :] = timestampss[length] / 128
data[2 + 2 * li + 1, 1, :] = regens[length][:, 0].abs().square()
channel_names[2 + 2 * li + 1] = f"regen x {length}"
channel_names[2 + 2 * li] = f"fiber out x {length}"
# get current backend
backend = matplotlib.get_backend()
matplotlib.use("TkCairo")
eye = util.eye_diagram.eye_diagram(data.to(dtype=torch.float32).detach().cpu().numpy(), channel_names=channel_names)
print_attrs = ("channel_name", "success", "min_area")
with np.printoptions(precision=3, suppress=True, formatter={"float": "{:0.3e}".format}):
for result in eye.eye_stats:
print_dict = {attr: result[attr] for attr in print_attrs}
rprint(print_dict)
rprint()
eye.plot(all_stats=False)
matplotlib.use(backend)
if __name__ == "__main__": if __name__ == "__main__":
# lengths = range(90000, 100000+10000, 10000)
# lengths = [100000]
# sweep_lengths(*lengths, model=".models/best_20241204_132605.tar", data_glob="data/202412*-{length}-*-0.ini", strategy="newest")
trainer = RegenerationTrainer( trainer = RegenerationTrainer(
global_settings=global_settings, global_settings=global_settings,
data_settings=data_settings, data_settings=data_settings,
pytorch_settings=pytorch_settings, pytorch_settings=pytorch_settings,
model_settings=model_settings, model_settings=model_settings,
optimizer_settings=optimizer_settings, optimizer_settings=optimizer_settings,
# checkpoint_path=".models/best_20241205_235929.tar", checkpoint_path=".models/best_20250117_144001.tar",
# 20241202_143149 new_model=True,
settings_override={
"data_settings": data_settings.__dict__,
# "optimizer_settings": {
# "early_stop_kwargs":{
# "plateau": 2**8,
# }
# }
}
) )
trainer.train() trainer.train()
# from hypertraining.lighning_models import regenerator, regeneratorData
# import lightning as L
# model = regenerator(
# 2 * data_settings.output_size,
# *model_settings.overrides["hidden_layer_dims"],
# model_settings.output_dim,
# layer_function=getattr(util.complexNN, model_settings.model_layer_function),
# layer_func_kwargs=model_settings.model_layer_kwargs,
# act_function=getattr(util.complexNN, model_settings.model_activation_func),
# act_func_kwargs=None,
# parametrizations=model_settings.model_layer_parametrizations,
# dtype=getattr(torch, data_settings.dtype),
# dropout_prob=model_settings.dropout_prob,
# scale_layers=model_settings.scale,
# optimizer=getattr(torch.optim, optimizer_settings.optimizer),
# optimizer_kwargs=optimizer_settings.optimizer_kwargs,
# lr_scheduler=getattr(torch.optim.lr_scheduler, optimizer_settings.scheduler),
# lr_scheduler_kwargs=optimizer_settings.scheduler_kwargs,
# )
# dm = regeneratorData(
# config_globs=data_settings.config_path,
# output_symbols=data_settings.symbols,
# output_dim=data_settings.output_size,
# dtype=getattr(torch, data_settings.dtype),
# drop_first=data_settings.drop_first,
# shuffle=data_settings.shuffle,
# train_split=data_settings.train_split,
# batch_size=pytorch_settings.batchsize,
# loader_settings={
# "num_workers": pytorch_settings.dataloader_workers,
# "prefetch_factor": pytorch_settings.dataloader_prefetch,
# "pin_memory": True,
# "drop_last": True,
# },
# seed=global_settings.seed,
# )
# # writer = L.SummaryWriter(pytorch_settings.summary_dir + f"/{datetime.now().strftime('%Y%m%d_%H%M%S')}")
# # from torch.utils.tensorboard import SummaryWriter
# subdir = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}"
# # writer = SummaryWriter(pytorch_settings.summary_dir + f"/{subdir}")
# logger = L.pytorch.loggers.TensorBoardLogger(pytorch_settings.summary_dir, name=subdir, log_graph=True)
# trainer = L.Trainer(
# fast_dev_run=False,
# # max_epochs=pytorch_settings.epochs,
# max_epochs=2,
# enable_checkpointing=True,
# default_root_dir=f".models/{subdir}/",
# logger=logger,
# )
# trainer.fit(model, dm)

View File

@@ -12,20 +12,22 @@ Full license text in LICENSE file
""" """
import configparser import configparser
# import copy
from datetime import datetime from datetime import datetime
import hashlib import hashlib
from pathlib import Path from pathlib import Path
import time import time
import h5py
from matplotlib import pyplot as plt # noqa: F401 from matplotlib import pyplot as plt # noqa: F401
import numpy as np import numpy as np
import add_pypho # noqa: F401 from . import add_pypho # noqa: F401
import pypho import pypho
default_config = f""" default_config = f"""
[glova] [glova]
nos = 256 sps = 128
sps = 256 nos = 16384
f0 = 193414489032258.06 f0 = 193414489032258.06
symbolrate = 10e9 symbolrate = 10e9
wisdom_dir = "{str((Path.home() / ".pypho"))}" wisdom_dir = "{str((Path.home() / ".pypho"))}"
@@ -37,9 +39,9 @@ length = 10000
gamma = 1.14 gamma = 1.14
alpha = 0.2 alpha = 0.2
D = 17 D = 17
S = 0 S = 0.058
birefsteps = 0 bireflength = 10
max_delta_beta = 0.4 pmd_q = 0.2
; birefseed = 0xC0FFEE ; birefseed = 0xC0FFEE
[signal] [signal]
@@ -47,17 +49,15 @@ max_delta_beta = 0.4
modulation = "pam" modulation = "pam"
mod_order = 4 mod_order = 4
mod_depth = 0.8 mod_depth = 1
max_jitter = 0.02 max_jitter = 0.02
; jitter_seed = 0xC0FFEE ; jitter_seed = 0xC0FFEE
laser_power = 0 laser_power = 0
edfa_power = 3 edfa_power = 0
edfa_nf = 5 edfa_nf = 5
pulse_shape = "gauss" pulse_shape = "gauss"
fwhm = 0.33 fwhm = 0.33
osnr = "inf"
[data] [data]
dir = "data" dir = "data"
@@ -71,6 +71,7 @@ def get_config(config_file=None):
""" """
if config_file is None: if config_file is None:
config_file = Path(__file__).parent / "signal_generation.ini" config_file = Path(__file__).parent / "signal_generation.ini"
config_file = Path(config_file)
if not config_file.exists(): if not config_file.exists():
with open(config_file, "w") as f: with open(config_file, "w") as f:
f.write(default_config) f.write(default_config)
@@ -83,7 +84,10 @@ def get_config(config_file=None):
conf[section] = {} conf[section] = {}
for key in config[section]: for key in config[section]:
# print(f"{key} = {config[section][key]}") # print(f"{key} = {config[section][key]}")
conf[section][key] = eval(config[section][key]) try:
conf[section][key] = eval(config[section][key])
except NameError:
conf[section][key] = float(config[section][key])
# if isinstance(conf[section][key], str): # if isinstance(conf[section][key], str):
# conf[section][key] = config[section][key].strip('"') # conf[section][key] = config[section][key].strip('"')
return conf return conf
@@ -96,7 +100,9 @@ class PDM_IM_IPM:
mod_order=8, mod_order=8,
seed=None, seed=None,
): ):
assert np.cbrt(mod_order) == int(np.cbrt(mod_order)) and mod_order > 1, "mod_order must be a cube of an integer greater than 1" assert np.cbrt(mod_order) == int(np.cbrt(mod_order)) and mod_order > 1, (
"mod_order must be a cube of an integer greater than 1"
)
self.glova = glova self.glova = glova
self.mod_order = mod_order self.mod_order = mod_order
self.symbols_per_dim = int(np.cbrt(mod_order)) self.symbols_per_dim = int(np.cbrt(mod_order))
@@ -106,18 +112,11 @@ class PDM_IM_IPM:
rs = np.random.RandomState(self.seed) rs = np.random.RandomState(self.seed)
symbols = rs.randint(0, self.mod_order, n) symbols = rs.randint(0, self.mod_order, n)
return symbols return symbols
class pam_generator: class pam_generator:
def __init__( def __init__(
self, self, glova, mod_order=None, mod_depth=0.5, pulse_shape="gauss", fwhm=0.33, seed=None, single_channel=False
glova,
mod_order=None,
mod_depth=0.5,
pulse_shape="gauss",
fwhm=0.33,
seed=None,
single_channel=False
) -> None: ) -> None:
self.glova = glova self.glova = glova
self.pulse_shape = pulse_shape self.pulse_shape = pulse_shape
@@ -133,49 +132,43 @@ class pam_generator:
wavelet = self.gauss(oversampling=6) wavelet = self.gauss(oversampling=6)
else: else:
raise ValueError(f"Unknown pulse shape: {self.pulse_shape}") raise ValueError(f"Unknown pulse shape: {self.pulse_shape}")
# prepare symbols # prepare symbols
symbols_x = symbols[0] / (self.mod_order) symbols_x = symbols[0] / (self.mod_order)
diffs_x = np.diff(symbols_x, prepend=symbols_x[0]) diffs_x = np.diff(symbols_x, prepend=symbols_x[0])
digital_x = self.generate_digital_signal(diffs_x, max_jitter) digital_x = self.generate_digital_signal(diffs_x, max_jitter)
digital_x = np.pad( digital_x = np.pad(digital_x, (0, self.glova.sps // 2), "constant", constant_values=(0, 0))
digital_x, (0, self.glova.sps // 2), "constant", constant_values=(0, 0)
)
# create analog signal of diff of symbols # create analog signal of diff of symbols
E_x = np.convolve(digital_x, wavelet) E_x = np.convolve(digital_x, wavelet)
# convert to pam and set modulation depth (scale and move up such that 1 stays at 1) # convert to pam and set modulation depth (scale and move up such that 1 stays at 1)
E_x = np.cumsum(E_x) * self.modulation_depth + 2*(1 - self.modulation_depth) E_x = np.cumsum(E_x) * self.modulation_depth + (1 - self.modulation_depth)
# cut off the wavelet tails # cut off the wavelet tails
E_x = E_x[self.glova.sps // 2 + len(wavelet) // 2 - 1 : -len(wavelet) // 2] E_x = E_x[self.glova.sps // 2 + len(wavelet) // 2 - 1 : -len(wavelet) // 2]
# modulate the laser # modulate the laser
E[0]["E"][0] = np.sqrt(np.multiply(np.square(E[0]["E"][0]), E_x)) E[0]["E"][0] = np.sqrt(np.multiply(np.square(E[0]["E"][0]), E_x))
if not self.single_channel: if not self.single_channel:
symbols_y = symbols[1] / (self.mod_order) symbols_y = symbols[1] / (self.mod_order)
diffs_y = np.diff(symbols_y, prepend=symbols_y[0]) diffs_y = np.diff(symbols_y, prepend=symbols_y[0])
digital_y = self.generate_digital_signal(diffs_y, max_jitter) digital_y = self.generate_digital_signal(diffs_y, max_jitter)
digital_y = np.pad( digital_y = np.pad(digital_y, (0, self.glova.sps // 2), "constant", constant_values=(0, 0))
digital_y, (0, self.glova.sps // 2), "constant", constant_values=(0, 0)
)
E_y = np.convolve(digital_y, wavelet) E_y = np.convolve(digital_y, wavelet)
E_y = np.cumsum(E_y) * self.modulation_depth + (1 - self.modulation_depth) E_y = np.cumsum(E_y) * self.modulation_depth + (1 - self.modulation_depth)
E_y = E_y[self.glova.sps // 2 + len(wavelet) // 2 - 1 : -len(wavelet) // 2] E_y = E_y[self.glova.sps // 2 + len(wavelet) // 2 - 1 : -len(wavelet) // 2]
E[0]["E"][1] = np.sqrt(np.multiply(np.square(E[0]["E"][1]), E_y)) E[0]["E"][1] = np.sqrt(np.multiply(np.square(E[0]["E"][1]), E_y))
# rotate the signal on the y-polarisation by 90° # rotate the signal on the y-polarisation by 90°
E[0]["E"][1] *= 1j # E[0]["E"][1] *= 1j
else: else:
E[0]["E"][1] = np.zeros_like(E[0]["E"][0], dtype=E[0]["E"][0].dtype) E[0]["E"][1] = np.zeros_like(E[0]["E"][0], dtype=E[0]["E"][0].dtype)
return E return E
def generate_digital_signal(self, symbols, max_jitter=0): def generate_digital_signal(self, symbols, max_jitter=0):
rs = np.random.RandomState(self.seed) rs = np.random.RandomState(self.seed)
@@ -198,19 +191,19 @@ class pam_generator:
endpoint=True, endpoint=True,
) )
sigma = self.fwhm / (1 * np.sqrt(2 * np.log(2))) * self.glova.sps sigma = self.fwhm / (1 * np.sqrt(2 * np.log(2))) * self.glova.sps
pulse = ( pulse = 1 / (sigma * np.sqrt(2 * np.pi)) * np.exp(-np.square(sample_points) / (2 * np.square(sigma)))
1
/ (sigma * np.sqrt(2 * np.pi))
* np.exp(-np.square(sample_points) / (2 * np.square(sigma)))
)
return pulse return pulse
def initialize_fiber_and_data(config, input_data_override=None): def initialize_fiber_and_data(config):
f0 = config["glova"].get("f0", None)
if f0 is None:
f0 = 299792458/(config["glova"].get("lambda0", 1550)*1e-9)
config["glova"]["f0"] = f0
py_glova = pypho.setup( py_glova = pypho.setup(
nos=config["glova"]["nos"], nos=config["glova"]["nos"],
sps=config["glova"]["sps"], sps=config["glova"]["sps"],
f0=config["glova"]["f0"], f0=f0,
symbolrate=config["glova"]["symbolrate"], symbolrate=config["glova"]["symbolrate"],
wisdom_dir=config["glova"]["wisdom_dir"], wisdom_dir=config["glova"]["wisdom_dir"],
flags=config["glova"]["flags"], flags=config["glova"]["flags"],
@@ -221,49 +214,89 @@ def initialize_fiber_and_data(config, input_data_override=None):
c_data = pypho.cfiber.DataWrapper(py_glova.sps * py_glova.nos) c_data = pypho.cfiber.DataWrapper(py_glova.sps * py_glova.nos)
py_edfa = pypho.oamp(py_glova, Pmean=config["signal"]["edfa_power"], NF=config["signal"]["edfa_nf"]) py_edfa = pypho.oamp(py_glova, Pmean=config["signal"]["edfa_power"], NF=config["signal"]["edfa_nf"])
if input_data_override is not None: osnr = config["signal"]["osnr"] = float(config["signal"].get("osnr", "inf"))
c_data.E_in = input_data_override[0]
noise = input_data_override[1] config["signal"]["seed"] = config["signal"].get("seed", (int(time.time() * 1000)) % 2**32)
else: config["signal"]["jitter_seed"] = config["signal"].get("jitter_seed", (int(time.time() * 1000)) % 2**32)
config["signal"]["seed"] = config["signal"].get( symbolsrc = pypho.symbols(
"seed", (int(time.time() * 1000)) % 2**32 py_glova, py_glova.nos, pattern="ones", p1=config["signal"]["mod_order"], seed=config["signal"]["seed"]
) )
config["signal"]["jitter_seed"] = config["signal"].get( laserx = pypho.lasmod(py_glova, power=0, Df=0, theta=np.pi/4)
"jitter_seed", (int(time.time() * 1000)) % 2**32 # lasery = pypho.lasmod(py_glova, power=0, Df=25, theta=0)
)
symbolsrc = pypho.symbols( modulator = pam_generator(
py_glova, py_glova.nos, pattern="ones", p1=config["signal"]["mod_order"], seed=config["signal"]["seed"] py_glova,
) mod_depth=config["signal"]["mod_depth"],
laser = pypho.lasmod( pulse_shape=config["signal"]["pulse_shape"],
py_glova, power=config["signal"]["laser_power"]+1.5, Df=0, theta=np.pi / 4 fwhm=config["signal"]["fwhm"],
) seed=config["signal"]["jitter_seed"],
modulator = pam_generator( mod_order=config["signal"]["mod_order"],
py_glova, )
mod_depth=config["signal"]["mod_depth"],
pulse_shape=config["signal"]["pulse_shape"], symbols_x = symbolsrc(pattern="random")
fwhm=config["signal"]["fwhm"], symbols_y = symbolsrc(pattern="random")
seed=config["signal"]["jitter_seed"], symbols_x[:3] = 0
single_channel=False, symbols_y[:3] = 0
mod_order=config["signal"]["mod_order"], # symbols_x += 1
cw = laserx()
# cwy = lasery()
# cw[0]['E'][0] = cw[0]['E'][0]
# cw[0]['E'][1] = cwy[0]['E'][0]
source_signal = modulator(E=cw, symbols=(symbols_x, symbols_y))
if osnr != float("inf"):
osnr_lin = 10 ** (osnr / 10)
signal_power = np.sum(pypho.functions.getpower_W(source_signal[0]["E"]))
noise_power = signal_power / osnr_lin
noise = np.random.normal(0, 1, source_signal[0]["E"].shape) + 1j * np.random.normal(
0, 1, source_signal[0]["E"].shape
) )
noise_power_is = np.sum(pypho.functions.getpower_W(noise))
noise = noise * np.sqrt(noise_power / noise_power_is)
noise_power_is = np.sum(pypho.functions.getpower_W(noise))
source_signal[0]["E"] += noise
source_signal[0]["noise"] = noise_power_is
symbols_x = symbolsrc(pattern="random") # source_signal[0]['E'][1] += source_signal[0]['E'][1][np.argmin(np.abs(source_signal[0]['E'][1]))]
symbols_y = symbolsrc(pattern="random")
symbols_x[:3] = 0 ## side channels
symbols_y[:3] = 0 # df = 100
# symbols_x += 1 # signal_power = pypho.functions.W_to_dBm(np.sum(pypho.functions.getpower_W(source_signal[0]["E"])))
cw = laser() # symbols_x_side = symbolsrc(pattern="random")
# symbols_y_side = symbolsrc(pattern="random")
# symbols_x_side[:3] = 0
# symbols_y_side[:3] = 0
# cw_left = laser(Df=-df)
# source_signal_left = modulator(E=cw_left, symbols=(symbols_x_side, symbols_y_side))
# cw_right = laser(Df=df)
# source_signal_right = modulator(E=cw_right, symbols=(symbols_y_side, symbols_x_side))
source_signal = modulator(E=cw, symbols=(symbols_x, symbols_y)) E_in_pure = source_signal[0]["E"]
# source_signal[0]['E'][1] += source_signal[0]['E'][1][np.argmin(np.abs(source_signal[0]['E'][1]))] nf = py_edfa.NF
pmean = py_edfa.Pmean
source_signal = py_edfa(E=source_signal)
c_data.E_in = source_signal[0]["E"] # ideal amplification to launch power into fiber
noise = source_signal[0]["noise"] source_signal = py_edfa(E=source_signal, NF=0, Pmean=config["signal"]["laser_power"])
# source_signal_left = py_edfa(E=source_signal_left, NF=0, Pmean=config["signal"]["laser_power"])
# source_signal_right = py_edfa(E=source_signal_right, NF=0, Pmean=config["signal"]["laser_power"])
# source_signal[0]["E"][0] += source_signal_left[0]["E"][0] + source_signal_right[0]["E"][0]
# source_signal[0]["E"][1] += source_signal_left[0]["E"][1] + source_signal_right[0]["E"][1]
c_data.E_in = source_signal[0]["E"]
noise = source_signal[0]["noise"]
py_edfa.NF = nf
py_edfa.Pmean = pmean
py_fiber = pypho.fiber( py_fiber = pypho.fiber(
glova=py_glova, glova=py_glova,
@@ -272,27 +305,32 @@ def initialize_fiber_and_data(config, input_data_override=None):
gamma=config["fiber"]["gamma"], gamma=config["fiber"]["gamma"],
D=config["fiber"]["d"], D=config["fiber"]["d"],
S=config["fiber"]["s"], S=config["fiber"]["s"],
phi_max=0.02,
) )
if config["fiber"].get("birefsteps", 0) > 0:
seed = config["fiber"].get( config["fiber"]["birefsteps"] = config["fiber"].get(
"birefseed", (int(time.time() * 1000)) % 2**32 "birefsteps", config["fiber"]["length"] // config["fiber"].get("bireflength", config["fiber"]["length"])
) )
if config["fiber"]["birefsteps"] > 0:
config["fiber"]["bireflength"] = config["fiber"].get("bireflength", config["fiber"]["length"] / config["fiber"]["birefsteps"])
seed = config["fiber"].get("birefseed", (int(time.time() * 1000)) % 2**32)
py_fiber.birefarray = pypho.birefringence_segment.create_pmd_fibre( py_fiber.birefarray = pypho.birefringence_segment.create_pmd_fibre(
py_fiber.l, config["fiber"]["length"],
py_fiber.l / config["fiber"]["birefsteps"], config["fiber"]["bireflength"],
# maxDeltaD=config["fiber"]["d"]/5, maxDeltaBeta=config["fiber"].get("pmd_q", 0)/np.sqrt(2*config["fiber"]["bireflength"]),
maxDeltaBeta = config["fiber"].get("max_delta_beta", 0),
seed=seed, seed=seed,
) )
c_params = pypho.cfiber.ParamsWrapper.from_fiber( elif (dgd := config['fiber'].get('dgd', 0)) > 0:
py_fiber, max_step=1e3 if py_fiber.gamma == 0 else 200 py_fiber.birefarray = [
) pypho.birefringence_segment(z_point=0, angle=np.pi/2, delta_beta=1000*dgd/config["fiber"]["length"])
]
c_params = pypho.cfiber.ParamsWrapper.from_fiber(py_fiber, max_step=config["fiber"]["length"] if py_fiber.gamma == 0 else 200)
c_fiber = pypho.cfiber.FiberWrapper(c_data, c_params, c_glova) c_fiber = pypho.cfiber.FiberWrapper(c_data, c_params, c_glova)
return c_fiber, c_data, noise, py_edfa return c_fiber, c_data, noise, py_edfa, (symbols_x, symbols_y), py_glova, E_in_pure
def save_data(data, config): def save_data(data, config, **metadata):
data_dir = Path(config["data"]["dir"]) data_dir = Path(config["data"]["dir"])
npy_dir = config["data"].get("npy_dir", "") npy_dir = config["data"].get("npy_dir", "")
save_dir = data_dir / npy_dir if len(npy_dir) else data_dir save_dir = data_dir / npy_dir if len(npy_dir) else data_dir
@@ -307,6 +345,7 @@ def save_data(data, config):
seed = config["signal"].get("seed", False) seed = config["signal"].get("seed", False)
jitter_seed = config["signal"].get("jitter_seed", False) jitter_seed = config["signal"].get("jitter_seed", False)
birefseed = config["fiber"].get("birefseed", False) birefseed = config["fiber"].get("birefseed", False)
osnr = float(config["signal"].get("osnr", "inf"))
config_content = "\n".join(( config_content = "\n".join((
f"; Generated by {str(Path(__file__).name)} @ {timestamp.strftime('%Y-%m-%d %H:%M:%S')}", f"; Generated by {str(Path(__file__).name)} @ {timestamp.strftime('%Y-%m-%d %H:%M:%S')}",
@@ -318,16 +357,19 @@ def save_data(data, config):
f'wisdom_dir = "{config["glova"]["wisdom_dir"]}"', f'wisdom_dir = "{config["glova"]["wisdom_dir"]}"',
f'flags = "{config["glova"]["flags"]}"', f'flags = "{config["glova"]["flags"]}"',
f"nthreads = {config['glova']['nthreads']}", f"nthreads = {config['glova']['nthreads']}",
" ", "",
"[fiber]", "[fiber]",
f"length = {config['fiber']['length']}", f"length = {config['fiber']['length']}",
f"gamma = {config['fiber']['gamma']}", f"gamma = {config['fiber']['gamma']}",
f"alpha = {config['fiber']['alpha']}", f"alpha = {config['fiber']['alpha']}",
f"D = {config['fiber']['d']}", f"D = {config['fiber']['d']}",
f"S = {config['fiber']['s']}", f"S = {config['fiber']['s']}",
f"birefsteps = {config['fiber'].get('birefsteps',0)}", f"birefsteps = {config['fiber'].get('birefsteps', 0)}",
f"max_delta_beta = {config['fiber'].get('max_delta_beta', 0)}", f"pmd_q = {config['fiber'].get('pmd_q', 0)}",
f"birefseed = {hex(birefseed)}" if birefseed else "; birefseed = not set", f"birefseed = {hex(birefseed)}" if birefseed else "; birefseed = not set",
f"dgd = {config['fiber'].get('dgd', 0)}",
f"ortho_error = {config['fiber'].get('ortho_error', 0)}",
f"pol_error = {config['fiber'].get('pol_error', 0)}",
"", "",
"[signal]", "[signal]",
f"seed = {hex(seed)}" if seed else "; seed = not set", f"seed = {hex(seed)}" if seed else "; seed = not set",
@@ -335,100 +377,93 @@ def save_data(data, config):
f'modulation = "{config["signal"]["modulation"]}"', f'modulation = "{config["signal"]["modulation"]}"',
f"mod_order = {config['signal']['mod_order']}", f"mod_order = {config['signal']['mod_order']}",
f"mod_depth = {config['signal']['mod_depth']}", f"mod_depth = {config['signal']['mod_depth']}",
"" "",
f"max_jitter = {config['signal']['max_jitter']}", f"max_jitter = {config['signal']['max_jitter']}",
f"jitter_seed = {hex(jitter_seed)}" if jitter_seed else "; jitter_seed = not set", f"jitter_seed = {hex(jitter_seed)}" if jitter_seed else "; jitter_seed = not set",
"" "",
f"laser_power = {config['signal']['laser_power']}", f"laser_power = {config['signal']['laser_power']}",
f"edfa_power = {config['signal']['edfa_power']}", f"edfa_power = {config['signal']['edfa_power']}",
f"edfa_nf = {config['signal']['edfa_nf']}", f"edfa_nf = {config['signal']['edfa_nf']}",
"" f"osnr = {osnr}",
"",
f'pulse_shape = "{config["signal"]["pulse_shape"]}"', f'pulse_shape = "{config["signal"]["pulse_shape"]}"',
f"fwhm = {config['signal']['fwhm']}", f"fwhm = {config['signal']['fwhm']}",
"", "",
"[data]", "[data]",
f'dir = "{str(data_dir)}"', f'dir = "{str(data_dir)}"',
f'npy_dir = "{npy_dir}"', f'npy_dir = "{npy_dir}"',
"file = " "file = ",
)) ))
config_hash = hashlib.md5(config_content.encode()).hexdigest() config_hash = hashlib.md5(config_content.encode()).hexdigest()
save_file = f"{config_hash}.npy" save_file = f"{config_hash}.h5"
config_content += f'"{str(save_file)}"\n' config_content += f'"{str(save_file)}"\n'
config_filename:Path = create_config_filename(config, data_dir, timestamp)
while config_filename.exists():
time.sleep(1)
config_filename = create_config_filename(config, data_dir=data_dir)
with open(config_filename, "w") as f:
f.write(config_content)
with h5py.File(save_dir / save_file, "w") as outfile:
outfile.create_dataset("data", data=save_data)
outfile.create_dataset("symbols", data=metadata.pop("symbols"))
for key, value in metadata.items():
# if isinstance(value, dict):
# value = json.dumps(model_runner.convert_arrays(value))
outfile.attrs[key] = value
# np.save(save_dir / save_file, save_data)
# print("Saved config to", config_filename)
# print("Saved data to", save_dir / save_file)
return config_filename
def create_config_filename(config, data_dir:Path, timestamp=None):
if timestamp is None:
timestamp = datetime.now()
filename_components = ( filename_components = (
timestamp.strftime("%Y%m%d-%H%M%S"), timestamp.strftime("%Y%m%d-%H%M%S"),
config["glova"]["sps"], config["glova"]["sps"],
config["glova"]["nos"], config["glova"]["nos"],
config["signal"]["osnr"],
config["fiber"]["length"], config["fiber"]["length"],
config["fiber"]["gamma"], config["fiber"]["gamma"],
config["fiber"]["alpha"], config["fiber"]["alpha"],
config["fiber"]["d"], config["fiber"]["d"],
config["fiber"]["s"], config["fiber"]["s"],
f"{config['signal']['modulation'].upper()}{config['signal']['mod_order']}", f"{config['signal']['modulation'].upper()}{config['signal']['mod_order']}",
config['fiber'].get('birefsteps',0), config["fiber"].get("birefsteps", 0),
config["fiber"].get("max_delta_beta", 0), config["fiber"].get("pmd_q", 0),
int(config["glova"]["symbolrate"] / 1e9),
) )
lookup_file = "-".join(map(str, filename_components)) + ".ini" lookup_file = "-".join(map(str, filename_components)) + ".ini"
with open(data_dir / lookup_file, "w") as f: return data_dir / lookup_file
f.write(config_content)
np.save(save_dir / save_file, save_data) def length_loop(config, lengths, save=True):
print("Saved config to", data_dir / lookup_file)
print("Saved data to", save_dir / save_file)
def length_loop(config, lengths, incremental=False, bireflength=None, save=True):
lengths = sorted(lengths) lengths = sorted(lengths)
input_override = None for length in lengths:
birefsteps_running = 0 print(f"\nGenerating data for fiber length {length}m")
for lind, length in enumerate(lengths):
# print(f"\nGenerating data for fiber length {length}")
if lind > 0 and incremental:
# set the length to the difference between the current and previous length -> incremental
length = lengths[lind] - lengths[lind - 1]
if incremental:
print(
f"\nGenerating data for fiber length {lengths[lind]}m [using {length}m increment]"
)
else:
print(f"\nGenerating data for fiber length {length}m")
config["fiber"]["length"] = length config["fiber"]["length"] = length
if bireflength is not None and bireflength > 0:
config["fiber"]["birefsteps"] = length // bireflength
birefsteps_running += config["fiber"]["birefsteps"]
# set the input data to the output data of the previous run
cfiber, cdata, noise, edfa = initialize_fiber_and_data(
config, input_data_override=input_override
)
if lind == 0: cfiber, cdata, noise, edfa, symbols, py_glova = initialize_fiber_and_data(config)
cdata_orig = cdata
mean_power_in = np.sum(pypho.functions.getpower_W(cdata.E_in)) mean_power_in = np.sum(pypho.functions.getpower_W(cdata.E_in))
print(
f"Mean power in: {mean_power_in:.3e} W ({pypho.functions.W_to_dBm(mean_power_in):.3e} dBm)"
)
cfiber() cfiber()
mean_power_out = np.sum(pypho.functions.getpower_W(cdata.E_out)) mean_power_out = np.sum(pypho.functions.getpower_W(cdata.E_out))
print(
f"Mean power out: {mean_power_out:.3e} W ({pypho.functions.W_to_dBm(mean_power_out):.3e} dBm)"
)
if incremental: print(f"Mean power in: {mean_power_in:.3e} W ({pypho.functions.W_to_dBm(mean_power_in):.3e} dBm)")
input_override = (cdata.E_out, noise) print(f"Mean power out: {mean_power_out:.3e} W ({pypho.functions.W_to_dBm(mean_power_out):.3e} dBm)")
cdata.E_in = cdata_orig.E_in
config["fiber"]["length"] = lengths[lind] E_tmp = [{"E": cdata.E_out, "noise": noise * (-cfiber.params.l * cfiber.params.alpha)}]
if bireflength is not None:
config["fiber"]["birefsteps"] = birefsteps_running
E_tmp = [{'E': cdata.E_out, 'noise': noise*(-cfiber.params.l*cfiber.params.alpha)}]
E_tmp = edfa(E=E_tmp) E_tmp = edfa(E=E_tmp)
cdata.E_out = E_tmp[0]['E'] cdata.E_out = E_tmp[0]["E"]
mean_power_amp = np.sum(pypho.functions.getpower_W(cdata.E_out))
print(f"Mean power after EDFA: {mean_power_amp:.3e} W ({pypho.functions.W_to_dBm(mean_power_amp):.3e} dBm)")
if save: if save:
save_data(cdata, config) save_data(cdata, config)
@@ -436,27 +471,55 @@ def length_loop(config, lengths, incremental=False, bireflength=None, save=True)
def single_run_with_plot(config, save=True): def single_run_with_plot(config, save=True):
cfiber, cdata, noise, edfa = initialize_fiber_and_data(config) cfiber, cdata, config_filename = single_run(config, save)
mean_power_in = np.sum(pypho.functions.getpower_W(cdata.E_in))
print(
f"Mean power in: {mean_power_in:.3e} W ({pypho.functions.W_to_dBm(mean_power_in):.3e} dBm)"
)
cfiber()
mean_power_out = np.sum(pypho.functions.getpower_W(cdata.E_out))
print(
f"Mean power out: {mean_power_out:.3e} W ({pypho.functions.W_to_dBm(mean_power_out):.3e} dBm)"
)
E_tmp = [{'E': cdata.E_out, 'noise': noise*(-cfiber.params.l*cfiber.params.alpha)}]
E_tmp = edfa(E=E_tmp)
cdata.E_out = E_tmp[0]['E']
if save:
save_data(cdata, config)
in_out_eyes(cfiber, cdata, show_pols=False) in_out_eyes(cfiber, cdata, show_pols=False)
return config_filename
def single_run(config, save=True, silent=True):
cfiber, cdata, noise, edfa, symbols, glova, E_in = initialize_fiber_and_data(config)
# transmit
cfiber()
# amplify
E_tmp = [{"E": cdata.E_out, "noise": noise}]
E_tmp = edfa(E=E_tmp)
# rotate
# ortho error
ortho_error = config["fiber"].get("ortho_error", 0)
E_tmp[0]["E"] = np.stack((
E_tmp[0]["E"][0] * np.cos(ortho_error/2) + E_tmp[0]["E"][1] * np.sin(ortho_error/2),
E_tmp[0]["E"][0] * np.sin(ortho_error/2) + E_tmp[0]["E"][1] * np.cos(ortho_error/2)
), axis=0)
pol_error = config['fiber'].get('pol_error', 0)
E_tmp[0]["E"] = np.stack((
E_tmp[0]["E"][0] * np.cos(pol_error) - E_tmp[0]["E"][1] * np.sin(pol_error),
E_tmp[0]["E"][0] * np.sin(pol_error) + E_tmp[0]["E"][1] * np.cos(pol_error)
), axis=0)
# output
cdata.E_out = E_tmp[0]["E"]
config_filename = None
symbols = np.array(symbols)
if save:
config_filename = save_data(cdata, config, **{"symbols": symbols})
if not silent:
print(f"Saved config to {config_filename}")
return cfiber, cdata, config_filename
def in_out_eyes(cfiber, cdata, show_pols=False): def in_out_eyes(cfiber, cdata, show_pols=False):
fig, axs = plt.subplots(2, 2, sharex=True, sharey=True) fig, axs = plt.subplots(2, 2, sharex=True, sharey=True)
@@ -620,9 +683,7 @@ def plot_eye_diagram(
signal = signal[: head * eye_width] signal = signal[: head * eye_width]
if normalize: if normalize:
signal = signal / np.max(signal) signal = signal / np.max(signal)
slices = np.lib.stride_tricks.sliding_window_view(signal, eye_width + 1)[ slices = np.lib.stride_tricks.sliding_window_view(signal, eye_width + 1)[offset % (eye_width + 1) :: eye_width]
offset % (eye_width + 1) :: eye_width
]
plt_ax = np.arange(-eye_width // 2, eye_width // 2 + 1) / samplerate plt_ax = np.arange(-eye_width // 2, eye_width // 2 + 1) / samplerate
for slice in slices: for slice in slices:
ax.plot(plt_ax, slice, color=color, alpha=0.1) ax.plot(plt_ax, slice, color=color, alpha=0.1)
@@ -642,14 +703,27 @@ if __name__ == "__main__":
# lengths.append(10*max(ranges)) # lengths.append(10*max(ranges))
# lengths = [*lengths, *lengths] # lengths = [*lengths, *lengths]
lengths = ( lengths = (
# 8000, 9000, # 8000, 9000,
10000, 20000, 30000, 40000, 50000, 60000, 70000, 80000, 90000, 10000,
95000, 100000, 105000, 110000, 115000, 120000 20000,
30000,
40000,
50000,
60000,
70000,
80000,
90000,
95000,
100000,
105000,
110000,
115000,
120000,
) )
lengths = sorted(lengths)
length_loop(config, lengths, incremental=False, bireflength=1000, save=True) # lengths = (10000,100000)
# length_loop(config, lengths, save=True)
# birefringence is constant over coupling length -> several 100m -> bireflength=1000 (m) # birefringence is constant over coupling length -> several 100m -> bireflength=1000 (m)
# single_run_with_plot(config, save=True) single_run_with_plot(config, save=False)

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff

View File

@@ -26,7 +26,7 @@ global_settings = GlobalSettings(
) )
data_settings = DataSettings( data_settings = DataSettings(
config_path="data/*-128-16384-10000-0-0-0-0-PAM4-0-0.4.ini", config_path="data/20241211-105524-128-16384-1-0-0-0-0-PAM4-0-0.ini",
# config_path=[f"data/20241202-*-128-16384-{length}-0-0-17-0-PAM4-0.ini" for length in range(48000, 53000, 1000)], # config_path=[f"data/20241202-*-128-16384-{length}-0-0-17-0-PAM4-0.ini" for length in range(48000, 53000, 1000)],
dtype="complex64", dtype="complex64",
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber # symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
@@ -53,14 +53,14 @@ pytorch_settings = PytorchSettings(
) )
model_settings = ModelSettings( model_settings = ModelSettings(
output_dim=3, output_dim=1,
n_hidden_layers=3, n_hidden_layers=3,
overrides={ overrides={
"n_hidden_nodes_0": 2, "n_hidden_nodes_0": 4,
"n_hidden_nodes_1": 2, "n_hidden_nodes_1": 4,
"n_hidden_nodes_2": 2, "n_hidden_nodes_2": 4,
}, },
dropout_prob=0.01, dropout_prob=0,
model_layer_function="ONNRect", model_layer_function="ONNRect",
model_activation_func="EOActivation", model_activation_func="EOActivation",
model_layer_kwargs={"square": True}, model_layer_kwargs={"square": True},
@@ -110,20 +110,24 @@ model_settings = ModelSettings(
) )
optimizer_settings = OptimizerSettings( optimizer_settings = OptimizerSettings(
optimizer="AdamW", optimizer="RMSprop",
# optimizer="AdamW",
optimizer_kwargs={ optimizer_kwargs={
"lr": 0.005, "lr": 0.01,
"amsgrad": True, "alpha": 0.9,
"momentum": 0.1,
"eps": 1e-8,
"centered": True,
# "amsgrad": True,
# "weight_decay": 1e-7, # "weight_decay": 1e-7,
}, },
# learning_rate=0.05,
scheduler="ReduceLROnPlateau", scheduler="ReduceLROnPlateau",
scheduler_kwargs={ scheduler_kwargs={
"patience": 2**6, "patience": 2**5,
"factor": 0.75, "factor": 0.75,
# "threshold": 1e-3, # "threshold": 1e-3,
"min_lr": 1e-6, "min_lr": 1e-6,
"cooldown": 10, # "cooldown": 10,
}, },
) )

View File

@@ -319,6 +319,29 @@ class normalize_by_first(nn.Module):
def forward(self, data): def forward(self, data):
return data / data[:, 0].unsqueeze(1) return data / data[:, 0].unsqueeze(1)
class rotate(nn.Module):
def __init__(self):
super(rotate, self).__init__()
def forward(self, data, angle):
# data -> (batch, n*2)
# angle -> (batch, n)
data_ = data
if angle.ndim == 1:
angle_ = angle.unsqueeze(1)
else:
angle_ = angle
angle_ = angle_.expand(-1, data_.shape[1]//2)
c = torch.cos(angle_)
s = torch.sin(angle_)
rot = torch.stack([torch.stack([c, -s], dim=2),
torch.stack([s, c], dim=2)], dim=3)
d = torch.bmm(data_.reshape(-1, 1, 2), rot.view(-1, 2, 2).to(dtype=data_.dtype)).reshape(*data.shape)
# d = torch.bmm(data.unsqueeze(-1).mT, rot.to(dtype=data.dtype).mT).mT.squeeze(-1)
return d
class photodiode(nn.Module): class photodiode(nn.Module):
def __init__(self, size, bias=True): def __init__(self, size, bias=True):
@@ -418,8 +441,7 @@ class input_rotator(nn.Module):
# return out # return out
#### as defined by zhang et al #### as defined by zhang et alas
class DropoutComplex(nn.Module): class DropoutComplex(nn.Module):
def __init__(self, p=0.5): def __init__(self, p=0.5):
@@ -441,7 +463,7 @@ class Scale(nn.Module):
self.scale = nn.Parameter(torch.ones(size, dtype=torch.float32)) self.scale = nn.Parameter(torch.ones(size, dtype=torch.float32))
def forward(self, x): def forward(self, x):
return x * self.scale return x * torch.sqrt(self.scale)
def __repr__(self): def __repr__(self):
return f"Scale({self.size})" return f"Scale({self.size})"
@@ -458,6 +480,15 @@ class Identity(nn.Module):
def forward(self, x): def forward(self, x):
return x return x
class phase_shift(nn.Module):
def __init__(self, size):
super(phase_shift, self).__init__()
self.size = size
self.phase = nn.Parameter(torch.rand(size))
def forward(self, x):
return x * torch.exp(1j*self.phase)
class PowRot(nn.Module): class PowRot(nn.Module):
@@ -487,7 +518,7 @@ class MZISingle(nn.Module):
return x * torch.exp(1j * self.phi) * torch.sin(self.omega + self.func(x)) return x * torch.exp(1j * self.phi) * torch.sin(self.omega + self.func(x))
def naive_angle_loss(x: torch.Tensor, target: torch.Tensor, mod=2*torch.pi): def naive_angle_loss(x: torch.Tensor, target: torch.Tensor, mod=2*torch.pi):
return torch.fmod((x - target), mod).square().mean() return torch.fmod((x.abs().real - target.abs().real), mod).abs().mean()
def cosine_loss(x: torch.Tensor, target: torch.Tensor): def cosine_loss(x: torch.Tensor, target: torch.Tensor):
return (2*(1 - torch.cos(x - target))).mean() return (2*(1 - torch.cos(x - target))).mean()
@@ -508,51 +539,46 @@ def angle_mse_loss(x: torch.Tensor, target: torch.Tensor):
class EOActivation(nn.Module): class EOActivation(nn.Module):
def __init__(self, size=None): def __init__(self, size=None):
# 10.1109/SiPhotonics60897.2024.10543376 # 10.1109/JSTQE.2019.2930455
super(EOActivation, self).__init__() super(EOActivation, self).__init__()
if size is None: if size is None:
raise ValueError("Size must be specified") raise ValueError("Size must be specified")
self.size = size self.size = size
self.alpha = nn.Parameter(torch.ones(size)) self.alpha = nn.Parameter(torch.rand(size))
self.V_bias = nn.Parameter(torch.ones(size)) self.gain = nn.Parameter(torch.rand(size))
self.gain = nn.Parameter(torch.ones(size)) self.V_bias = nn.Parameter(torch.rand(size))
# if bias: # self.register_buffer("gain", torch.ones(size))
# self.phase_bias = nn.Parameter(torch.zeros(size)) # self.register_buffer("responsivity", torch.ones(size))
# else: # self.register_buffer("V_pi", torch.ones(size))
# self.register_buffer("phase_bias", torch.zeros(size))
self.register_buffer("phase_bias", torch.clamp(torch.ones(size) + torch.randn(size)*0.1, 0, 1)*torch.pi)
self.register_buffer("responsivity", torch.ones(size)*0.9)
self.register_buffer("V_pi", torch.ones(size)*3)
self.reset_weights() self.reset_weights()
def reset_weights(self): def reset_weights(self):
if "alpha" in self._parameters: if "alpha" in self._parameters:
self.alpha.data = torch.ones(self.size)*0.5 self.alpha.data = torch.rand(self.size)
if "V_pi" in self._parameters: # if "V_pi" in self._parameters:
self.V_pi.data = torch.ones(self.size)*3 # self.V_pi.data = torch.rand(self.size)*3
if "V_bias" in self._parameters: if "V_bias" in self._parameters:
self.V_bias.data = torch.zeros(self.size) self.V_bias.data = torch.randn(self.size)
if "gain" in self._parameters: if "gain" in self._parameters:
self.gain.data = torch.ones(self.size) self.gain.data = torch.rand(self.size)
if "responsivity" in self._parameters: # if "responsivity" in self._parameters:
self.responsivity.data = torch.ones(self.size)*0.9 # self.responsivity.data = torch.ones(self.size)*0.9
if "bias" in self._parameters: # if "bias" in self._parameters:
self.phase_bias.data = torch.zeros(self.size) # self.phase_bias.data = torch.zeros(self.size)
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
phi_b = torch.pi * self.V_bias / (self.V_pi + 1e-8) phi_b = torch.pi * self.V_bias# / (self.V_pi)
g_phi = torch.pi * (self.alpha * self.gain * self.responsivity) / (self.V_pi + 1e-8) g_phi = torch.pi * (self.alpha * self.gain)# * self.responsivity)# / (self.V_pi)
intermediate = g_phi * x.abs().square() + phi_b intermediate = g_phi * x.abs().square() + phi_b
return ( return (
1j 1j
* torch.sqrt(1 - self.alpha) * torch.sqrt(1 - self.alpha)
* torch.exp(-0.5j * (intermediate + self.phase_bias)) * torch.exp(-0.5j * intermediate)
* torch.cos(0.5 * intermediate) * torch.cos(0.5 * intermediate)
* x * x
) )
class Pow(nn.Module): class Pow(nn.Module):
""" """
implements the activation function implements the activation function
@@ -693,6 +719,7 @@ __all__ = [
MZISingle, MZISingle,
EOActivation, EOActivation,
photodiode, photodiode,
phase_shift,
# SaturableAbsorberLambertW, # SaturableAbsorberLambertW,
# SaturableAbsorber, # SaturableAbsorber,
# SpreadLayer, # SpreadLayer,

View File

@@ -0,0 +1,105 @@
# Copyright (c) 2015, Warren Weckesser. All rights reserved.
# This software is licensed according to the "BSD 2-clause" license.
import hashlib
import h5py
import numpy as _np
from scipy.interpolate import interp1d as _interp1d
from scipy.ndimage import gaussian_filter as _gaussian_filter
from ._brescount import bres_curve_count as _bres_curve_count
from pathlib import Path
__all__ = ['grid_count']
def grid_count(y, window_size, offset=0, size=None, fuzz=True, blur=0, bounds=None):
"""
Parameters
----------
`y` is the 1-d array of signal samples.
`window_size` is the number of samples to show horizontally in the
eye diagram. Typically this is twice the number of samples in a
"symbol" (i.e. in a data bit).
`offset` is the number of initial samples to skip before computing
the eye diagram. This allows the overall phase of the diagram to
be adjusted.
`size` must be a tuple of two integers. It sets the size of the
array of counts, (height, width). The default is (800, 640).
`fuzz`: If True, the values in `y` are reinterpolated with a
random "fuzz factor" before plotting in the eye diagram. This
reduces an aliasing-like effect that arises with the use of
Bresenham's algorithm.
`bounds` must be a tuple of two floating point values, (ymin, ymax).
These set the y range of the returned array. If not given, the
bounds are `(y.min() - 0.05*A, y.max() + 0.05*A)`, where `A` is
`y.max() - y.min()`.
Return Value
------------
Returns a numpy array of integers.
"""
# hash input params
param_ob = (y, window_size, offset, size, fuzz, blur, bounds)
param_hash = hashlib.md5(str(param_ob).encode()).hexdigest()
cache_dir = Path.home()/".eyediagram"/".cache"
cache_dir.mkdir(parents=True, exist_ok=True)
if (cache_dir/param_hash).is_file():
try:
with h5py.File(cache_dir/param_hash, "r") as infile:
counts = infile["counts"][:]
if counts.len() != 0:
return counts
except:
pass
if size is None:
size = (800, 640)
height, width = size
dt = width / window_size
counts = _np.zeros((width, height), dtype=_np.int32)
if bounds is None:
ymin = y.min()
ymax = y.max()
yamp = ymax - ymin
ymin = ymin - 0.05*yamp
ymax = ymax + 0.05*yamp
ymax = _np.ceil(ymax*10)/10
ymin = _np.floor(ymin*10)/10
else:
ymin, ymax = bounds
start = offset
while start + window_size < len(y):
end = start + window_size
yy = y[start:end+1]
k = _np.arange(len(yy))
xx = dt*k
if fuzz:
f = _interp1d(xx, yy, kind='cubic')
jiggle = dt*(_np.random.beta(a=3, b=3, size=len(xx)-2) - 0.5)
xx[1:-1] += jiggle
yd = f(xx)
else:
yd = yy
iyd = (height * (yd - ymin)/(ymax - ymin)).astype(_np.int32)
_bres_curve_count(xx.astype(_np.int32), iyd, counts)
start = end
if blur != 0:
counts = _gaussian_filter(counts, sigma=blur)
with h5py.File(cache_dir/param_hash, "w") as outfile:
outfile.create_dataset("data", data=counts)
return counts

View File

@@ -1,10 +1,12 @@
from pathlib import Path from pathlib import Path
import h5py
import torch import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
# from torch.utils.data import Sampler # from torch.utils.data import Sampler
import numpy as np import numpy as np
import configparser import configparser
import multiprocessing as mp
# class SubsetSampler(Sampler[int]): # class SubsetSampler(Sampler[int]):
# """ # """
@@ -24,7 +26,22 @@ import configparser
# return len(self.indices) # return len(self.indices)
def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=False, device=None, dtype=None): def load_from_file(datapath):
if str(datapath).endswith(".h5"):
symbols = None
with h5py.File(datapath, "r") as infile:
data = infile["data"][:]
try:
symbols = np.swapaxes(infile["symbols"][:], 0, 1)
except KeyError:
pass
else:
symbols = None
data = np.load(datapath)
return data, symbols
def load_data(config_path, skipfirst=0, skiplast=0, symbols=None, real=False, normalize=1, device=None, dtype=None):
filepath = Path(config_path) filepath = Path(config_path)
filepath = filepath.parent.glob(filepath.name) filepath = filepath.parent.glob(filepath.name)
config = configparser.ConfigParser() config = configparser.ConfigParser()
@@ -40,14 +57,28 @@ def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=Fals
if symbols is None: if symbols is None:
symbols = int(config["glova"]["nos"]) - skipfirst symbols = int(config["glova"]["nos"]) - skipfirst
data = np.load(datapath)[int(skipfirst * sps) : int(symbols * sps + skipfirst * sps)] data, orig_symbols = load_from_file(datapath)
timestamps = np.arange(int(skipfirst * sps), int(symbols * sps + skipfirst * sps))
if normalize: data = data[int(skipfirst * sps) : int(symbols * sps + skipfirst * sps - skiplast * sps)]
# square gets normalized to 1, as the power is (proportional to) the square of the amplitude orig_symbols = orig_symbols[skipfirst : symbols + skipfirst - skiplast]
a, b, c, d = np.square(data.T) timestamps = np.arange(int(skipfirst * sps), int(symbols * sps + skipfirst * sps - skiplast * sps))
a, b, c, d = a / np.max(np.abs(a)), b / np.max(np.abs(b)), c / np.max(np.abs(c)), d / np.max(np.abs(d))
data = np.sqrt(np.array([a, b, c, d]).T) data *= np.sqrt(normalize)
launch_power = float(config["signal"]["laser_power"])
output_power = float(config["signal"]["edfa_power"])
target_normalization = 10 ** (output_power / 10) / 10 ** (launch_power / 10)
# target_normalization *= 0.5 # allow 50% power loss, so the network can ignore parts of the signal
data[:, 0:2] *= np.sqrt(target_normalization)
# if normalize:
# # square gets normalized to 1, as the power is (proportional to) the square of the amplitude
# a, b, c, d = data.T
# a, b, c, d = a - np.min(np.abs(a)), b - np.min(np.abs(b)), c - np.min(np.abs(c)), d - np.min(np.abs(d))
# a, b, c, d = a / np.max(np.abs(a)), b / np.max(np.abs(b)), c / np.max(np.abs(c)), d / np.max(np.abs(d))
# data = np.array([a, b, c, d]).T
if real: if real:
data = np.abs(data) data = np.abs(data)
@@ -58,7 +89,7 @@ def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=Fals
data = torch.tensor(data, device=device, dtype=dtype) data = torch.tensor(data, device=device, dtype=dtype)
return data, config return data, config, orig_symbols
def roll_along(arr, shifts, dim): def roll_along(arr, shifts, dim):
@@ -110,11 +141,15 @@ class FiberRegenerationDataset(Dataset):
target_delay: float | int = 0, target_delay: float | int = 0,
xy_delay: float | int = 0, xy_delay: float | int = 0,
drop_first: float | int = 0, drop_first: float | int = 0,
drop_last=0,
dtype: torch.dtype = None, dtype: torch.dtype = None,
real: bool = False, real: bool = False,
device=None, device=None,
polarisations: tuple | list = (0,), # osnr: float|None = None,
polarisations=None,
randomise_polarisations: bool = False, randomise_polarisations: bool = False,
repeat_randoms: int = 1,
# cross_pol_interference: float = 0,
**kwargs, **kwargs,
): ):
""" """
@@ -148,64 +183,53 @@ class FiberRegenerationDataset(Dataset):
assert drop_first >= 0, "drop_first must be non-negative" assert drop_first >= 0, "drop_first must be non-negative"
self.randomise_polarisations = randomise_polarisations self.randomise_polarisations = randomise_polarisations
# self.cross_pol_interference = cross_pol_interference
faux = kwargs.pop("faux", False) data_raw = None
self.config = None
if faux: files = []
data_raw = np.array( self.orig_symbols = None
[[i + 0.1j, i + 0.2j, i + 1.1j, i + 1.2j] for i in range(12800)], for file_path in file_path if isinstance(file_path, (tuple, list)) else [file_path]:
dtype=np.complex128, data, config, orig_syms = load_data(
file_path,
skipfirst=drop_first,
skiplast=drop_last,
symbols=kwargs.get("num_symbols", None),
real=real,
normalize=1000,
device=device,
dtype=dtype,
) )
data_raw = torch.tensor(data_raw, device=device, dtype=dtype) if orig_syms is not None:
timestamps = torch.arange(12800) if self.orig_symbols is None:
self.orig_symbols = orig_syms
data_raw = torch.concatenate([data_raw, timestamps.reshape(-1, 1)], axis=-1)
self.config = {
"data": {"dir": '"."', "npy_dir": '"."', "file": "faux"},
"glova": {"sps": 128},
}
else:
data_raw = None
self.config = None
files = []
for file_path in file_path if isinstance(file_path, (tuple, list)) else [file_path]:
data, config = load_data(
file_path,
skipfirst=drop_first,
symbols=kwargs.get("num_symbols", None),
real=real,
normalize=True,
device=device,
dtype=dtype,
)
if data_raw is None:
data_raw = data
else: else:
data_raw = torch.cat([data_raw, data], dim=0) self.orig_symbols = np.concat((self.orig_symbols, orig_syms), axis=-1)
if self.config is None:
self.config = config
else:
assert self.config["glova"]["sps"] == config["glova"]["sps"], "samples per symbol must be the same"
files.append(config["data"]["file"].strip('"'))
self.config["data"]["file"] = str(files)
for i, angle in enumerate(torch.tensor(np.array(polarisations))): if data_raw is None:
data_raw_copy = data_raw.clone() data_raw = data
if angle == 0: else:
continue data_raw = torch.cat([data_raw, data], dim=0)
sine = torch.sin(angle) if self.config is None:
cosine = torch.cos(angle) self.config = config
data_raw_copy[:, 2] = data_raw[:, 2] * cosine - data_raw[:, 3] * sine else:
data_raw_copy[:, 3] = data_raw[:, 2] * sine + data_raw[:, 3] * cosine assert self.config["glova"]["sps"] == config["glova"]["sps"], "samples per symbol must be the same"
if i == 0: files.append(config["data"]["file"].strip('"'))
data_raw = data_raw_copy self.config["data"]["file"] = str(files)
else:
data_raw = torch.cat([data_raw, data_raw_copy], dim=0) # if polarisations is not None:
# data_raw_clone = data_raw.clone()
# # rotate the polarisation by 180 degrees
# data_raw_clone[2, :] *= -1
# data_raw_clone[3, :] *= -1
# data_raw = torch.cat([data_raw, data_raw_clone], dim=0)
self.polarisations = bool(polarisations)
self.device = data_raw.device self.device = data_raw.device
self.samples_per_symbol = int(self.config["glova"]["sps"]) self.samples_per_symbol = int(self.config["glova"]["sps"])
# self.num_symbols = int(self.config["glova"]["nos"])
self.samples_per_slice = int(symbols * self.samples_per_symbol) self.samples_per_slice = int(symbols * self.samples_per_symbol)
self.symbols_per_slice = self.samples_per_slice / self.samples_per_symbol self.symbols_per_slice = self.samples_per_slice / self.samples_per_symbol
@@ -278,23 +302,94 @@ class FiberRegenerationDataset(Dataset):
timestamps = data_raw[4, :] timestamps = data_raw[4, :]
data_raw = data_raw[:4, :] data_raw = data_raw[:4, :]
data_raw = data_raw.view(2, 2, -1) data_raw = data_raw.view(2, 2, -1)
timestamps_doubled = torch.cat([timestamps.unsqueeze(dim=0), timestamps.unsqueeze(dim=0)], dim=0).unsqueeze( fiber_in = data_raw[0, :, :]
dim=1 fiber_out = data_raw[1, :, :]
) # timestamps_doubled = torch.cat([timestamps.unsqueeze(dim=0), timestamps.unsqueeze(dim=0)], dim=0).unsqueeze(
data_raw = torch.cat([data_raw, timestamps_doubled], dim=1) # dim=1
# )
fiber_in = torch.cat([fiber_in, timestamps.unsqueeze(0)], dim=0)
fiber_out = torch.cat([fiber_out, timestamps.unsqueeze(0)], dim=0)
# fiber_out: [E_out_x, E_out_y, timestamps]
# add noise related to amplification necessary due to splitting of the signal
# gain_lin = output_dim*2
# gain_lin = 1
# edfa_nf = float(self.config["signal"]["edfa_nf"])
# nf_lin = 10**(edfa_nf/10)
# f0 = float(self.config["glova"]["f0"])
# noise_add = (gain_lin-1)*nf_lin*6.626e-34*f0*12.5e9
# noise = torch.randn_like(fiber_out[:2, :])
# noise_power = torch.sum(torch.mean(noise.abs().square().squeeze(), dim=-1))
# noise = noise * torch.sqrt(noise_add / noise_power)
# fiber_out[:2, :] += noise
# if osnr is None:
# noisy = fiber_out[:2, :]
# else:
# noisy = self.add_noise(fiber_out[:2, :], osnr)
# fiber_out = torch.cat([fiber_out, noisy], dim=0)
# fiber_out: [E_out_x, E_out_y, timestamps, E_out_x_noisy, E_out_y_noisy]
if repeat_randoms > 1:
fiber_in = fiber_in.repeat(1, 1, repeat_randoms)
fiber_out = fiber_out.repeat(1, 1, repeat_randoms)
# review: potential problems with repeated timestamps when plotting
else:
repeat_randoms = 1
if self.randomise_polarisations:
angles = torch.fmod(torch.randperm(data_raw.shape[-1]*repeat_randoms, device=fiber_out.device), 2) * torch.pi
start_angle = torch.rand(1) * 2 * torch.pi
angles = start_angle + torch.cumsum(torch.randn(data_raw.shape[-1])/333, dim=0) * torch.pi * 2 # random walk
angles = torch.randn(data_raw.shape[-1], device=fiber_out.device) * 2*torch.pi / 36 # sigma = 10 degrees
# self.angles = torch.rand(self.data.shape[0]) * 2 * torch.pi
else:
angles = torch.zeros(data_raw.shape[-1], device=fiber_out.device)
sin = torch.sin(angles)
cos = torch.cos(angles)
rot = torch.stack([torch.stack([cos, -sin], dim=1), torch.stack([sin, cos], dim=1)], dim=2)
data_rot = torch.bmm(fiber_out[:2, :].T.unsqueeze(1), rot.to(dtype=data_raw.dtype)).squeeze(1).T
# data_rot_noisy = torch.bmm(fiber_out[3:5, :].T.unsqueeze(1), rot.to(dtype=data_raw.dtype)).squeeze(1).T
fiber_out = torch.cat((fiber_out, data_rot), dim=0)
fiber_out = torch.cat([fiber_out, angles.unsqueeze(0)], dim=0)
# fiber_in:
# 0 E_in_x,
# 1 E_in_y,
# 2 timestamps
# fiber_out:
# 0 E_out_x,
# 1 E_out_y,
# 2 timestamps,
# 3 E_out_x_rot,
# 4 E_out_y_rot,
# 5 angle
# data_raw = torch.cat([data_raw, timestamps_doubled], dim=1) # data_raw = torch.cat([data_raw, timestamps_doubled], dim=1)
# data layout # data layout
# [ [E_in_x, E_in_y, timestamps], # [ [E_in_x, E_in_y, timestamps],
# [E_out_x, E_out_y, timestamps] ] # [E_out_x, E_out_y, timestamps] ]
self.data = data_raw.unfold(dimension=-1, size=self.samples_per_slice, step=1) self.fiber_in = fiber_in.unfold(dimension=-1, size=self.samples_per_slice, step=1)
self.data = self.data.movedim(-2, 0) self.fiber_in = self.fiber_in.movedim(-2, 0)
if randomise_polarisations: self.fiber_out = fiber_out.unfold(dimension=-1, size=self.samples_per_slice, step=1)
self.angles = torch.rand(self.data.shape[0]) * np.pi * 2 self.fiber_out = self.fiber_out.movedim(-2, 0)
# self.data[:, 1, :2, :] = self.rotate(self.data[:, 1, :2, :], self.angles)
else: # if self.randomise_polarisations:
self.angles = torch.zeros(self.data.shape[0]) # self.angles = torch.cumsum((torch.rand(self.fiber_out.shape[0]) - 0.5) * 2 * torch.pi * 2 / 5000, dim=0)
# self.data = data_raw.unfold(dimension=-1, size=self.samples_per_slice, step=1)
# self.data = self.data.movedim(-2, 0)
# self.angles = torch.zeros(self.data.shape[0])
...
# ... # ...
# -> [no_slices, 2, 3, samples_per_slice] # -> [no_slices, 2, 3, samples_per_slice]
@@ -305,77 +400,116 @@ class FiberRegenerationDataset(Dataset):
# ... # ...
# ] -> [no_slices, 2, 3, samples_per_slice] # ] -> [no_slices, 2, 3, samples_per_slice]
...
def __len__(self): def __len__(self):
return self.data.shape[0] return self.fiber_in.shape[0]
def add_noise(self, data, osnr):
osnr_lin = 10 ** (osnr / 10)
popt = torch.mean(data.abs().square().squeeze(), dim=-1)
noise = torch.randn_like(data)
pn = torch.mean(noise.abs().square().squeeze(), dim=-1)
mult = torch.sqrt(popt / (pn * osnr_lin))
mult = mult * torch.eye(popt.shape[0], device=mult.device)
mult = mult.to(dtype=noise.dtype)
noise = mult @ noise
pn = torch.mean(noise.abs().square().squeeze(), dim=-1)
noisy = data + noise
return noisy
def __getitem__(self, idx): def __getitem__(self, idx):
if isinstance(idx, slice): if isinstance(idx, slice):
return [self.__getitem__(i) for i in range(*idx.indices(len(self)))] return [self.__getitem__(i) for i in range(*idx.indices(len(self)))]
else: else:
data_slice = self.data[idx].squeeze() # fiber in: [E_in_x, E_in_y, timestamps]
# fiber out: [E_out_x, E_out_y, timestamps, E_out_x_rot, E_out_y_rot, angle]
data_slice = data_slice[:, :, : data_slice.shape[2] // self.output_dim * self.output_dim]
data_slice = data_slice.view(data_slice.shape[0], data_slice.shape[1], self.output_dim, -1) # if self.polarisations:
output_dim = self.output_dim // 2
self.output_dim = output_dim * 2
# if self.randomise_polarisations: if not self.polarisations:
# angle = torch.rand(1) * torch.pi * 2 output_dim = 2 * output_dim
# sine = torch.sin(angle)
# cosine = torch.cos(angle)
# data_slice_ = data_slice[1]
# data_slice[1, 0] = data_slice_[0] * cosine - data_slice_[1] * sine
# data_slice[1,1] = data_slice_[0] * sine + data_slice_[1] * cosine
# else:
# angle = torch.zeros(1)
# data = data_slice[1, :2, :, 0]
angle = self.angles[idx] fiber_in = self.fiber_in[idx].squeeze()
fiber_out = self.fiber_out[idx].squeeze()
data_index = 1 fiber_in = fiber_in[..., : fiber_in.shape[-1] // output_dim * output_dim]
fiber_out = fiber_out[..., : fiber_out.shape[-1] // output_dim * output_dim]
data_slice[1, :2, :, :] = self.rotate(data_slice[data_index, :2, :, :], angle) fiber_in = fiber_in.view(fiber_in.shape[0], output_dim, -1)
fiber_out = fiber_out.view(fiber_out.shape[0], output_dim, -1)
data = data_slice[1, :2, :, 0] center_angle = fiber_out[5, output_dim // 2, 0]
# data = self.rotate(data, angle) angles = fiber_out[5, :, 0]
plot_data_rot = fiber_out[3:5, output_dim // 2, 0].detach().clone()
data = fiber_out[0:2, :, 0]
plot_data = fiber_out[0:2, output_dim // 2, 0].detach().clone()
# for both polarisations (x/y), calculate the mean of the signal around the current symbol (-> corresponding to a lowpass filter) target = fiber_in[:2, output_dim // 2, 0]
angle_data = data_slice[1, :2, :, :].reshape(2, -1).mean(dim=1) plot_target = fiber_in[:2, output_dim // 2, 0].detach().clone()
angle_data2 = self.complex_max(data_slice[1, :2, :, :].reshape(2, -1)) target_timestamp = fiber_in[2, output_dim // 2, 0].real
plot_data = data_slice[1, :2, self.output_dim // 2, 0]
sop = self.polarimeter(plot_data)
# angle_data = data_slice[1, :2, :self.output_dim//2, :].squeeze().reshape(2, -1).mean(dim=-1)
# angle = data_slice[1, 3, self.output_dim // 2, 0].real
target = data_slice[0, :2, self.output_dim // 2, 0]
target_timestamp = data_slice[0, 2, self.output_dim // 2, 0].real
... ...
# data_timestamps = data[-1,:].real if self.polarisations:
# data = data[:-1, :] rot = int(np.random.randint(2) * 2 - 1)
# target_timestamp = target[-1].real data = rot * data
# target = target[:-1] target = rot * target
# plot_data = plot_data[:-1] plot_data_rot = rot * plot_data_rot
center_angle = center_angle + (rot - 1) * torch.pi / 2 # rot: -1 or 1 -> -2 or 0 -> -pi or 0
angles = angles + (rot - 1) * torch.pi / 2
pol_flipped_data = -data
pol_flipped_target = -target
# transpose to interleave the x and y data in the output tensor # transpose to interleave the x and y data in the output tensor
data = data.transpose(0, 1).flatten().squeeze() data = data.transpose(0, 1).flatten().squeeze()
angle_data = angle_data.flatten().squeeze() data = data / torch.sqrt(torch.ones(1) * len(data)) # power loss due to splitting
angle_data2 = angle_data.flatten().squeeze() pol_flipped_data = pol_flipped_data.transpose(0, 1).flatten().squeeze()
angle = angle.flatten().squeeze() pol_flipped_data = pol_flipped_data / torch.sqrt(
torch.ones(1) * len(pol_flipped_data)
) # power loss due to splitting
# angle_data = angle_data.transpose(0, 1).flatten().squeeze()
# angle_data2 = angle_data2.transpose(0,1).flatten().squeeze()
center_angle = center_angle.flatten().squeeze()
angles = angles.flatten().squeeze()
# data_timestamps = data_timestamps.flatten().squeeze() # data_timestamps = data_timestamps.flatten().squeeze()
# target = target.transpose(0,1).flatten().squeeze()
target = target.flatten().squeeze() target = target.flatten().squeeze()
pol_flipped_target = pol_flipped_target.flatten().squeeze()
target_timestamp = target_timestamp.flatten().squeeze() target_timestamp = target_timestamp.flatten().squeeze()
plot_target = plot_target.flatten().squeeze()
plot_data = plot_data.flatten().squeeze()
plot_data_rot = plot_data_rot.flatten().squeeze()
return {
"x": data,
"x_flipped": pol_flipped_data,
"x_stacked": torch.cat([data, pol_flipped_data], dim=-1),
"y": target,
"y_flipped": pol_flipped_target,
"y_stacked": torch.cat([target, pol_flipped_target], dim=-1),
"center_angle": center_angle,
"angles": angles,
"mean_angle": angles.mean(),
# "sop": sop,
# "angle_data": angle_data,
# "angle_data2": angle_data2,
"timestamp": target_timestamp,
"plot_target": plot_target,
"plot_data": plot_data,
"plot_data_rot": plot_data_rot,
# "plot_clean": fiber_out_plot_clean,
}
return {"x": data, "y": target, "angle": angle, "sop": sop, "angle_data": angle_data, "angle_data2": angle_data2, "timestamp": target_timestamp, "plot_data": plot_data}
def complex_max(self, data, dim=-1): def complex_max(self, data, dim=-1):
# returns element(s) with the maximum absolute value along a given dimension # returns element(s) with the maximum absolute value along a given dimension
# ind = torch.argmax(data.abs(), dim=dim, keepdim=True) # ind = torch.argmax(data.abs(), dim=dim, keepdim=True)
# max_values = torch.gather(data, dim, ind).squeeze(dim=dim) # max_values = torch.gather(data, dim, ind).squeeze(dim=dim)
# return max_values # return max_values
return torch.gather(data, dim, torch.argmax(data.abs(), dim=dim, keepdim=True)).squeeze(dim=dim) return torch.gather(data, dim, torch.argmax(data.abs(), dim=dim, keepdim=True)).squeeze(dim=dim)
def rotate(self, data, angle): def rotate(self, data, angle):
# rotates a 2d tensor by a given angle # rotates a 2d tensor by a given angle
@@ -388,7 +522,25 @@ class FiberRegenerationDataset(Dataset):
cosine = torch.cos(angle) cosine = torch.cos(angle)
return torch.stack([data[0] * cosine - data[1] * sine, data[0] * sine + data[1] * cosine], dim=0) return torch.stack([data[0] * cosine - data[1] * sine, data[0] * sine + data[1] * cosine], dim=0)
def rotate_all(self):
def do_rotation(j, num_processes):
for i in range(len(self) // num_processes):
index = i * num_processes + j
self.data[index, 1, :2, :] = self.rotate(self.data[index, 1, :2, :], self.angles[index])
self.processes = []
for j in range(mp.cpu_count()):
self.processes.append(mp.Process(target=do_rotation, args=(j, mp.cpu_count())))
self.processes[-1].start()
for p in self.processes:
p.join()
for i in range(len(self) // mp.cpu_count() * mp.cpu_count(), len(self)):
self.data[i, 1, :2, :] = self.rotate(self.data[i, 1, :2, :], self.angles[i])
def polarimeter(self, data): def polarimeter(self, data):
# data: [2, ...] -> x, y # data: [2, ...] -> x, y
# returns [4] -> S0, S1, S2, S3 # returns [4] -> S0, S1, S2, S3
@@ -396,12 +548,12 @@ class FiberRegenerationDataset(Dataset):
y = data[1].mean() y = data[1].mean()
I_X = x.abs().square() I_X = x.abs().square()
I_Y = y.abs().square() I_Y = y.abs().square()
I_45 = (x+y).abs().square() I_45 = (x + y).abs().square()
I_RHC = (x + 1j*y).abs().square() I_RHC = (x + 1j * y).abs().square()
S0 = I_X + I_Y S0 = I_X + I_Y
S1 = (2*I_X - S0) / S0 S1 = (2 * I_X - S0) / S0
S2 = (2*I_45 - S0) / S0 S2 = (2 * I_45 - S0) / S0
S3 = (2*I_RHC - S0) / S0 S3 = (2 * I_RHC - S0) / S0
return torch.stack([S1, S2, S3], dim=0) return torch.stack([S0, S1, S2, S3], dim=0)

View File

@@ -1,16 +1,23 @@
from datetime import datetime
import json
from pathlib import Path
from typing import Literal
import h5py
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from matplotlib.colors import LinearSegmentedColormap from matplotlib.colors import LinearSegmentedColormap
# from cmap import Colormap as cm
import numpy as np import numpy as np
from scipy.cluster.vq import kmeans2 from scipy.cluster.vq import kmeans2
import warnings import warnings
import multiprocessing import multiprocessing
from rich.traceback import install from rich.traceback import install
from rich import pretty
from rich import print
install() install()
pretty.install() # from rich import pretty
# from rich import print
# pretty.install()
def generate_sample_data(n_symbols=1000, sps=128, noise=0.01, skew=1): def generate_sample_data(n_symbols=1000, sps=128, noise=0.01, skew=1):
@@ -21,6 +28,7 @@ def generate_sample_data(n_symbols=1000, sps=128, noise=0.01, skew=1):
xaxis = np.arange(0, len(signal)) / sps xaxis = np.arange(0, len(signal)) / sps
return np.vstack([xaxis, signal]) return np.vstack([xaxis, signal])
def create_symbol_sequence(n_symbols, skew=1): def create_symbol_sequence(n_symbols, skew=1):
np.random.seed(42) np.random.seed(42)
data = np.random.randint(0, 4, n_symbols) / 4 data = np.random.randint(0, 4, n_symbols) / 4
@@ -39,6 +47,14 @@ def generate_signal(data, sps):
signal = np.convolve(data_padded, wavelet) signal = np.convolve(data_padded, wavelet)
signal = np.cumsum(signal) signal = np.cumsum(signal)
signal = signal[sps // 2 + len(wavelet) // 2 - 1 : -len(wavelet) // 2] signal = signal[sps // 2 + len(wavelet) // 2 - 1 : -len(wavelet) // 2]
mi, ma = np.min(signal), np.max(signal)
signal = (signal - mi) / (ma - mi)
mod = 0.8
signal *= mod
signal += 1 - mod
return signal return signal
@@ -49,8 +65,8 @@ def normalization_with_noise(signal, noise=0):
signal += awgn signal += awgn
# min-max normalization # min-max normalization
signal = signal - np.min(signal) # signal = signal - np.min(signal)
signal = signal / np.max(signal) # signal = signal / np.max(signal)
return signal return signal
@@ -68,98 +84,264 @@ def generate_wavelet(sps, oversample=3):
class eye_diagram: class eye_diagram:
def __init__(self, data, *, channel_names=None, horizontal_bins=256, vertical_bins=1000, n_levels=4, multithreaded=True): def __init__(
self,
data,
*,
channel_names=None,
horizontal_bins=256,
vertical_bins=1000,
n_levels=4,
multithreaded=True,
save_file_or_dir=None,
):
# data has shape [channels, 2, samples] # data has shape [channels, 2, samples]
# each sample has a timestamp and a value # each sample has a timestamp and a value
if data.ndim == 2: if data.ndim == 2:
data = data[np.newaxis, :, :] data = data[np.newaxis, :, :]
self.channel_names = channel_names
self.raw_data = data self.raw_data = data
self.channels = data.shape[0]
self.y_bins = np.zeros(1)
self.x_bins = np.zeros(1)
self.eye_data = np.zeros(1)
self.channel_names = channel_names
self.n_channels = data.shape[0]
self.n_levels = n_levels self.n_levels = n_levels
self.eye_stats = [{"success": False} for _ in range(self.channels)] self.eye_stats = [{"success": False} for _ in range(self.n_channels)]
self.horizontal_bins = horizontal_bins self.horizontal_bins = horizontal_bins
self.vertical_bins = vertical_bins self.vertical_bins = vertical_bins
self.multi_threaded = multithreaded self.multi_threaded = multithreaded
self.analysed = False
self.eye_built = False self.eye_built = False
self.analyse()
def generate_eye_data(self): self.save_file = save_file_or_dir
self.x_bins = np.linspace(0, 2, self.horizontal_bins, endpoint=False)
self.y_bins = np.zeros((self.channels, self.vertical_bins)) def load_data(self, file=None):
self.eye_data = np.zeros((self.channels, self.vertical_bins, self.horizontal_bins)) file = self.save_file if file is None else file
datas = [self.raw_data[i] for i in range(self.channels)]
if self.multi_threaded: if file is None:
with multiprocessing.Pool() as pool: raise FileNotFoundError("No file specified.")
results = pool.map(self.generate_eye_data_single, datas)
for i, result in enumerate(results): self.save_file = str(file)
self.eye_data[i], self.y_bins[i] = result # self.file_or_dir = self.save_file
with h5py.File(file, "r") as infile:
self.y_bins = infile["y_bins"][:]
self.x_bins = infile["x_bins"][:]
self.eye_data = infile["eye_data"][:]
self.channel_names = infile.attrs["channel_names"]
self.n_channels = infile.attrs["n_channels"]
self.n_levels = infile.attrs["n_levels"]
self.eye_stats = infile.attrs["eye_stats"]
self.eye_stats = [json.loads(stat) for stat in self.eye_stats]
self.horizontal_bins = infile.attrs["horizontal_bins"]
self.vertical_bins = infile.attrs["vertical_bins"]
self.multi_threaded = infile.attrs["multithreaded"]
self.analysed = infile.attrs["analysed"]
self.eye_built = infile.attrs["eye_built"]
def save_data(self, file_or_dir=None):
file_or_dir = self.save_file if file_or_dir is None else file_or_dir
if file_or_dir is None:
file = Path(f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_eye_data.eye.h5")
elif Path(file_or_dir).is_dir():
file = Path(file_or_dir) / f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_eye_data.eye.h5"
else: else:
for i, data in enumerate(datas): file = Path(file_or_dir)
self.eye_data[i], self.y_bins[i] = self.generate_eye_data_single(data)
self.eye_built = True # file.parent.mkdir(parents=True, exist_ok=True)
self.save_file = str(file)
with h5py.File(file, "w") as outfile:
outfile.create_dataset("eye_data", data=self.eye_data)
outfile.create_dataset("y_bins", data=self.y_bins)
outfile.create_dataset("x_bins", data=self.x_bins)
outfile.attrs["channel_names"] = self.channel_names
outfile.attrs["n_channels"] = self.n_channels
outfile.attrs["n_levels"] = self.n_levels
self.eye_stats = eye_diagram.convert_arrays(self.eye_stats)
outfile.attrs["eye_stats"] = [json.dumps(stat) for stat in self.eye_stats]
outfile.attrs["horizontal_bins"] = self.horizontal_bins
outfile.attrs["vertical_bins"] = self.vertical_bins
outfile.attrs["multithreaded"] = self.multi_threaded
outfile.attrs["analysed"] = self.analysed
outfile.attrs["eye_built"] = self.eye_built
@staticmethod
def convert_arrays(input_object):
"""
convert ndarrays in (nested) dict to lists
"""
if isinstance(input_object, np.ndarray):
return input_object.tolist()
elif isinstance(input_object, list):
return [eye_diagram.convert_arrays(old) for old in input_object]
elif isinstance(input_object, tuple):
return tuple(eye_diagram.convert_arrays(old) for old in input_object)
elif isinstance(input_object, dict):
dict_out = {}
for key, value in input_object.items():
dict_out[key] = eye_diagram.convert_arrays(value)
return dict_out
return input_object
def generate_eye_data(
self, mode: Literal["default", "load", "save", "nosave"] = "default", file_or_dir: Path | None | str = None
):
# modes:
# default: try to load eye data from file, if not found, generate and save
# load: try to load eye data from file, if not found, generate but don't save
# save: generate eye data and save
update_save = True
if mode == "load":
self.load_data(file_or_dir)
update_save = False
elif mode == "default":
try:
self.load_data(file_or_dir)
update_save = False
except (FileNotFoundError, IsADirectoryError):
pass
if not self.eye_built:
update_save = True
self.x_bins = np.linspace(0, 2, self.horizontal_bins, endpoint=False)
self.y_bins = np.zeros((self.n_channels, self.vertical_bins))
self.eye_data = np.zeros((self.n_channels, self.vertical_bins, self.horizontal_bins))
datas = [self.raw_data[i] for i in range(self.n_channels)]
if self.multi_threaded:
with multiprocessing.Pool() as pool:
results = pool.map(self.generate_eye_data_single, datas)
for i, result in enumerate(results):
self.eye_data[i], self.y_bins[i] = result
else:
for i, data in enumerate(datas):
self.eye_data[i], self.y_bins[i] = self.generate_eye_data_single(data)
self.eye_built = True
if mode == "save" or (mode == "default" and update_save):
self.save_data(file_or_dir)
def generate_eye_data_single(self, data): def generate_eye_data_single(self, data):
eye_data = np.zeros((self.vertical_bins, self.horizontal_bins)) eye_data = np.zeros((self.vertical_bins, self.horizontal_bins))
data_min = np.min(data[1, :]) data_min = np.min(data[1, :])
data_max = np.max(data[1, :]) data_max = np.max(data[1, :])
# round down/up to 1 decimal
data_min = np.floor(data_min*10)/10
data_max = np.ceil(data_max*10)/10
# data_range = data_max - data_min
# data_min -= 0.1 * data_range
# data_max += 0.1 * data_range
# data_min = -0.05
# data_max += 0.05
# data[1,:] -= np.min(data[1, :])
# data[1,:] /= np.max(data[1, :])
# data_min = 0
# data_max = 1
y_bins = np.linspace(data_min, data_max, self.vertical_bins, endpoint=False) y_bins = np.linspace(data_min, data_max, self.vertical_bins, endpoint=False)
t_vals = data[0, :] % 2 t_vals = data[0, :] % 2 # + np.random.randn(*data[0, :].shape) * 1 / (512)
val_vals = data[1, :] val_vals = data[1, :] # + np.random.randn(*data[1, :].shape) * 1 / (320)
x_indices = np.digitize(t_vals, self.x_bins) - 1 x_indices = np.digitize(t_vals, self.x_bins) - 1
y_indices = np.digitize(val_vals, y_bins) - 1 y_indices = np.digitize(val_vals, y_bins) - 1
np.add.at(eye_data, (y_indices, x_indices), 1) np.add.at(eye_data, (y_indices, x_indices), 1)
return eye_data, y_bins return eye_data, y_bins
def plot(self, title="Eye Diagram", stats=True, all_stats=True, show=True): def plot(
self,
title="Eye Diagram",
stats=True,
all_stats=True,
show=True,
mode: Literal["default", "load", "save", "nosave"] = "default",
# save_images = False,
# image_dir = None,
# cmap=None,
):
if stats and not self.analysed:
self.analyse(mode=mode)
if not self.eye_built: if not self.eye_built:
self.generate_eye_data() self.generate_eye_data(mode=mode)
cmap = LinearSegmentedColormap.from_list( cmap = LinearSegmentedColormap.from_list(
"eyemap", "eyemap",
[(0, "white"), (0.1, "blue"), (0.2, "cyan"), (0.5, "green"), (0.8, "yellow"), (0.9, "red"), (1, "magenta")], [
(0, "#FFFFFF00"),
(0.1, "blue"),
(0.2, "cyan"),
(0.5, "green"),
(0.8, "yellow"),
(0.9, "red"),
(1, "magenta"),
],
) )
if self.channels % 2 == 0: # cmap = cm('google:turbo_r' if cmap is None else cmap)
# first = cmap(-1)
# cmap = cmap.to_mpl()
# cmap.set_under(first, alpha=0)
if self.n_channels % 2 == 0:
rows = 2 rows = 2
cols = self.channels // 2 cols = self.n_channels // 2
else: else:
cols = int(np.ceil(np.sqrt(self.channels))) cols = int(np.ceil(np.sqrt(self.n_channels)))
rows = int(np.ceil(self.channels / cols)) rows = int(np.ceil(self.n_channels / cols))
fig, ax = plt.subplots(rows, cols, sharex=True, sharey=False) fig, ax = plt.subplots(rows, cols, sharex=True, sharey=False)
fig.suptitle(title) fig.suptitle(title)
fig.tight_layout()
ax = np.atleast_1d(ax).transpose().flatten() ax = np.atleast_1d(ax).transpose().flatten()
for i in range(self.channels): for i in range(self.n_channels):
ax[i].set_title(self.channel_names[i] if self.channel_names is not None else f"Channel {i+1}") ax[i].set_title(self.channel_names[i] if self.channel_names is not None else f"Channel {i + 1}")
if (i+1) % rows == 0: if (i + 1) % rows == 0:
ax[i].set_xlabel("Symbol") ax[i].set_xlabel("Symbol")
if i < rows: if i < rows:
ax[i].set_ylabel("Amplitude") ax[i].set_ylabel("Amplitude")
ax[i].grid() ax[i].grid()
ax[i].set_axisbelow(True)
ax[i].imshow( ax[i].imshow(
self.eye_data[i], self.eye_data[i] - 0.1,
origin="lower", origin="lower",
aspect="auto", aspect="auto",
cmap=cmap, cmap=cmap,
extent=[self.x_bins[0], self.x_bins[-1], self.y_bins[i][0], self.y_bins[i][-1]], extent=[self.x_bins[0], self.x_bins[-1], self.y_bins[i][0], self.y_bins[i][-1]],
interpolation="gaussian",
vmin=0,
zorder=3,
) )
ax[i].set_xlim((self.x_bins[0], self.x_bins[-1])) ax[i].set_xlim((self.x_bins[0], self.x_bins[-1]))
ymin = np.min(self.y_bins[:, 0]) ymin = np.min(self.y_bins[:, 0])
ymax = np.max(self.y_bins[:, -1]) ymax = np.max(self.y_bins[:, -1])
yspan = ymax - ymin yspan = ymax - ymin
ax[i].set_ylim((ymin - 0.1 * yspan, ymax + 0.1 * yspan)) ax[i].set_ylim((ymin - 0.1 * yspan, ymax + 0.1 * yspan))
# if save_images:
# image_dir = "images_out" if image_dir is None else image_dir
# image_path = Path(image_dir) / (slugify(f"{datetime.now().strftime("%Y%m%d_%H%M%S")}_{title.replace(" ","_")}_{self.channel_names[i].replace(" ", "_") if self.channel_names is not None else f"{i + 1}"}_{ymin:.1f}_{ymax:.1f}") + ".png")
# image_path.parent.mkdir(parents=True, exist_ok=True)
# # plt.imsave(
# # image_path,
# # self.eye_data[i] - 0.1,
# # origin="lower",
# # # aspect="auto",
# # cmap=cmap,
# # # extent=[self.x_bins[0], self.x_bins[-1], self.y_bins[i][0], self.y_bins[i][-1]],
# # # interpolation="gaussian",
# # vmin=0,
# # # zorder=3,
# # )
if stats and self.eye_stats[i]["success"]: if stats and self.eye_stats[i]["success"]:
# add min_area above the plot # # add min_area above the plot
ax[i].annotate( # ax[i].annotate(
f"Min Area: {self.eye_stats[i]['min_area']:.2e}", # f"Min Area: {self.eye_stats[i]['min_area']:.2e}",
xy=(0.05, ymax + 0.05 * yspan), # xy=(0.05, ymax + 0.05 * yspan),
# xycoords="axes fraction", # # xycoords="axes fraction",
ha="left", # ha="left",
va="center", # va="center",
bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"), # bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"),
) # )
if all_stats: if all_stats:
ax[i].plot([0, 2], [self.eye_stats[i]["levels"], self.eye_stats[i]["levels"]], "k--") ax[i].plot([0, 2], [self.eye_stats[i]["levels"], self.eye_stats[i]["levels"]], "k--")
ax[i].set_yticks(self.eye_stats[i]["levels"]) y_ticks = (*self.eye_stats[i]["levels"], *self.eye_stats[i]["thresholds"])
# y_ticks = np.sort(y_ticks)
ax[i].set_yticks(y_ticks)
# add arrows for amplitudes # add arrows for amplitudes
for j in range(len(self.eye_stats[i]["amplitudes"])): for j in range(len(self.eye_stats[i]["amplitudes"])):
ax[i].annotate( ax[i].annotate(
@@ -193,62 +375,69 @@ class eye_diagram:
except (ValueError, IndexError): except (ValueError, IndexError):
pass pass
# add arrows for eye widths # add arrows for eye widths
for j in range(len(self.eye_stats[i]["widths"])): # for j in range(len(self.eye_stats[i]["widths"])):
try: # try:
left = np.max(self.eye_stats[i]["time_clusters"][j][0]) # left = np.max(self.eye_stats[i]["time_clusters"][j][0])
right = np.min(self.eye_stats[i]["time_clusters"][j][1]) # right = np.min(self.eye_stats[i]["time_clusters"][j][1])
vertical = (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2 # vertical = (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2
ax[i].annotate( # ax[i].annotate(
"", # "",
xy=(left, vertical), # xy=(left, vertical),
xytext=(right, vertical), # xytext=(right, vertical),
arrowprops=dict(arrowstyle="<->", facecolor="black"), # arrowprops=dict(arrowstyle="<->", facecolor="black"),
) # )
ax[i].annotate( # ax[i].annotate(
f"{self.eye_stats[i]['widths'][j]:.2e}", # f"{self.eye_stats[i]['widths'][j]:.2e}",
xy=((left + right) / 2 - 0.15, vertical + 0.01), # xy=((left + right) / 2 - 0.15, vertical + 0.01),
bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"), # bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"),
) # )
except (ValueError, IndexError): # except (ValueError, IndexError):
pass # pass
# add area # # add area
for j in range(len(self.eye_stats[i]["areas"])): # for j in range(len(self.eye_stats[i]["areas"])):
horizontal = self.eye_stats[i]["time_midpoint"] # horizontal = self.eye_stats[i]["time_midpoint"]
vertical = (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2 # vertical = (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2
ax[i].annotate( # ax[i].annotate(
f"{self.eye_stats[i]['areas'][j]:.2e}", # f"{self.eye_stats[i]['areas'][j]:.2e}",
xy=(horizontal + 0.035, vertical - 0.07), # xy=(horizontal + 0.035, vertical - 0.07),
bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"), # bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"),
) # )
fig.tight_layout() fig.tight_layout()
if show: if show:
plt.show() plt.show()
return fig return fig
@staticmethod
def calculate_thresholds(levels):
ret = np.cumsum(levels, dtype=float)
ret[2:] = ret[2:] - ret[:-2]
return ret[1:] / 2
def analyse_single(self, data, index): def analyse_single(self, data, index):
warnings.filterwarnings("error") warnings.filterwarnings("error")
eye_stats = {} eye_stats = {}
eye_stats["channel_name"] = str(index+1) if self.channel_names is None else self.channel_names[index] eye_stats["channel_name"] = str(index + 1) if self.channel_names is None else self.channel_names[index]
try: try:
approx_levels = eye_diagram.approximate_levels(data, self.n_levels) approx_levels = eye_diagram.approximate_levels(data, self.n_levels)
time_bounds = eye_diagram.calculate_time_bounds(data, approx_levels) time_bounds = eye_diagram.calculate_time_bounds(data, approx_levels)
eye_stats["time_midpoint"] = (time_bounds[0] + time_bounds[1]) / 2 eye_stats["time_midpoint"] = float((time_bounds[0] + time_bounds[1]) / 2)
# eye_stats["time_midpoint"] = 1.0
eye_stats["levels"], eye_stats["amplitude_clusters"] = eye_diagram.calculate_levels( eye_stats["levels"], eye_stats["amplitude_clusters"] = eye_diagram.calculate_levels(
data, approx_levels, time_bounds data, approx_levels, time_bounds
) )
eye_stats["thresholds"] = self.calculate_thresholds(eye_stats["levels"])
eye_stats["amplitudes"] = np.diff(eye_stats["levels"]) eye_stats["amplitudes"] = np.diff(eye_stats["levels"])
eye_stats["heights"] = eye_diagram.calculate_eye_heights( eye_stats["heights"] = eye_diagram.calculate_eye_heights(eye_stats["amplitude_clusters"])
eye_stats["amplitude_clusters"]
)
eye_stats["widths"], eye_stats["time_clusters"] = eye_diagram.calculate_eye_widths( eye_stats["widths"], eye_stats["time_clusters"] = eye_diagram.calculate_eye_widths(
data, eye_stats["levels"] data, eye_stats["levels"]
@@ -260,36 +449,59 @@ class eye_diagram:
# if not (np.max(eye_stats['time_clusters'][j][0]) < eye_stats["time_midpoint"] < np.min(eye_stats['time_clusters'][j][1])): # if not (np.max(eye_stats['time_clusters'][j][0]) < eye_stats["time_midpoint"] < np.min(eye_stats['time_clusters'][j][1])):
# raise ValueError # raise ValueError
eye_stats["areas"] = eye_stats["heights"] * eye_stats["widths"] # eye_stats["areas"] = eye_stats["heights"] * eye_stats["widths"]
eye_stats["mean_area"] = np.mean(eye_stats["areas"]) # eye_stats["mean_area"] = np.mean(eye_stats["areas"])
eye_stats["min_area"] = np.min(eye_stats["areas"]) # eye_stats["min_area"] = np.min(eye_stats["areas"])
eye_stats["success"] = True eye_stats["success"] = True
except (RuntimeWarning, UserWarning, ValueError): except (RuntimeWarning, UserWarning, ValueError):
eye_stats["success"] = False eye_stats["success"] = False
eye_stats["time_midpoint"] = 0 eye_stats["time_midpoint"] = None
eye_stats["levels"] = np.zeros(self.n_levels) eye_stats["levels"] = None
eye_stats["amplitude_clusters"] = [] eye_stats["thresholds"] = None
eye_stats["amplitudes"] = np.zeros(self.n_levels - 1) eye_stats["amplitude_clusters"] = None
eye_stats["heights"] = np.zeros(self.n_levels - 1) eye_stats["amplitudes"] = None
eye_stats["widths"] = np.zeros(self.n_levels - 1) eye_stats["heights"] = None
eye_stats["areas"] = np.zeros(self.n_levels - 1) eye_stats["widths"] = None
eye_stats["mean_area"] = 0 # eye_stats["areas"] = np.zeros(self.n_levels - 1)
eye_stats["min_area"] = 0 # eye_stats["mean_area"] = 0
# eye_stats["min_area"] = 0
warnings.resetwarnings() warnings.resetwarnings()
return eye_stats return eye_stats
def analyse(
self, mode: Literal["default", "load", "save", "nosave"] = "default", file_or_dir: Path | None | str = None
):
# modes:
# default: try to load eye data from file, if not found, generate and save
# load: try to load eye data from file, if not found, generate but don't save
# save: generate eye data and save
update_save = True
if mode == "load":
self.load_data(file_or_dir)
update_save = False
elif mode == "default":
try:
self.load_data(file_or_dir)
update_save = False
except (FileNotFoundError, IsADirectoryError):
pass
def analyse(self): if not self.analysed:
self.eye_stats = [] update_save = True
if self.multi_threaded: self.eye_stats = []
with multiprocessing.Pool() as pool: if self.multi_threaded:
results = pool.starmap(self.analyse_single, [(self.raw_data[i], i) for i in range(self.channels)]) with multiprocessing.Pool() as pool:
for i, result in enumerate(results): results = pool.starmap(self.analyse_single, [(self.raw_data[i], i) for i in range(self.n_channels)])
self.eye_stats.append(result) for i, result in enumerate(results):
else: self.eye_stats.append(result)
for i in range(self.channels): else:
self.eye_stats.append(self.analyse_single(self.raw_data[i], i)) for i in range(self.n_channels):
self.eye_stats.append(self.analyse_single(self.raw_data[i], i))
self.analysed = True
if mode == "save" or (mode == "default" and update_save):
self.save_data(file_or_dir)
@staticmethod @staticmethod
def approximate_levels(data, levels): def approximate_levels(data, levels):
@@ -431,7 +643,7 @@ class eye_diagram:
if __name__ == "__main__": if __name__ == "__main__":
length = int(2**14) length = int(2**16)
# data = generate_sample_data(length, noise=1) # data = generate_sample_data(length, noise=1)
# data1 = generate_sample_data(length, noise=0.01) # data1 = generate_sample_data(length, noise=0.01)
# data2 = generate_sample_data(length, noise=0.01, skew=1.2) # data2 = generate_sample_data(length, noise=0.01, skew=1.2)
@@ -439,12 +651,13 @@ if __name__ == "__main__":
# data = np.stack([data, data1, data2, data3]) # data = np.stack([data, data1, data2, data3])
data = generate_sample_data(length, noise=0.005) data = generate_sample_data(length, noise=0.0000)
eye = eye_diagram(data, horizontal_bins=256, vertical_bins=256) eye = eye_diagram(data, horizontal_bins=1056, vertical_bins=1200)
attrs = ("success", "amplitudes", "time_midpoint", "levels", "heights", "widths", "area", "mean_area", "min_area") eye.plot(mode="nosave", stats=False)
for i, channel in enumerate(eye.eye_stats): # attrs = ("success", "amplitudes", "time_midpoint", "levels", "heights", "widths")#, "area", "mean_area", "min_area")
print(f"Channel {i}") # for i, channel in enumerate(eye.eye_stats):
print_data = {attr: channel[attr] for attr in attrs} # print(f"Channel {i}")
print(print_data) # print_data = {attr: channel[attr] for attr in attrs}
# print(print_data)
eye.plot() # eye.plot()

View File

@@ -0,0 +1,122 @@
# Copyright (c) 2015, Warren Weckesser. All rights reserved.
# This software is licensed according to the "BSD 2-clause" license.
# modified by Joseph Hopfmüller in 2025,
# for integration into optical regeneration analysis scripts
from pathlib import Path
from matplotlib.colors import LinearSegmentedColormap
import matplotlib.colors as colors
import numpy as _np
from .core import grid_count as _grid_count
import matplotlib.pyplot as _plt
import numpy as np
from scipy.ndimage import gaussian_filter
# from ._common import _common_doc
__all__ = ["eyediagram"] # , 'eyediagram_lines']
# def eyediagram_lines(y, window_size, offset=0, **plotkwargs):
# """
# Plot an eye diagram using matplotlib by repeatedly calling the `plot`
# function.
# <common>
# """
# start = offset
# while start < len(y):
# end = start + window_size
# if end > len(y):
# end = len(y)
# yy = y[start:end+1]
# _plt.plot(_np.arange(len(yy)), yy, 'k', **plotkwargs)
# start = end
# eyediagram_lines.__doc__ = eyediagram_lines.__doc__.replace("<common>",
# _common_doc)
eyemap = LinearSegmentedColormap.from_list(
"eyemap",
[
(0, "#0000FF00"),
(0.1, "blue"),
(0.2, "cyan"),
(0.5, "green"),
(0.8, "yellow"),
(0.9, "red"),
(1, "magenta"),
],
)
def eyediagram(
y,
window_size,
offset=0,
colorbar=False,
show=False,
save_im=False,
overwrite=False,
blur: int | bool = True,
save_path="out.png",
bounds=None,
**imshowkwargs,
):
"""
Plot an eye diagram using matplotlib by creating an image and calling
the `imshow` function.
<common>
"""
if bounds is None:
ymax = y.max()
ymin = y.min()
yamp = ymax - ymin
ymin = ymin - 0.05 * yamp
ymax = ymax + 0.05 * yamp
ymin = np.floor(ymin * 10) / 10
ymax = np.ceil(ymax * 10) / 10
bounds = (ymin, ymax)
counts = _grid_count(y, window_size, offset, bounds=bounds, size=(1000, 1200), blur=int(blur))
counts = counts.astype(_np.float32)
origin = imshowkwargs.pop("origin", "lower")
cmap: colors.Colormap = imshowkwargs.pop("cmap", eyemap)
vmin = imshowkwargs.pop("vmin", 1)
vmax = imshowkwargs.pop("vmax", None)
cmap.set_under("white", alpha=0)
if show:
_plt.imshow(
counts.T[::-1, :],
extent=[0, 2, *bounds],
origin=origin,
cmap=cmap,
vmin=vmin,
vmax=vmax,
**imshowkwargs,
)
_plt.grid()
if colorbar:
_plt.colorbar()
if Path(save_path).is_file() and not overwrite:
save_im = False
if save_im:
from PIL import Image
arr = counts.T[::-1, :]
if origin == "lower":
arr = arr[::-1]
arr = (arr-arr.min())/(arr.max()-arr.min())
image = Image.fromarray((cmap(arr)[:, :, :] * 255).astype(np.uint8))
image.save(save_path)
# print("-")
if show:
_plt.show()
# eyediagram.__doc__ = eyediagram.__doc__.replace("<common>", _common_doc)

View File

@@ -1,6 +1,9 @@
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
from .datasets import load_data if __name__ == "__main__":
from datasets import load_data
else:
from .datasets import load_data
def eye(*, path=None, data=None, sps=None, title=None, symbols=1000, skipfirst=0, width=2, alpha=None, complex=False, show=True): def eye(*, path=None, data=None, sps=None, title=None, symbols=1000, skipfirst=0, width=2, alpha=None, complex=False, show=True):
"""Plot an eye diagram for the data given by filepath. """Plot an eye diagram for the data given by filepath.
@@ -20,6 +23,7 @@ def eye(*, path=None, data=None, sps=None, title=None, symbols=1000, skipfirst=0
raise ValueError("Either path or data and sps must be given.") raise ValueError("Either path or data and sps must be given.")
if path is not None: if path is not None:
data, config = load_data(path, skipfirst, symbols) data, config = load_data(path, skipfirst, symbols)
data = data.detach().cpu().numpy()[:, :4]
sps = int(config["glova"]["sps"]) sps = int(config["glova"]["sps"])
if sps is None: if sps is None:
raise ValueError("sps not set.") raise ValueError("sps not set.")
@@ -71,3 +75,6 @@ def eye(*, path=None, data=None, sps=None, title=None, symbols=1000, skipfirst=0
plt.show() plt.show()
return fig return fig
if __name__ == "__main__":
eye(path="data/20241229-163838-128-16384-50000-0-0.2-16.8-0.058-PAM4-0-0.16.ini", symbols=1000, width=2, alpha=0.1, complex=False)