Compare commits

...

5 Commits

Author SHA1 Message Date
Joseph Hopfmüller
249fe1e940 wip 2025-01-27 21:05:49 +01:00
Joseph Hopfmüller
f38d0ca3bb model robustness testing 2025-01-10 23:40:54 +01:00
Joseph Hopfmüller
3af73343c1 add corrected datasets with PMD and dispersion 2024-12-29 16:41:52 +01:00
Joseph Hopfmüller
7a0b65f82d Merge branch 'machine_learning' of git.suuppl.dev:seppl/optical-regeneration into machine_learning 2024-12-29 16:00:41 +01:00
Joseph Hopfmüller
98305fdf47 update dataset configurations, add rotation module, and refine model settings for training, new hyperparameter tuning run for corrected datasets 2024-12-29 16:00:36 +01:00
56 changed files with 3863 additions and 931 deletions

2
.gitignore vendored
View File

@@ -163,3 +163,5 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
tolerance_results/*
data/*

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:bc02b0099ea3bb136733e3d20817cad79b6c50c2e4b845f0d206455dde188cc4
size 615

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d80ff6f2a84acf973fbdf81a05ed0b1902f8bf97856cd5132b646f6b1173f496
size 615

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:54ac6b6a452aa6b7d312a4c8ab8f7ebe2f96c1c4170cbc56147e8f2f9d934ad6
size 615

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:bd2c6f4050488b6e857759d48aa1f1f37399d81cee1667d3668145e938d17c83
size 615

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c59b8113092459a8751b385a7b1a6f10828626d2ec2f29775512157fd9bbc75c
size 615

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:04eaa2a29b3302e5de3bf99bf6a57fc8f27361fd3df3cac9245e25ab99324829
size 615

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:704d3b0b17b9d320f4717b5a5a8bdbc5714f3caa4efa7153e980766429e834f4
size 615

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:29304f35a88fd777566105f8666fff1c9927beb32756822365bcf9c159feb98e
size 615

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:4a87488c12e0253b2bb5d1ac7aa1536f69ef569e62c2aab6a10149d753e049b4
size 615

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:ff47c8d5413881edb03cfadfcde3b550ef7089543615a467c4f0027edaf1455e
size 615

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:9e05a0a54c7e3aaffeca0ac90cce1b274a544d90b329a93d453392f5df4e91a8
size 616

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:25e45b06b551ab8031c2159030d658999fcb3d1f0a34538c90768b94c8116771
size 616

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:707ed73713b6c2d80d4e333b1ccdf4650f50aefefff58ddb471c1d5411954b3d
size 616

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a8c3baf878943741d83835c1afed05c1f9780ff3f0df260c0d92706343e59c50
size 616

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:98878d09510dedc24f1429bbd12cce49c66be6c9d279a28765b120efe820a171
size 616

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:fcbdaffa211d6b0b44b3ae1c66645999e95901bfdb2fffee4c45e34a0d901ee1
size 649

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:73e922310068a66ab1a0c3d39b0b0fd7db798f2637b199a8c4bd113a38bb28c8
size 134217856

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:43e80dc7d21aeff62c73f0ed02b20a4ac9b573d0e131a3ab8d1471077e03634b
size 134217856

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:fb74cfbeec54b4f263c08510312881527fc7e484604fa0a1213b596f175fecc2
size 134217856

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:447cca0af25e309c8be61216cea4fb2d3a8a967b0522760f1dab8f15e6b41574
size 134217856

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:63321906920de825fc7857aa9f2e4944c3b32f3eadf99da06b66f6599924bc4c
size 134217856

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:92e3fcc41f05380cb7b334fefc0a30bb8c1dfd9ca5c3b2cfaad36b0c7093914e
size 134217856

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:0bf671d666d35617edd7bfb58548a5120bd92e5f9cb6edb4b5fc8d3bf5db8987
size 134217856

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:bd429a5776555671b19d42e3ae152026dafd5bf95aeed9847df1432ed37f3eba
size 134217856

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:1df90745cc2e6d4b0ad964fca2de1441e6e0b4b8345fbb0fbc1ffe9820674269
size 134481920

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:477365edd696f66610298198257422a017f825e1f8bec4363bdfb0da0d741ebc
size 134217856

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:1814f8ae0e6cdb69c0030741a3c6b35c74f2d6985eed34c0d5b4135384014abc
size 134217856

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d53bc0fd0897c7372c2f182b84163bcf401ad91f26b6949c6d7a1d70c5dbb513
size 134217856

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:fc882f29530f7d683be631214f6667611e0aba87453a11157d71c50f3548fb3c
size 134217856

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:0f0bc616c2e581444a6fa658e45ee942c1ef5a4d21f22363518331f8d80cbe62
size 134217856

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:ad94f2af43a2a06ebddc78566cbef0ea538c3da9191682e2fde4ecddb061b0f0
size 134217856

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:6930349d1ae1479dcb4c9ee9eaefe2da3eae62539cb4b97de873a7a8b175e809
size 134217856

View File

@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:746ea83013870351296e01e294905ee027291ef79fd78c1e6b69dd9ebaa1cba0
size 10240000
oid sha256:76934d1d202aea1311ba67f5ea35eeb99a9c5c856f491565032e7d54ca6f072d
size 13598720

37
notes/models.md Normal file
View File

@@ -0,0 +1,37 @@
# models
## no polarisation flipping
```py
config_path="data/20241229-163*-128-16384-50000-*.ini"
model=".models/best_20241230_011907.tar"
```
```py
config_path="data/20241229-163*-128-16384-80000-*.ini"
model=".models/best_20241230_103752.tar"
```
```py
config_path="data/20241229-163*-128-16384-100000-*.ini"
model=".models/best_20241230_164534.tar"
```
## with polarisation flipping
polarisation flipping: signal is randomly rotated by 180°. polarization rotation can be detected by adding a tone on one of the polarisations, but only to mod 180° with a direct detection setup. the randomly flipped signal should allow the network to hopefully learn to compensate for dispersion, pmd independently from the polarization rot. the training data includes the flipped signal as well, but no indication if the polarisation is flipped.
```py
config_path="data/20241229-163*-128-16384-50000-*.ini"
model=".models/best_20241231_000328.tar"
```
```py
config_path="data/20241229-163*-128-16384-80000-*.ini"
model=".models/best_20241231_163614.tar"
```
```py
config_path="data/20241229-163*-128-16384-100000-*.ini"
model=".models/best_20241231_170532.tar"
```

View File

@@ -0,0 +1,59 @@
# Baseline Models
## a) D+S, pol_error 0, ortho_error 0, DGD 0
dataset
```raw
data/20250118-225840-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini
```
model
```raw
.models/best_20250118_225918.tar
```
## b) D+S, pol_error 0.4, ortho_error 0, DGD 0
dataset
```raw
data/20250116-214606-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini
```
model
```raw
.models/best_20250116_214816.tar
```
## c) D+S, pol_error 0, ortho_error 0.1, DGD 0
dataset
```raw
data/20250116-214547-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini
```
model
```raw
.models/best_20250117_122319.tar
```
## d) D+S, pol_error 0, ortho_error 0, DGD 10ps (1 T_sym)
birefringence angle pi/2 (worst case)
dataset
```raw
data/20250117-143926-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini
```
model
```raw
.models/best_20250117_144001.tar
```

2
pypho

Submodule pypho updated: dd015f4852...e44fc477fe

View File

@@ -26,6 +26,8 @@ import torch
import torch.optim as optim
import torch.utils.data
import hypertraining.models as models
from torch.utils.tensorboard import SummaryWriter
import multiprocessing
@@ -253,14 +255,17 @@ class HyperTraining:
model_kwargs = {
"dims": (input_dim, *hidden_dims, self.model_settings.output_dim),
"layer_function": layer_func,
"layer_parametrizations": layer_parametrizations,
"activation_function": afunc,
"layer_func_kwargs": self.model_settings.model_layer_kwargs,
"act_function": afunc,
"act_func_kwargs": None,
"parametrizations": layer_parametrizations,
"dtype": dtype,
"droupout_prob": self.model_settings.dropout_prob,
"scale": scale_layers,
"dropout_prob": self.model_settings.dropout_prob,
"scale_layers": scale_layers,
"rotate": False,
}
model = util.complexNN.regenerator(**model_kwargs)
model = models.regenerator(*model_kwargs.pop("dims"), **model_kwargs)
n_nodes = sum(hidden_dims)
if writer is not None:
@@ -381,7 +386,10 @@ class HyperTraining:
running_loss = 0.0
model.train()
loader_len = len(train_loader)
for batch_idx, (x, y, _) in enumerate(train_loader):
for batch_idx, batch in enumerate(train_loader):
x = batch["x"]
y = batch["y"]
if batch_idx >= self.optuna_settings._n_train_batches:
break
model.zero_grad(set_to_none=True)
@@ -390,7 +398,7 @@ class HyperTraining:
y.to(self.pytorch_settings.device),
)
y_pred = model(x)
loss = util.complexNN.complex_mse_loss(y_pred, y)
loss = util.complexNN.complex_mse_loss(y_pred, y, power=True)
loss_value = loss.item()
loss.backward()
optimizer.step()
@@ -444,7 +452,9 @@ class HyperTraining:
model.eval()
running_error = 0
with torch.no_grad():
for batch_idx, (x, y, _) in enumerate(valid_loader):
for batch_idx, batch in enumerate(valid_loader):
x = batch["x"]
y = batch["y"]
if batch_idx >= self.optuna_settings._n_valid_batches:
break
x, y = (
@@ -452,50 +462,44 @@ class HyperTraining:
y.to(self.pytorch_settings.device),
)
y_pred = model(x)
error = util.complexNN.complex_mse_loss(y_pred, y)
error = util.complexNN.complex_mse_loss(y_pred, y, power=True)
error_value = error.item()
running_error += error_value
running_error /= min(len(valid_loader), self.optuna_settings._n_valid_batches)
if writer is not None:
title_append, subtitle = self.build_title(trial)
writer.add_figure(
"fiber response",
self.plot_model_response(
trial,
model=model,
title_append=title_append,
subtitle=subtitle,
show=False,
),
epoch + 1,
)
writer.add_figure(
"eye diagram",
self.plot_model_response(
trial,
model=self.model,
title_append=title_append,
subtitle=subtitle,
show=False,
mode="eye",
),
epoch + 1,
writer.add_scalar(
"eval loss",
running_error,
epoch,
)
# if (epoch + 1) % 10 == 0 or epoch < 10:
# # plotting is slow, so only do it every 10 epochs
# title_append, subtitle = self.build_title(trial)
# head_fig, eye_fig, powers_fig = self.plot_model_response(
# model=model,
# title_append=title_append,
# subtitle=subtitle,
# show=False,
# )
# writer.add_figure(
# "fiber response",
# head_fig,
# epoch + 1,
# )
# writer.add_figure(
# "eye diagram",
# eye_fig,
# epoch + 1,
# )
writer.add_figure(
"powers",
self.plot_model_response(
trial,
model=self.model,
title_append=title_append,
subtitle=subtitle,
mode="powers",
show=False,
),
epoch + 1,
)
# writer.add_figure(
# "powers",
# powers_fig,
# epoch + 1,
# )
# writer.flush()
# if enable_progress:
# progress.stop()
@@ -511,15 +515,18 @@ class HyperTraining:
with torch.no_grad():
model = model.to(self.pytorch_settings.device)
for x, y, timestamp in loader:
for batch in loader:
x = batch["x"]
y = batch["y"]
timestamp = batch["timestamp"]
x, y = (
x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device),
)
if trace_powers:
y_pred, powers = model(x, trace_powers).cpu()
y_pred, powers = model(x, trace_powers=True).cpu()
else:
y_pred = model(x, trace_powers).cpu()
y_pred = model(x, trace_powers=True).cpu()
# x = x.cpu()
# y = y.cpu()
y_pred = y_pred.view(y_pred.shape[0], -1, 2)
@@ -539,7 +546,7 @@ class HyperTraining:
return fiber_in, fiber_out, regen, timestamps, powers
return fiber_in, fiber_out, regen, timestamps
def objective(self, trial: optuna.Trial, plot_before=False):
def objective(self, trial: optuna.Trial):
if self.stop_study:
trial.study.stop()
model = None
@@ -555,54 +562,54 @@ class HyperTraining:
title_append, subtitle = self.build_title(trial)
writer.add_figure(
"fiber response",
self.plot_model_response(
trial,
model=model,
title_append=title_append,
subtitle=subtitle,
show=False,
),
0,
)
writer.add_figure(
"eye diagram",
self.plot_model_response(
trial,
model=self.model,
title_append=title_append,
subtitle=subtitle,
mode="eye",
show=False,
),
0,
)
# writer.add_figure(
# "fiber response",
# self.plot_model_response(
# trial,
# model=model,
# title_append=title_append,
# subtitle=subtitle,
# show=False,
# ),
# 0,
# )
# writer.add_figure(
# "eye diagram",
# self.plot_model_response(
# trial,
# model=self.model,
# title_append=title_append,
# subtitle=subtitle,
# mode="eye",
# show=False,
# ),
# 0,
# )
writer.add_figure(
"powers",
self.plot_model_response(
trial,
model=self.model,
title_append=title_append,
subtitle=subtitle,
mode="powers",
show=False,
),
0,
)
# writer.add_figure(
# "powers",
# self.plot_model_response(
# trial,
# model=self.model,
# title_append=title_append,
# subtitle=subtitle,
# mode="powers",
# show=False,
# ),
# 0,
# )
train_loader, valid_loader = self.get_sliced_data(trial)
optimizer_name = trial.suggest_categorical_optional("optimizer", self.optimizer_settings.optimizer)
lr = trial.suggest_float_optional("lr", self.optimizer_settings.learning_rate, log=True)
lr = trial.suggest_float_optional("lr", self.optimizer_settings.optimizer_kwargs["lr"], log=True)
optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)
if self.optimizer_settings.scheduler is not None:
scheduler = getattr(optim.lr_scheduler, self.optimizer_settings.scheduler)(
optimizer, **self.optimizer_settings.scheduler_kwargs
)
# if self.optimizer_settings.scheduler is not None:
# scheduler = getattr(optim.lr_scheduler, self.optimizer_settings.scheduler)(
# optimizer, **self.optimizer_settings.scheduler_kwargs
# )
for epoch in range(self.pytorch_settings.epochs):
trial.set_user_attr("epoch", epoch)
@@ -628,8 +635,8 @@ class HyperTraining:
writer,
# enable_progress=enable_progress,
)
if self.optimizer_settings.scheduler is not None:
scheduler.step(error)
# if self.optimizer_settings.scheduler is not None:
# scheduler.step(error)
trial.set_user_attr("mse", error)
trial.set_user_attr("log_mse", np.log10(error + np.finfo(float).eps))
@@ -645,10 +652,10 @@ class HyperTraining:
if self.optuna_settings._multi_objective:
return -np.log10(error + np.finfo(float).eps), trial.user_attrs.get("model_n_nodes", -1)
if self.pytorch_settings.save_models and model is not None:
save_path = Path(self.pytorch_settings.model_dir) / f"{self.optuna_settings.study_name}_{trial.number}.pth"
save_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(model, save_path)
# if self.pytorch_settings.save_models and model is not None:
# save_path = Path(self.pytorch_settings.model_dir) / f"{self.optuna_settings.study_name}_{trial.number}.pth"
# save_path.parent.mkdir(parents=True, exist_ok=True)
# torch.save(model, save_path)
return error

View File

@@ -8,7 +8,8 @@ from util.complexNN import (
photodiode,
EOActivation,
polarimeter,
normalize_by_first
# normalize_by_first,
rotate,
)
@@ -19,11 +20,11 @@ class polarisation_estimator2(Module):
polarimeter(),
torch.nn.Linear(4, 4),
torch.nn.ReLU(),
torch.nn.Dropout(p=0.01),
# torch.nn.Dropout(p=0.01),
torch.nn.Linear(4, 4),
torch.nn.ReLU(),
torch.nn.Dropout(p=0.01),
torch.nn.Linear(4, 4),
# torch.nn.Dropout(p=0.01),
torch.nn.Linear(4, 1),
)
def forward(self, x):
@@ -123,7 +124,8 @@ class regenerator(Module):
parametrizations: list[dict] = None,
dtype=torch.float64,
dropout_prob=0.01,
scale_layers=False,
prescale=1,
rotate=False,
):
super(regenerator, self).__init__()
self._n_hidden_layers = len(dims) - 2
@@ -131,14 +133,15 @@ class regenerator(Module):
layer_func_kwargs = layer_func_kwargs or {}
act_func_kwargs = act_func_kwargs or {}
self.build_model(dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob, scale_layers)
self.rotation = rotate
self.prescale = prescale
def build_model(self, dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob, scale_layers):
self.build_model(dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob)
def build_model(self, dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob):
for i in range(0, self._n_hidden_layers):
self.add_module(f"layer_{i}", Sequential())
if scale_layers:
self.get_submodule(f"layer_{i}").add_module("scale", Scale(dims[i]))
module = layer_function(dims[i], dims[i + 1], dtype=dtype, **layer_func_kwargs)
self.get_submodule(f"layer_{i}").add_module("ONN", module)
@@ -146,13 +149,14 @@ class regenerator(Module):
module = act_function(size=dims[i + 1], **act_func_kwargs)
self.get_submodule(f"layer_{i}").add_module("activation", module)
module = DropoutComplex(p=dropout_prob)
self.get_submodule(f"layer_{i}").add_module("dropout", module)
if dropout_prob is not None and dropout_prob > 0:
module = DropoutComplex(p=dropout_prob)
self.get_submodule(f"layer_{i}").add_module("dropout", module)
self.add_module(f"layer_{self._n_hidden_layers}", Sequential())
if scale_layers:
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("scale", Scale(dims[-2]))
# if scale_layers:
# self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("scale", Scale(dims[-2]))
module = layer_function(dims[-2], dims[-1], dtype=dtype, **layer_func_kwargs)
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("ONN", module)
@@ -160,6 +164,14 @@ class regenerator(Module):
module = act_function(size=dims[-1], **act_func_kwargs)
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("activation", module)
module = Scale(size=dims[-1])
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("out_scale", module)
if self.rotation:
module = rotate()
self.add_module("rotate", module)
# module = Scale(size=dims[-1])
# self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("out_scale", module)
@@ -190,15 +202,28 @@ class regenerator(Module):
powers.append(x.abs().square().sum())
return powers
def forward(self, x, trace_powers=False):
def forward(self, x, angle=None, pre_rot=False, trace_powers=False):
x = x * self.prescale
powers = self._trace_powers(trace_powers, x)
x = self.layer_0(x)
powers = self._trace_powers(trace_powers, x, powers)
for i in range(1, self._n_hidden_layers):
# x = self.layer_0(x)
# powers = self._trace_powers(trace_powers, x, powers)
for i in range(0, self._n_hidden_layers):
x = getattr(self, f"layer_{i}")(x)
powers = self._trace_powers(trace_powers, x, powers)
x = getattr(self, f"layer_{self._n_hidden_layers}")(x)
powers = self._trace_powers(trace_powers, x, powers)
if trace_powers:
return x, powers
return x
if self.rotation:
try:
x_rot = self.rotate(x, angle)
except AttributeError:
pass
powers = self._trace_powers(trace_powers, x_rot, powers)
else:
x_rot = x
if pre_rot and trace_powers:
return x_rot, x, powers
if pre_rot and not trace_powers:
return x_rot, x
if not pre_rot and trace_powers:
return x_rot, powers
return x_rot

View File

@@ -18,10 +18,14 @@ class DataSettings:
shuffle: bool = True
in_out_delay: float = 0
xy_delay: tuple | float | int = 0
drop_first: int = 1000
drop_first: int = 64
drop_last: int = 64
train_split: float = 0.8
polarisations: tuple | list = (0,)
# cross_pol_interference: float = 0
randomise_polarisations: bool = False
osnr: float | int = None
seed: int = None
"""
change to:
@@ -91,6 +95,12 @@ class ModelSettings:
"""
def _early_stop_default_kwargs():
return {
"threshold": 1e-05,
"plateau": 25,
}
@dataclass
class OptimizerSettings:
optimizer: tuple | str = ("Adam", "RMSprop", "SGD")
@@ -99,6 +109,9 @@ class OptimizerSettings:
scheduler: str | None = None
scheduler_kwargs: dict | None = None
early_stopping: bool = False
early_stop_kwargs: dict = field(default_factory=_early_stop_default_kwargs)
"""
change to:

View File

@@ -2,9 +2,9 @@ import copy
from datetime import datetime
from pathlib import Path
import random
from typing import Literal
import matplotlib
from matplotlib.colors import LinearSegmentedColormap
from mpl_toolkits.axes_grid1 import make_axes_locatable
import torch.nn.utils.parametrize
try:
@@ -47,46 +47,107 @@ from .settings import (
PytorchSettings,
)
from cmcrameri import cm
# from matplotlib import colors as mcolors
# alpha_map = mcolors.LinearSegmentedColormap(
# 'alphamap',
# {
# 'red': [(0, 0, 0), (1, 0, 0)],
# 'green': [(0, 0, 0), (1, 0, 0)],
# 'blue': [(0, 0, 0), (1, 0, 0)],
# 'alpha': [
# (0, 1, 1),
# # (0.2, 0.2, 0.1),
# (1, 0, 0)
# ]
# }
# )
# alpha_map.set_bad(color="#AAAAAA")
def pad_to_size(array, size):
if not hasattr(size, "__len__"):
size = (size, size)
left = (
(size[0] - array.shape[0] + 1) // 2 if size[0] is not None else 0
)
right = (
(size[0] - array.shape[0]) // 2 if size[0] is not None else 0
)
top = (
(size[1] - array.shape[1] + 1) // 2 if size[1] is not None else 0
)
bottom = (
(size[1] - array.shape[1]) // 2 if size[1] is not None else 0
)
array: np.ndarray = array
if array.ndim == 2:
return np.pad(
array,
(
(left, right),
(top, bottom),
),
constant_values=(np.nan, np.nan),
)
elif array.ndim == 3:
return np.pad(
array,
(
(left, right),
(top, bottom),
(0,0)
),
constant_values=(np.nan, np.nan),
)
def traverse_dict_update(target, source):
for k, v in source.items():
if isinstance(v, dict):
if k not in target:
target[k] = {}
traverse_dict_update(target[k], v)
try:
if k not in target:
target[k] = {}
traverse_dict_update(target[k], v)
except TypeError:
if k not in target.__dict__:
setattr(target, k, {})
traverse_dict_update(target.__dict__[k], v)
else:
try:
target[k] = v
except TypeError:
target.__dict__[k] = v
def get_parameter_names_and_values(model):
def is_parametrized(module):
if hasattr(module, "parametrizations"):
return True
return False
def _get_param_info(module, prefix='', parametrization=False):
def _get_param_info(module, prefix="", parametrization=False):
param_list = []
for name, param in module.named_parameters(recurse = parametrization):
for name, param in module.named_parameters(recurse=parametrization):
if parametrization and name.startswith("parametrizations"):
name_parts = name.split('.')
name_parts = name.split(".")
name = name_parts[1]
param = getattr(module, name)
full_name = prefix + ('.' if prefix else '') + name
full_name = prefix + ("." if prefix else "") + name
param_value = param.data
param_list.append((full_name, param_value))
for child_name, child_module in module.named_children():
child_prefix = prefix + ('.' if prefix else '') + child_name
child_prefix = prefix + ("." if prefix else "") + child_name
if child_name == "parametrizations":
continue
param_list.extend(_get_param_info(child_module, child_prefix, is_parametrized(child_module)))
return param_list
return _get_param_info(model)
class PolarizationTrainer:
def __init__(
self,
@@ -101,7 +162,7 @@ class PolarizationTrainer:
settings_override=None,
reset_epoch=False,
):
self.mod = torch.pi/2
self.mod = torch.pi / 2
self.resume = checkpoint_path is not None
torch.serialization.add_safe_globals([
*util.complexNN.__all__,
@@ -219,7 +280,7 @@ class PolarizationTrainer:
# dims = self.model_kwargs.pop("dims")
model_kwargs = copy.deepcopy(self.model_kwargs)
self.model = models.polarisation_estimator(*model_kwargs.pop('dims'),**model_kwargs)
self.model = models.polarisation_estimator(*model_kwargs.pop("dims"), **model_kwargs)
# self.model = models.polarisation_estimator2()
if self.writer is not None:
@@ -260,6 +321,7 @@ class PolarizationTrainer:
target_delay=in_out_delay,
xy_delay=xy_delay,
drop_first=self.data_settings.drop_first,
drop_last=self.data_settings.drop_last,
dtype=dtype,
real=not dtype.is_complex,
num_symbols=num_symbols,
@@ -336,17 +398,20 @@ class PolarizationTrainer:
write_div = 0
loss_div = 0
for batch_idx, batch in enumerate(train_loader):
x = batch["x"]
y = batch["sop"]
x = batch["angle_data2"]
y = batch["center_angle"]
self.model.zero_grad(set_to_none=True)
x, y = (
x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device),
)
y_pred = self.model(x)
y_pred = self.model(x).abs().real
# y_pred = torch.fmod(y_pred, self.mod)
y = y.abs().real
# y = torch.fmod(y, self.mod)
# loss = torch.nn.functional.mse_loss(torch.cos(y_pred), torch.cos(y))
# loss = torch.nn.functional.smooth_l1_loss(torch.cos(torch.fmod(y_pred, torch.pi/2)).squeeze(), torch.cos(torch.fmod(y, torch.pi/2)).squeeze(), beta=0.5)
loss = torch.nn.functional.mse_loss(y_pred, y)
# loss = util.complexNN.naive_angle_loss(y_pred, y, mod=torch.pi/2)
loss = util.complexNN.naive_angle_loss(y_pred, y, mod=self.mod)
loss_value = loss.item()
loss.backward()
optimizer.step()
@@ -356,7 +421,7 @@ class PolarizationTrainer:
loss_div += 1
if enable_progress:
progress.update(task, advance=1, description=f"{loss_value:.3e}")
progress.update(task, advance=1, description=f"{loss_value/np.pi*180:.3e} °")
if batch_idx % self.pytorch_settings.write_every == 0:
self.writer.add_scalar(
@@ -395,24 +460,28 @@ class PolarizationTrainer:
loss_div = 0
with torch.no_grad():
for _, batch in enumerate(valid_loader):
x = batch["x"]
y = batch["sop"]
# x = batch["angle_data2"]
x = batch["angle_data2"]
y = batch["center_angle"]
x, y = (
x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device),
)
y_pred = self.model(x)
y_pred = self.model(x).abs().real
# y_pred = torch.fmod(y_pred, self.mod)
y = y.abs().real
# y = torch.fmod(y, self.mod)
# loss = torch.nn.functional.mse_loss(torch.cos(y_pred), torch.cos(y))
# loss = torch.nn.functional.smooth_l1_loss(torch.cos(torch.fmod(y_pred, torch.pi/2)).squeeze(), torch.cos(torch.fmod(y, torch.pi/2)).squeeze(), beta=0.5)
loss = torch.nn.functional.mse_loss(y_pred, y)
# loss = util.complexNN.naive_angle_loss(y_pred, y, mod=torch.pi/2)
loss = util.complexNN.naive_angle_loss(y_pred, y, mod=self.mod)
loss_value = loss.item()
running_loss += loss_value
loss_div += 1
if enable_progress:
progress.update(task, advance=1, description=f"{loss_value:.3e}")
progress.update(task, advance=1, description=f"{loss_value/np.pi*180:.3e} °")
running_loss = running_loss/loss_div
running_loss = running_loss / loss_div
self.writer.add_scalar(
"eval loss",
@@ -506,19 +575,19 @@ class PolarizationTrainer:
for i, config_path in enumerate(self.data_settings.config_path):
paths = Path.cwd().glob(config_path)
for j, path in enumerate(paths):
text = str(path) + '\n'
with open(path, 'r') as f:
text = str(path) + "\n"
with open(path, "r") as f:
text += f.read()
text += '\n'
self.writer.add_text(f"config_{i*len(self.data_settings.config_path)+j}", text)
text += "\n"
self.writer.add_text(f"config_{i * len(self.data_settings.config_path) + j}", text)
elif isinstance(self.data_settings.config_path, str):
paths = Path.cwd().glob(self.data_settings.config_path)
for j, path in enumerate(paths):
text = str(path) + '\n'
with open(path, 'r') as f:
text = str(path) + "\n"
with open(path, "r") as f:
text += f.read()
text += '\n'
text += "\n"
self.writer.add_text(f"config_{j}", text)
self.writer.flush()
@@ -571,7 +640,8 @@ class PolarizationTrainer:
if loss < self.best["loss"]:
self.best = checkpoint
save_path = (
Path(self.pytorch_settings.model_dir) / f"best_pol_{self.writer.get_logdir().split('/')[-1]}.tar"
Path(self.pytorch_settings.model_dir)
/ f"best_pol_{self.writer.get_logdir().split('/')[-1]}.tar"
)
save_path.parent.mkdir(parents=True, exist_ok=True)
self.save_checkpoint(self.best, save_path)
@@ -580,6 +650,7 @@ class PolarizationTrainer:
self.writer.close()
return self.best
class RegenerationTrainer:
def __init__(
self,
@@ -592,6 +663,7 @@ class RegenerationTrainer:
console=None,
checkpoint_path=None,
settings_override=None,
new_model=False,
reset_epoch=False,
):
self.resume = checkpoint_path is not None
@@ -605,12 +677,23 @@ class RegenerationTrainer:
models.regenerator,
torch.nn.utils.parametrizations.orthogonal,
])
# self.new_model = True
self.model_name = datetime.now().strftime("%Y%m%d_%H%M%S")
if self.resume:
print(f"loading checkpoint from {checkpoint_path}")
self.checkpoint_dict = torch.load(checkpoint_path, weights_only=True)
if settings_override is not None:
traverse_dict_update(self.checkpoint_dict["settings"], settings_override)
if reset_epoch:
if not new_model:
# self.new_model = False
checkpoint_file = checkpoint_path.split("/")[-1].split(".")[0]
if checkpoint_file.startswith("best"):
self.model_name = "_".join(checkpoint_file.split("_")[1:])
else:
self.model_name = "_".join(checkpoint_file.split("_")[:-1])
if new_model or reset_epoch:
self.checkpoint_dict["epoch"] = -1
self.global_settings: GlobalSettings = self.checkpoint_dict["settings"]["global_settings"]
@@ -636,11 +719,15 @@ class RegenerationTrainer:
self.model_settings: ModelSettings = model_settings
self.optimizer_settings: OptimizerSettings = optimizer_settings
if self.global_settings.seed is not None:
random.seed(self.global_settings.seed)
np.random.seed(self.global_settings.seed)
self.console = console or Console()
self.writer = None
def setup_tb_writer(self, append=None):
log_dir = self.pytorch_settings.summary_dir + "/" + (datetime.now().strftime("%Y%m%d_%H%M%S"))
log_dir = self.pytorch_settings.summary_dir + "/" + self.model_name
if append is not None:
log_dir += "_" + str(append)
@@ -669,7 +756,7 @@ class RegenerationTrainer:
def define_model(self, model_kwargs=None):
if self.resume:
model_kwargs = self.checkpoint_dict["model_kwargs"]
model_kwargs = None
else:
model_kwargs = model_kwargs
@@ -678,6 +765,14 @@ class RegenerationTrainer:
input_dim = 2 * self.data_settings.output_size
# if self.data_settings.polarisations:
# input_dim *= 2
output_dim = self.model_settings.output_dim
if self.data_settings.polarisations:
output_dim *= 2
dtype = getattr(torch, self.data_settings.dtype)
afunc = getattr(util.complexNN, self.model_settings.model_activation_func)
@@ -689,7 +784,7 @@ class RegenerationTrainer:
hidden_dims = [self.model_settings.overrides.get(f"n_hidden_nodes_{i}") for i in range(n_hidden_layers)]
self.model_kwargs = {
"dims": (input_dim, *hidden_dims, self.model_settings.output_dim),
"dims": (input_dim, *hidden_dims, output_dim),
"layer_function": layer_func,
"layer_func_kwargs": self.model_settings.model_layer_kwargs,
"act_function": afunc,
@@ -697,7 +792,7 @@ class RegenerationTrainer:
"parametrizations": layer_parametrizations,
"dtype": dtype,
"dropout_prob": self.model_settings.dropout_prob,
"scale_layers": self.model_settings.scale,
"prescale": self.model_settings.scale,
}
else:
self.model_kwargs = model_kwargs
@@ -706,10 +801,12 @@ class RegenerationTrainer:
# dims = self.model_kwargs.pop("dims")
model_kwargs = copy.deepcopy(self.model_kwargs)
self.model = models.regenerator(*model_kwargs.pop('dims'),**model_kwargs)
self.model = models.regenerator(*model_kwargs.pop("dims"), **model_kwargs)
if self.writer is not None:
self.writer.add_graph(self.model, torch.zeros(1, input_dim, dtype=dtype))
self.writer.add_graph(
self.model, (torch.rand(1, input_dim, dtype=dtype), torch.rand(1, 1, dtype=dtype.to_real()))
)
self.model = self.model.to(self.pytorch_settings.device)
if self.resume:
@@ -728,13 +825,16 @@ class RegenerationTrainer:
num_symbols = None
config_path = self.data_settings.config_path
polarisations = self.data_settings.polarisations
randomise_polarisations = self.data_settings.randomise_polarisations
polarisations = self.data_settings.polarisations
osnr = self.data_settings.osnr
# cross_pol_interference = self.data_settings.cross_pol_interference
if override is not None:
num_symbols = override.get("num_symbols", None)
config_path = override.get("config_path", config_path)
polarisations = override.get("polarisations", polarisations)
randomise_polarisations = override.get("randomise_polarisation", randomise_polarisations)
# cross_pol_interference = override.get("angle_var", 0)
# get dataset
dataset = FiberRegenerationDataset(
file_path=config_path,
@@ -743,11 +843,14 @@ class RegenerationTrainer:
target_delay=in_out_delay,
xy_delay=xy_delay,
drop_first=self.data_settings.drop_first,
drop_last=self.data_settings.drop_last,
dtype=dtype,
real=not dtype.is_complex,
num_symbols=num_symbols,
polarisations=polarisations,
randomise_polarisations=randomise_polarisations,
polarisations=polarisations,
# cross_pol_interference=cross_pol_interference,
osnr = osnr,
)
dataset_size = len(dataset)
@@ -816,16 +919,25 @@ class RegenerationTrainer:
running_loss = 0.0
self.model.train()
loader_len = len(train_loader)
# x_key = "x_stacked"# if self.data_settings.polarisations else "x"
# y_key = "y_stacked"# if self.data_settings.polarisations else "y"
x_key = "x"
y_key = "y"
for batch_idx, batch in enumerate(train_loader):
x = batch["x"]
y = batch["y"]
x = batch[x_key]
y = batch[y_key]
angle = batch["mean_angle"]
self.model.zero_grad(set_to_none=True)
x, y = (
x, y, angle = (
x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device),
angle.to(self.pytorch_settings.device),
)
y_pred = self.model(x)
y_pred = self.model(x, -angle)
# loss = util.complexNN.complex_mse_loss(y_pred*torch.sqrt(torch.ones(1, device=y.device)*5), y*torch.sqrt(torch.ones(1, device=y.device)*5), power=True)
loss = util.complexNN.complex_mse_loss(y_pred, y, power=True)
loss_value = loss.item()
loss.backward()
optimizer.step()
@@ -868,23 +980,31 @@ class RegenerationTrainer:
self.model.eval()
running_error = 0
# x_key = "x_stacked"# if self.data_settings.polarisations else "x"
# y_key = "y_stacked"# if self.data_settings.polarisations else "y"
x_key = "x"
y_key = "y"
with torch.no_grad():
for _, batch in enumerate(valid_loader):
x = batch["x"]
y = batch["y"]
x, y = (
x = batch[x_key]
y = batch[y_key]
angle = batch["mean_angle"]
x, y, angle = (
x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device),
angle.to(self.pytorch_settings.device),
)
y_pred = self.model(x)
y_pred = self.model(x, -angle)
# error = util.complexNN.complex_mse_loss(y_pred*torch.sqrt(torch.ones(1, device=y.device)*5), y*torch.sqrt(torch.ones(1, device=y.device)*5), power=True)
error = util.complexNN.complex_mse_loss(y_pred, y, power=True)
error_value = error.item()
running_error += error_value
if enable_progress:
progress.update(task, advance=1, description=f"{error_value:.3e}")
running_error = running_error/len(valid_loader)
running_error = running_error / len(valid_loader)
self.writer.add_scalar(
"eval loss",
@@ -894,7 +1014,7 @@ class RegenerationTrainer:
if (epoch + 1) % 10 == 0 or epoch < 10:
# plotting is slow, so only do it every 10 epochs
title_append, subtitle = self.build_title(epoch + 1)
head_fig, eye_fig, powers_fig = self.plot_model_response(
head_fig, eye_fig, weight_fig, powers_fig = self.plot_model_response(
model=self.model,
title_append=title_append,
subtitle=subtitle,
@@ -910,6 +1030,11 @@ class RegenerationTrainer:
eye_fig,
epoch + 1,
)
self.writer.add_figure(
"weights",
weight_fig,
epoch + 1,
)
self.writer.add_figure(
"powers",
@@ -928,45 +1053,70 @@ class RegenerationTrainer:
def run_model(self, model, loader, trace_powers=False):
model.eval()
fiber_out = []
fiber_out_rot = []
fiber_in = []
regen = []
timestamps = []
angles = []
# x_key = "x_stacked"# if self.data_settings.polarisations else "x"
# y_key = "y_stacked"# if self.data_settings.polarisations else "y"
x_key = "x"
y_key = "y"
with torch.no_grad():
model = model.to(self.pytorch_settings.device)
for batch in loader:
x = batch["x"]
y = batch["y"]
x = batch[x_key]
y = batch[y_key]
plot_target = batch["plot_target"]
angle = batch["mean_angle"]
# center_angle = batch["center_angle"]
timestamp = batch["timestamp"]
plot_data = batch["plot_data"]
x, y = (
plot_data_rot = batch["plot_data_rot"]
x, y, angle = (
x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device),
angle.to(self.pytorch_settings.device),
)
if trace_powers:
y_pred, powers = model(x, trace_powers).cpu()
y_pred, powers = model(x, -angle, True).cpu()
else:
y_pred = model(x, trace_powers).cpu()
y_pred = model(x, -angle).cpu()
# x = x.cpu()
# y = y.cpu()
# if self.data_settings.polarisations:
y_pred = y_pred[:, :2]
y_pred = y_pred.view(y_pred.shape[0], -1, 2)
y = y.view(y.shape[0], -1, 2)
plot_data = plot_data.view(plot_data.shape[0], -1, 2)
y_pred = y_pred[:, y_pred.shape[1]//2, :]
# y = y.view(y.shape[0], -1, 2)
# plot_data = plot_data.view(plot_data.shape[0], -1, 2)
# c = torch.cos(-angle).cpu()
# s = torch.sin(-angle).cpu()
# rot = torch.stack([torch.stack([c, -s], dim=1), torch.stack([s, c], dim=1)], dim=2).squeeze(-1)
# plot_data = torch.bmm(plot_data, rot.to(dtype=plot_data.dtype))
# plot_data = plot_data
# sines = torch.sin(-angle.cpu())
# cosines = torch.cos(-angle.cpu())
# plot_data = torch.stack((plot_data[..., 0] * cosines - plot_data[..., 1] * sines, plot_data[..., 0] * sines + plot_data[..., 1] * cosines), dim=-1)
# x = x.view(x.shape[0], -1, 2)
# timestamp = timestamp.view(-1, 1)
fiber_out.append(plot_data.squeeze())
fiber_in.append(y.squeeze())
fiber_out_rot.append(plot_data_rot.squeeze())
fiber_in.append(plot_target.squeeze())
regen.append(y_pred.squeeze())
timestamps.append(timestamp.squeeze())
angles.append(angle.squeeze())
fiber_out = torch.vstack(fiber_out).cpu()
fiber_out_rot = torch.vstack(fiber_out_rot).cpu()
fiber_in = torch.vstack(fiber_in).cpu()
regen = torch.vstack(regen).cpu()
angles = torch.vstack(angles).cpu()
timestamps = torch.concat(timestamps).cpu()
if trace_powers:
return fiber_in, fiber_out, regen, timestamps, powers
return fiber_in, fiber_out, regen, timestamps
return fiber_in, fiber_out, fiber_out_rot, angles, regen, timestamps, powers
return fiber_in, fiber_out, fiber_out_rot, angles, regen, timestamps
def write_parameters(self, epoch, attributes: list[str] | tuple[str] = None):
parameter_list = get_parameter_names_and_values(self.model)
@@ -998,7 +1148,7 @@ class RegenerationTrainer:
)
title_append, subtitle = self.build_title(0)
head_fig, eye_fig, powers_fig = self.plot_model_response(
head_fig, eye_fig, weight_fig, powers_fig = self.plot_model_response(
model=self.model,
title_append=title_append,
subtitle=subtitle,
@@ -1014,6 +1164,11 @@ class RegenerationTrainer:
eye_fig,
0,
)
self.writer.add_figure(
"weights",
weight_fig,
0,
)
self.writer.add_figure(
"powers",
@@ -1027,24 +1182,27 @@ class RegenerationTrainer:
for i, config_path in enumerate(self.data_settings.config_path):
paths = Path.cwd().glob(config_path)
for j, path in enumerate(paths):
text = str(path) + '\n'
with open(path, 'r') as f:
text = str(path) + "\n"
with open(path, "r") as f:
text += f.read()
text += '\n'
self.writer.add_text(f"config_{i*len(self.data_settings.config_path)+j}", text)
text += "\n"
self.writer.add_text(f"config_{i * len(self.data_settings.config_path) + j}", text)
elif isinstance(self.data_settings.config_path, str):
paths = Path.cwd().glob(self.data_settings.config_path)
for j, path in enumerate(paths):
text = str(path) + '\n'
with open(path, 'r') as f:
text = str(path) + "\n"
with open(path, "r") as f:
text += f.read()
text += '\n'
text += "\n"
self.writer.add_text(f"config_{j}", text)
self.writer.flush()
train_loader, valid_loader = self.get_sliced_data()
# train_loader.dataset.fiber_out.to(self.pytorch_settings.device)
# train_loader.dataset.fiber_in.to(self.pytorch_settings.device)
optimizer_name = self.optimizer_settings.optimizer
# lr = self.optimizer_settings.learning_rate
@@ -1074,6 +1232,7 @@ class RegenerationTrainer:
# except ValueError:
# pass
self.early_stop_vals = {"min_loss": float("inf"), "plateau_cnt": 0}
for epoch in range(self.best["epoch"] + 1, self.pytorch_settings.epochs):
enable_progress = True
if enable_progress:
@@ -1089,33 +1248,69 @@ class RegenerationTrainer:
epoch,
enable_progress=enable_progress,
)
if self.early_stop(loss):
self.save_model_checkpoints(epoch, loss)
break
if self.optimizer_settings.scheduler is not None:
self.scheduler.step(loss)
self.writer.add_scalar("learning rate", self.optimizer.param_groups[0]["lr"], epoch)
if self.pytorch_settings.save_models and self.model is not None:
save_path = (
Path(self.pytorch_settings.model_dir) / f"{self.writer.get_logdir().split('/')[-1]}_{epoch}.tar"
)
save_path.parent.mkdir(parents=True, exist_ok=True)
checkpoint = self.build_checkpoint_dict(loss, epoch)
self.save_checkpoint(checkpoint, save_path)
if loss < self.best["loss"]:
self.best = checkpoint
save_path = (
Path(self.pytorch_settings.model_dir) / f"best_{self.writer.get_logdir().split('/')[-1]}.tar"
)
save_path.parent.mkdir(parents=True, exist_ok=True)
self.save_checkpoint(self.best, save_path)
self.save_model_checkpoints(epoch, loss)
self.writer.flush()
save_path = (Path(self.pytorch_settings.model_dir) / f"best_{self.writer.get_logdir().split('/')[-1]}.tar")
print(f"Training complete. Best checkpoint: {save_path}")
self.writer.close()
return self.best
def early_stop(self, loss):
# not stopping early at all
if not self.optimizer_settings.early_stopping:
return False
# stopping because of abs threshold
if (loss_thr := self.optimizer_settings.early_stop_kwargs.get("threshold", None)) is not None:
if loss <= loss_thr:
print(f"Early stop: loss is below threshold ({loss:.2e} <= {loss_thr:.2e})")
return True
# update vals
if loss < self.early_stop_vals["min_loss"]:
self.early_stop_vals["min_loss"] = loss
self.early_stop_vals["plateau_cnt"] = 0
return False
# stopping because of plateau
if (plateau_thresh := self.optimizer_settings.early_stop_kwargs.get("plateau", None)) is not None:
self.early_stop_vals["plateau_cnt"] += 1
if self.early_stop_vals["plateau_cnt"] >= plateau_thresh:
print(f"Early stop: loss plateau length over threshold ({self.early_stop_vals["plateau_cnt"]} >= {plateau_thresh})")
return True
# no stop
return False
def save_model_checkpoints(self, epoch, loss):
if self.pytorch_settings.save_models and self.model is not None:
save_path = (
Path(self.pytorch_settings.model_dir) / f"{self.writer.get_logdir().split('/')[-1]}_{epoch}.tar"
)
save_path.parent.mkdir(parents=True, exist_ok=True)
checkpoint = self.build_checkpoint_dict(loss, epoch)
self.save_checkpoint(checkpoint, save_path)
if loss < self.best["loss"]:
self.best = checkpoint
save_path = (
Path(self.pytorch_settings.model_dir) / f"best_{self.writer.get_logdir().split('/')[-1]}.tar"
)
save_path.parent.mkdir(parents=True, exist_ok=True)
self.save_checkpoint(self.best, save_path)
def _plot_model_response_powers(self, powers, layer_names, title_append="", subtitle="", show=True):
powers = [power / powers[0] for power in powers]
fig, ax = plt.subplots()
fig.set_figwidth(18)
fig.set_figheight(4)
fig.suptitle(
f"Energy conservation{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}"
)
@@ -1131,6 +1326,77 @@ class RegenerationTrainer:
plt.show()
return fig
def _plot_model_weights(self, model, title_append="", subtitle="", show=True):
model_params = []
plots = []
dims = []
for num, (layer_name, layer) in enumerate(model.named_children()):
onn_weights = layer.ONN.weight
onn_weights = onn_weights.detach().cpu().numpy()
onn_values = np.abs(onn_weights).real
onn_angles = np.mod(np.angle(onn_weights), 2*np.pi).real
model_params.append({layer_name: onn_weights})
plots.append({layer_name: (num, onn_values, onn_angles)})
dims.append(onn_weights.shape[0])
max_size = np.max(dims)
for plot in plots:
layer_name, (num, onn_values, onn_angles) = plot.popitem()
if num == 0:
value_img = onn_values
angle_img = onn_angles
onn_angles = pad_to_size(onn_angles, (max_size, None))
onn_values = pad_to_size(onn_values, (max_size, None))
else:
onn_values = pad_to_size(onn_values, (max_size, onn_values.shape[1]+1))
onn_angles = pad_to_size(onn_angles, (max_size, onn_angles.shape[1]+1))
value_img = np.concatenate((value_img, onn_values), axis=1)
angle_img = np.concatenate((angle_img, onn_angles), axis=1)
value_img = np.ma.array(value_img, mask=np.isnan(value_img))
angle_img = np.ma.array(angle_img, mask=np.isnan(angle_img))
fig, axs = plt.subplots(1, 2, sharex=True, sharey=True, figsize=(18, 6.5))
fig.tight_layout()
dividers = map(make_axes_locatable, axs)
caxs = list(map(lambda d: d.append_axes("right", size="5%", pad=0.05), dividers))
masked_value_img = value_img
cmap = cm.batlow
cmap.set_bad(color="#AAAAAA")
im_val = axs[0].imshow(masked_value_img, cmap=cmap, vmin=0, vmax=1)
fig.colorbar(im_val, cax=caxs[0], orientation="vertical")
masked_angle_img = np.ma.masked_where(np.isnan(angle_img), angle_img)
cmap = cm.romaO
cmap.set_bad(color="#AAAAAA")
im_ang = axs[1].imshow(masked_angle_img, cmap=cmap, vmin=0, vmax=2*np.pi)
cbar = fig.colorbar(im_ang, cax=caxs[1], orientation="vertical", ticks=[i/8 * 2*np.pi for i in range(9)])
cbar.ax.set_yticklabels(["0", "π/4", "π/2", "3π/4", "π", "5π/4", "3π/2", "7π/4", ""])
axs[0].axis("off")
axs[1].axis("off")
axs[0].set_title("Values")
axs[1].set_title("Angles")
title = "Layer Weights"
if title_append:
title += f" {title_append}"
if subtitle:
title += f"\n{subtitle}"
fig.suptitle(title)
if show:
plt.show()
return fig
def _plot_model_response_eye(
self, *signals, timestamps, labels=None, sps=None, title_append="", subtitle="", show=True
):
@@ -1184,6 +1450,7 @@ class RegenerationTrainer:
fig, axs = plt.subplots(1, 2 * len(signals), sharex=True, sharey=True)
fig.set_figwidth(18)
fig.set_figheight(4)
fig.suptitle(f"Eye diagram{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}")
# xaxis = timestamps / sps
# xaxis = np.arange(2 * sps) / sps
@@ -1253,7 +1520,7 @@ class RegenerationTrainer:
xaxis = timestamps / sps
else:
xaxis = timestamps
ax.plot(xaxis, np.abs(signal[:, i]) ** 2, label=label)
ax.plot(xaxis, np.abs(signal[:, i]) ** 2, label=label, alpha=0.7)
ax.set_xlabel("Sample" if sps is None else "Symbol")
ax.set_ylabel("normalized power")
ax.minorticks_on()
@@ -1269,7 +1536,7 @@ class RegenerationTrainer:
def plot_model_response(
self,
model:torch.nn.Module=None,
model: torch.nn.Module = None,
title_append="",
subtitle="",
# mode: Literal["eye", "head", "powers"] = "head",
@@ -1281,7 +1548,9 @@ class RegenerationTrainer:
model = model.to(self.pytorch_settings.device)
model.eval()
with torch.no_grad():
_, powers = model(input_data, trace_powers=True)
_, powers = model(
input_data, torch.zeros(input_data.shape[0], 1).to(self.pytorch_settings.device), trace_powers=True
)
powers = [power.item() for power in powers]
layer_names = [name for (name, _) in model.named_children()]
@@ -1292,33 +1561,48 @@ class RegenerationTrainer:
data_settings_backup = copy.deepcopy(self.data_settings)
pytorch_settings_backup = copy.deepcopy(self.pytorch_settings)
self.data_settings.drop_first = 99.5 + random.randint(0, 1000)
self.data_settings.drop_first = int(64 + random.randint(0, 1000))
self.data_settings.shuffle = False
self.data_settings.train_split = 1.0
self.pytorch_settings.batchsize = max(self.pytorch_settings.head_symbols, self.pytorch_settings.eye_symbols)
config_path = random.choice(self.data_settings.config_path) if isinstance(self.data_settings.config_path, (list, tuple)) else self.data_settings.config_path
fiber_length = int(float(str(config_path).split('-')[4])/1000)
config_path = (
random.choice(self.data_settings.config_path)
if isinstance(self.data_settings.config_path, (list, tuple))
else self.data_settings.config_path
)
# fiber_length = int(float(str(config_path).split("-")[4]) / 1000)
if not hasattr(self, "_plot_loader"):
self._plot_loader, _ = self.get_sliced_data(
override={
"num_symbols": self.pytorch_settings.batchsize,
"config_path": config_path,
"shuffle": False,
"polarisations": (np.random.rand(1)*np.pi*2,),
"randomise_polarisation": False,
# "polarisations": (np.random.rand(1) * np.pi * 2,),
"polarisations": self.data_settings.polarisations,
"randomise_polarisation": self.data_settings.randomise_polarisations,
}
)
self._sps = self._plot_loader.dataset.samples_per_symbol
fiber_length = float(self._plot_loader.dataset.config["fiber"]["length"])/1000
self.data_settings = data_settings_backup
self.pytorch_settings = pytorch_settings_backup
fiber_in, fiber_out, regen, timestamps = self.run_model(model, self._plot_loader)
fiber_in, fiber_out, fiber_out_rot, angles, regen, timestamps = self.run_model(model, self._plot_loader)
fiber_in = fiber_in.view(-1, 2)
fiber_out = fiber_out.view(-1, 2)
fiber_out_rot = fiber_out_rot.view(-1, 2)
angles = angles.view(-1, 1)
angles = angles.real
angles = torch.fmod(angles, 2*torch.pi)
angles = torch.div(angles, 2*torch.pi)
angles = torch.repeat_interleave(angles, 2, dim=1)
regen = regen.view(-1, 2)
fiber_in = fiber_in.numpy()
fiber_out = fiber_out.numpy()
fiber_out_rot = fiber_out_rot.numpy()
angles = angles.numpy()
regen = regen.numpy()
timestamps = timestamps.numpy()
@@ -1327,31 +1611,34 @@ class RegenerationTrainer:
import gc
head_fig = self._plot_model_response_head(
fiber_in[:self.pytorch_settings.head_symbols*self._sps],
fiber_out[:self.pytorch_settings.head_symbols*self._sps],
regen[:self.pytorch_settings.head_symbols*self._sps],
timestamps=timestamps[:self.pytorch_settings.head_symbols*self._sps],
fiber_out[: self.pytorch_settings.head_symbols * self._sps],
fiber_in[: self.pytorch_settings.head_symbols * self._sps],
regen[: self.pytorch_settings.head_symbols * self._sps],
angles[: self.pytorch_settings.head_symbols * self._sps],
timestamps=timestamps[: self.pytorch_settings.head_symbols * self._sps],
labels=("fiber out", "fiber in", "regen", "normed angle"),
sps=self._sps,
title_append=title_append + f" ({fiber_length} km)",
subtitle=subtitle,
show=show,
)
# raise NotImplementedError("Eye diagram not implemented")
eye_fig = self._plot_model_response_eye(
fiber_in[: self.pytorch_settings.eye_symbols * self._sps],
fiber_out[: self.pytorch_settings.eye_symbols * self._sps],
regen[: self.pytorch_settings.eye_symbols * self._sps],
timestamps=timestamps[: self.pytorch_settings.eye_symbols * self._sps],
labels=("fiber in", "fiber out", "regen"),
sps=self._sps,
title_append=title_append + f" ({fiber_length} km)",
subtitle=subtitle,
show=show,
)
# raise NotImplementedError("Eye diagram not implemented")
eye_fig = self._plot_model_response_eye(
fiber_in[:self.pytorch_settings.eye_symbols*self._sps],
fiber_out[:self.pytorch_settings.eye_symbols*self._sps],
regen[:self.pytorch_settings.eye_symbols*self._sps],
timestamps=timestamps[:self.pytorch_settings.eye_symbols*self._sps],
labels=("fiber in", "fiber out", "regen"),
sps=self._sps,
title_append=title_append + f" ({fiber_length} km)",
subtitle=subtitle,
show=show,
)
weight_fig = self._plot_model_weights(model, title_append=title_append, subtitle=subtitle, show=show)
gc.collect()
return head_fig, eye_fig, power_fig
return head_fig, eye_fig, weight_fig, power_fig
def build_title(self, number: int):
title_append = f"epoch {number}"
@@ -1361,7 +1648,7 @@ class RegenerationTrainer:
self.model_settings.overrides.get(f"n_hidden_nodes_{i}", -1) for i in range(model_n_hidden_layers)
]
model_dims.insert(0, input_dim)
model_dims.append(2)
model_dims.append(self.model_settings.output_dim)
model_dims = [str(dim) for dim in model_dims]
model_activation_func = self.model_settings.model_activation_func
model_dtype = self.data_settings.dtype

View File

@@ -0,0 +1,217 @@
from pathlib import Path
import sys
from matplotlib import pyplot as plt
import numpy as np
import torch
import util
from hypertraining.settings import GlobalSettings, DataSettings, ModelSettings, OptimizerSettings, PytorchSettings
from hypertraining import models
# def move_to_location_in_size(array, location, size):
# array_x, array_y = array.shape
# location_x, location_y = location
# size_x, size_y = size
# left = location_x
# right = size_x - array_x - location_x
# top = location_y
# bottom = size_y - array_y - location_y
# return np.pad(
# array,
# (
# (left, right),
# (top, bottom),
# ),
# constant_values=(-np.inf, -np.inf),
# )
def register_puccs_cmap(puccs_path=None):
puccs_path = Path(__file__).resolve().parent / 'puccs.csv' if puccs_path is None else puccs_path
colors = []
# keys = None
with open(puccs_path, "r") as f:
for i, line in enumerate(f.readlines()):
elements = tuple(line.split(","))
# if i == 0:
# # keys = elements
# continue
# else:
try:
colors.append(tuple(map(float, elements[4:])))
except ValueError:
continue
# colors = []
# for current in puccs_csv_data:
# colors.append(tuple(current[4:]))
from matplotlib.colors import LinearSegmentedColormap
import matplotlib as mpl
mpl.colormaps.register(LinearSegmentedColormap.from_list('puccs', colors))
def pad_to_size(array, size):
if not hasattr(size, "__len__"):
size = (size, size)
left = (
(size[0] - array.shape[0] + 1) // 2 if size[0] is not None else 0
)
right = (
(size[0] - array.shape[0]) // 2 if size[0] is not None else 0
)
top = (
(size[1] - array.shape[1] + 1) // 2 if size[1] is not None else 0
)
bottom = (
(size[1] - array.shape[1]) // 2 if size[1] is not None else 0
)
array: np.ndarray = array
if array.ndim == 2:
return np.pad(
array,
(
(left, right),
(top, bottom),
),
constant_values=(np.nan, np.nan),
)
elif array.ndim == 3:
return np.pad(
array,
(
(left, right),
(top, bottom),
(0,0)
),
constant_values=(np.nan, np.nan),
)
def model_plot(model_path, show=True):
torch.serialization.add_safe_globals([
*util.complexNN.__all__,
GlobalSettings,
DataSettings,
ModelSettings,
OptimizerSettings,
PytorchSettings,
models.regenerator,
torch.nn.utils.parametrizations.orthogonal,
])
checkpoint_dict = torch.load(model_path, weights_only=True)
dims = checkpoint_dict["model_kwargs"].pop("dims")
model = models.regenerator(*dims, **checkpoint_dict["model_kwargs"])
model.load_state_dict(checkpoint_dict["model_state_dict"], strict=False)
model_params = []
plots = []
max_size = np.max(dims)
# max_act_size = np.max(dims[1:])
# angles = [None, None]
# weights = [None, None]
for num, (layer_name, layer) in enumerate(model.named_children()):
# each layer contains an "ONN" layer and an "activation" layer
# activation layer is approximately the same for all layers and nodes -> rotation by 90 degrees
onn_weights = layer.ONN.weight
onn_weights = onn_weights.detach().cpu().numpy()
onn_values = np.abs(onn_weights).real
onn_angles = np.mod(np.angle(onn_weights), 2*np.pi).real
model_params.append({layer_name: onn_weights})
plots.append({layer_name: (num, onn_values, onn_angles)})#, act_values, act_angles)})
# fig, axs = plt.subplots(3, len(model_params)*2-1, figsize=(20, 5))
for plot in plots:
layer_name, (num, onn_values, onn_angles) = plot.popitem()
if num == 0:
value_img = onn_values
angle_img = onn_angles
onn_angles = pad_to_size(onn_angles, (max_size, None))
onn_values = pad_to_size(onn_values, (max_size, None))
else:
onn_values = pad_to_size(onn_values, (max_size, onn_values.shape[1]+1))
onn_angles = pad_to_size(onn_angles, (max_size, onn_angles.shape[1]+1))
value_img = np.concatenate((value_img, onn_values), axis=1)
angle_img = np.concatenate((angle_img, onn_angles), axis=1)
value_img = np.ma.array(value_img, mask=np.isnan(value_img))
angle_img = np.ma.array(angle_img, mask=np.isnan(angle_img))
# from cmcrameri import cm
from cmap import Colormap as cm
import scicomap as sc
# from matplotlib import colors as mcolors
# alpha_map = mcolors.LinearSegmentedColormap(
# 'alphamap',
# {
# 'red': [(0, 0, 0), (1, 0, 0)],
# 'green': [(0, 0, 0), (1, 0, 0)],
# 'blue': [(0, 0, 0), (1, 0, 0)],
# 'alpha': [
# (0, 1, 1),
# # (0.2, 0.2, 0.1),
# (1, 0, 0)
# ]
# }
# )
# alpha_map.set_bad(color="#AAAAAA")
from mpl_toolkits.axes_grid1 import make_axes_locatable
fig, axs = plt.subplots(1, 2, sharex=True, sharey=True, figsize=(16, 5))
# fig.tight_layout()
dividers = map(make_axes_locatable, axs)
caxs = list(map(lambda d: d.append_axes("right", size="5%", pad=0.05), dividers))
# masked_value_img = np.ma.masked_where(np.isnan(value_img), value_img)
masked_value_img = value_img
cmap = cm('google:turbo').to_matplotlib()
# cmap = sc.ScicoSequential("rainbow").get_mpl_color_map()
cmap.set_bad(color="#AAAAAA")
im_val = axs[0].imshow(masked_value_img, cmap=cmap, vmin=0, vmax=1)
fig.colorbar(im_val, cax=caxs[0], orientation="vertical")
masked_angle_img = np.ma.masked_where(np.isnan(angle_img), angle_img)
# cmap = cm('crameri:romao').to_matplotlib()
# cmap = plt.get_cmap('puccs')
# cmap = sc.ScicoCircular("colorwheel").get_mpl_color_map()
cmap = cm('colorcet:CET_C8').to_matplotlib()
cmap.set_bad(color="#AAAAAA")
im_ang = axs[1].imshow(masked_angle_img, cmap=cmap, vmin=0, vmax=2*np.pi)
cbar = fig.colorbar(im_ang, cax=caxs[1], orientation="vertical", ticks=[i/8 * 2*np.pi for i in range(9)])
cbar.ax.set_yticklabels(["0", "π/4", "π/2", "3π/4", "π", "5π/4", "3π/2", "7π/4", ""])
# im_ang_w = axs[2].imshow(masked_angle_img, cmap=cmap)
# im_ang_w = axs[2].imshow(masked_value_img, cmap=alpha_map)
axs[0].axis("off")
axs[1].axis("off")
# axs[2].axis("off")
axs[0].set_title("Values")
axs[1].set_title("Angles")
# axs[2].set_title("Values and Angles")
...
if show:
plt.show()
return fig
# model = models.regenerator(*dims, **model_kwargs)
if __name__ == "__main__":
register_puccs_cmap()
if len(sys.argv) > 1:
model_plot(sys.argv[1])
else:
print("Please provide a model path as an argument")
# model_plot(".models/best_20250114_224234.tar")

View File

@@ -0,0 +1,102 @@
"x","L","a","b","R","G","B"
0.,0.5187848173343539,0.6399990176455989,0.67,0.8889427469969852,0.22673227640012172,0.
0.01,0.5374499525557803,0.604014067614707,0.6777967519386492,0.8956274406155226,0.27553288030331824,0.
0.02,0.5560867887452998,0.5680836759482211,0.6855816828789898,0.9019507507843885,0.318608215541461,0.
0.03,0.5746877595125583,0.5322224300667823,0.6933516322080414,0.907905487190649,0.3580633000693721,0.
0.04,0.5932314662487472,0.49647158484797804,0.7010976613543587,0.9134808162089558,0.3949845524063657,0.
0.05,0.6117000836392819,0.46086550613202343,0.7088123243737041,0.918668356138916,0.43002019316005363,0.
0.06,0.6300828534995973,0.4254249348741487,0.7164911273850869,0.923462736751354,0.4635961938811463,0.
0.07,0.6483763163456417,0.3901565406944371,0.7241326253017896,0.9278609626724071,0.49601354353255284,0.
0.08,0.6665840140182806,0.3550534951951814,0.7317382976124045,0.9318616057744784,0.5274983630587982,0.
0.09,0.6847162776119433,0.3200958808181962,0.7393124597949372,0.9354640163365924,0.5582303922647159,0.
0.1,0.7027902128942014,0.2852507189547545,0.7468622572263107,0.9386675557407496,0.5883604892249517,0.004034952213848706
0.11,0.7208298719332069,0.25047163906104203,0.7543977368741345,0.9414708123927996,0.6180221032545026,0.016031521294251994
0.12,0.7388665670611175,0.2156982733607376,0.7619319784446927,0.943870754968487,0.6473392272576862,0.029857267582036696
0.13,0.7569392765472108,0.18085547473834482,0.7694812638396673,0.9458617774020323,0.676432172396153,0.045365670193636125
0.14,0.7750950944867471,0.14585244938794778,0.7770652650825484,0.9474345911958609,0.7054219201084561,0.06017985923530026
0.15,0.793389684293558,0.11058188251425949,0.7847072337503834,0.9485749196617762,0.7344334940032564,0.07418869502646075
0.16,0.8117919447684838,0.07510373484536464,0.792394178330817,0.9492596163836376,0.7634480277996188,0.08767517868137237
0.17,0.8293050962981561,0.03629277424762101,0.799038155466063,0.9462308253550155,0.7922009241807345,0.10066327128139077
0.18,0.8213303100752708,-0.0062517290795987,0.7879999288492758,0.9088702681901394,0.7940579017644396,0.10139639009534024
0.19,0.8134831311534617,-0.048115463155645855,0.7771383286984362,0.8716809050191757,0.7954897210083888,0.10232311621802098
0.2,0.80558613530069,-0.0902449644291895,0.7662077749032042,0.8337524177888596,0.7965471523787845,0.10344968926026826
0.21,0.7975860185564765,-0.13292460297117392,0.7551344872795225,0.7947193410849823,0.7972381033243311,0.10477682283894393
0.22,0.7894147026971006,-0.17651756772919341,0.7438242359834689,0.7540941866826836,0.7975605026647324,0.10631182441371936
0.23,0.7809997374598548,-0.2214103719409295,0.7321767396537806,0.7112894518675287,0.7974995317311054,0.1080672415170634
0.24,0.7722646970273015,-0.2680107379394189,0.7200862142018722,0.6655745739336695,0.7970267795229349,0.11006041388465265
0.25,0.7631307298557146,-0.3167393290089981,0.7074435179925446,0.6160047476007512,0.7960993904970947,0.11231257117602686
0.26,0.7535192192483822,-0.36801555555407994,0.6941398344519211,0.5612859274945571,0.794659599537827,0.11484733363789801
0.27,0.7433557597838075,-0.42223636134393283,0.6800721760037781,0.4994862901720824,0.7926351396848288,0.11768844813479104
0.28,0.732575139048096,-0.479749646583324,0.6651502794883674,0.42731393423789277,0.7899410218414098,0.12085678487511567
0.29,0.7211269294461059,-0.5408244362880141,0.6493043460161184,0.3378265607222193,0.786483110019224,0.124366774034814
0.3,0.7090756028785993,-0.6051167807996883,0.6326236137723747,0.2098475715121697,0.7821998608677176,0.12819222127525928
0.31,0.7094510768540225,-0.6165036055456403,0.5630307498747129,0.15061488620640032,0.7845112116922692,0.21943537230975235
0.32,0.7174669421288304,-0.5917687864932311,0.4797229624661701,0.18766933782916642,0.7905828987725732,0.31091344246312086
0.33,0.7249009746435938,-0.5688293479200438,0.40246208306061504,0.21160609617940718,0.7962175427587832,0.38519766326885596
0.34,0.7317072855135611,-0.5478268906666535,0.3317250285377912,0.22717569971119178,0.8013847719431052,0.4490960048955565
0.35,0.7379328517830899,-0.5286164561226088,0.26702357292455026,0.23690087622812972,0.8061220291668977,0.5056371468159843
0.36,0.7436229063122554,-0.5110584677642499,0.20788761731555405,0.24226377668817778,0.8104638164122776,0.5563570758573497
0.37,0.7488251728809415,-0.4950056627547577,0.15382117501783654,0.24424372086048424,0.8144455902164638,0.6022301663745243
0.38,0.7535943992285348,-0.48028910419451787,0.10425526029155024,0.24352232677523483,0.818107753931944,0.6440238320299774
0.39,0.757994865186593,-0.4667104416936734,0.05852182167144754,0.240562414747303,0.8214980148949816,0.6824536572462205
0.4,0.7620994844391137,-0.4540446830999986,0.015863077249098356,0.2356325204239052,0.8246710357361025,0.7182393675419642
0.41,0.7659871096124125,-0.4420485102716773,-0.024540477496154123,0.22880568593963535,0.8276865975886148,0.7521146815529202
0.42,0.7697410958994951,-0.4304647113488041,-0.06355514164248566,0.21993360985514526,0.8306086550266585,0.7848331944479765
0.43,0.773446484628189,-0.4190308715098135,-0.10206473803580057,0.20858849290850018,0.833503273690861,0.8171544357676854
0.44,0.7771893686864673,-0.4074813310994203,-0.14096401824224686,0.1939295692427068,0.8364382500400466,0.8498448067259188
0.45,0.7810574093604746,-0.3955455908045306,-0.18116403397486242,0.17438366103820427,0.839483669055626,0.8836865023336339
0.46,0.7851360804917298,-0.3829599011818591,-0.2235531031349741,0.14679145002531463,0.8427091517444469,0.9194481212717681
0.47,0.789525027020907,-0.369416784561489,-0.26916682191206776,0.10278921007810798,0.8461971304126237,0.9580316568065935
0.48,0.7942371698732826,-0.35487637041943493,-0.3181394757087982,0.0013920913109500188,0.8499626968466341,0.9995866371771526
0.49,0.7773897680996302,-0.31852357140025195,-0.34537976514700053,0.10740420703601522,0.8254781216972907,1.
0.5,0.7604011244310231,-0.28211213216592784,-0.3722846952738428,0.1581725581872408,0.8008522647497104,1.
0.51,0.7433440454962605,-0.2455540169176899,-0.3992980063927199,0.19300141807932156,0.7761561224913385,1.
0.52,0.7262590833969331,-0.20893614020926626,-0.42635547610418184,0.2194621842292243,0.751443124097109,1.
0.53,0.709058602701224,-0.17207067467417486,-0.453595892719742,0.2405673704012788,0.7265803324554873,1.
0.54,0.6915768892539101,-0.1346024482921609,-0.48128169789479536,0.25788347992973676,0.701321051230534,1.
0.55,0.6736331627810209,-0.09614399811510127,-0.5096991935104321,0.2722888922216317,0.6753950894563805,1.
0.56,0.6551463184003872,-0.05652149358027936,-0.5389768254408652,0.28422807900785235,0.6486730893521468,1.
0.57,0.6361671326276888,-0.01584376303510615,-0.5690341788729347,0.293907374075009,0.6212117649042732,1.
0.58,0.6168396823565967,0.025580396234342995,-0.5996430791016598,0.301442767979156,0.5931976878638505,1.
0.59,0.5973210287815495,0.06741435793529688,-0.6305547881733555,0.30694603901024253,0.5648312189065924,1.
0.6,0.5777303704171711,0.10940264614179468,-0.661580531294122,0.3105418468883679,0.5362525958007331,1.
0.61,0.5581475370499237,0.15137416317967575,-0.6925938819599547,0.3123531986526998,0.5075386530652202,1.
0.62,0.5386227795100639,0.19322120739317136,-0.7235152578861672,0.31248922600720636,0.4787151440558522,1.
0.63,0.5191666876024412,0.23492108185347996,-0.754327887989376,0.31103663081260624,0.44973844514160927,1.
0.64,0.4996990584326256,0.2766456839100268,-0.7851587896650079,0.30803814950244496,0.4204116611935119,1.
0.65,0.479957679121191,0.3189570094767831,-0.8164232296840259,0.30343473603466015,0.390226489453496,1.
0.66,0.4600072725872886,0.3617163391430824,-0.8480187063016573,0.29717122075330515,0.3591178757512998,1.
0.67,0.44600100870220305,0.4113853615984094,-0.8697728377551008,0.3178994129506999,0.3295740682997879,1.
0.68,0.4574651571354146,0.44026390446569547,-0.8504539292487465,0.3842479358768364,0.3280946443367561,1.
0.69,0.4691809168948424,0.46977626401045774,-0.830711015748157,0.44293649140770447,0.3260767554252525,1.
0.7,0.4811696900083858,0.49997635259991063,-0.8105080314416201,0.49708450874457527,0.3234487047238236,1.
0.71,0.49350094811609174,0.5310391714342613,-0.7897279055963483,0.5485591109413528,0.3201099534066949,1.
0.72,0.5062548753068121,0.5631667067020758,-0.7682355153041539,0.5985798481027601,0.3159263917472715,1.
0.73,0.5195243020949684,0.5965928013272943,-0.7458744264238399,0.6480500606439057,0.31071717884730565,1.
0.74,0.5334043922713477,0.6315571758288618,-0.7224842728734379,0.6976685401842261,0.3042411890803418,1.
0.75,0.5479805812358602,0.6682750446095802,-0.697921082452685,0.7479712773579563,0.29618040787504757,1.
0.76,0.5633244502526606,0.7069267230777347,-0.6720642293775535,0.7993701361353484,0.28611136999256687,1.
0.77,0.5794956601139,0.7476624986056212,-0.6448131757501174,0.8521918014427678,0.2734527325942473,1.
0.78,0.5965429098573916,0.7906050455688622,-0.6160858559672187,0.9067003897516911,0.2573693489198746,1.
0.79,0.6145761476424179,0.8360313267658297,-0.5856969899409387,0.963334644317004,0.23648492980159264,1.
0.8,0.6232910688128902,0.859291371252556,-0.5300995185388214,1.,0.21867949406239662,0.9712088595948508
0.81,0.6159984336377875,0.8439887543380684,-0.44635440435952856,1.,0.21606849746358275,0.9041480210597966
0.82,0.6091642745073532,0.8296481879180277,-0.36787420852419694,1.,0.21421830096504035,0.8419706002336461
0.83,0.6025478038652375,0.8157644115969636,-0.2918938425681935,1.,0.21295365915197917,0.7823908751330636
0.84,0.5961857222953111,0.8024144366282877,-0.21883475834162458,0.9971140114799418,0.21220068235083267,0.7256713129328118
0.85,0.5900921771070883,0.7896279492437488,-0.1488594167412921,0.993273906363258,0.2118788857127918,0.671860243327784
0.86,0.5842771639541229,0.7774259239818333,-0.08208260304413262,0.9887084084529413,0.21191070453347688,0.6209624706933893
0.87,0.578741582584259,0.7658102488427286,-0.018514649521559012,0.9835846378805114,0.2122246941077346,0.5728987835613306
0.88,0.5734741590353537,0.7547572669288056,0.04197390858426542,0.9780378159372328,0.21275878699579343,0.5274829957183049
0.89,0.5684517008574971,0.7442183119942206,0.09964940221121898,0.9721670725313721,0.21346242315895625,0.4844270603851604
0.9,0.5636419856510335,0.7341257696545772,0.15488185789614228,0.9660363209686843,0.21429691147008262,0.4433660148378527
0.91,0.5590069340453534,0.7243997354573974,0.20810856081277884,0.9596781387247791,0.2152344151262528,0.4038812338146013
0.92,0.5545051525321143,0.7149533506766244,0.25980485409830323,0.9530986696850675,0.21625626438013962,0.3655130449917989
0.93,0.5500961975299247,0.705701749880514,0.3104351723857584,0.9462863346513658,0.21735046958786286,0.327780364198278
0.94,0.545740378056064,0.6965616468647046,0.36045530782708896,0.93921469089265,0.21851014470332586,0.29014917175372823
0.95,0.5414004092067859,0.6874548042588865,0.41029342232076466,0.9318478255642132,0.21973168075163751,0.2519897371806688
0.96,0.5370416605957644,0.6783085548415655,0.46034719456417006,0.9241434776436454,0.22101341980094052,0.2124579038400577
0.97,0.5326309593934517,0.6690532898786764,0.5109975653738162,0.9160532016485884,0.22235495330179011,0.17018252385769012
0.98,0.5281374148557197,0.6596241892863608,0.5625992691950712,0.90752576202319,0.22375597459867458,0.1223073280126531
0.99,0.5235317096396147,0.6499597345521199,0.615488972291106,0.8985077346125597,0.22521565729028564,0.05933950582860665
1.,0.5187848173343539,0.6399990176455989,0.67,0.8889427469969852,0.22673227640012172,0.
1 x L a b R G B
2 0. 0.5187848173343539 0.6399990176455989 0.67 0.8889427469969852 0.22673227640012172 0.
3 0.01 0.5374499525557803 0.604014067614707 0.6777967519386492 0.8956274406155226 0.27553288030331824 0.
4 0.02 0.5560867887452998 0.5680836759482211 0.6855816828789898 0.9019507507843885 0.318608215541461 0.
5 0.03 0.5746877595125583 0.5322224300667823 0.6933516322080414 0.907905487190649 0.3580633000693721 0.
6 0.04 0.5932314662487472 0.49647158484797804 0.7010976613543587 0.9134808162089558 0.3949845524063657 0.
7 0.05 0.6117000836392819 0.46086550613202343 0.7088123243737041 0.918668356138916 0.43002019316005363 0.
8 0.06 0.6300828534995973 0.4254249348741487 0.7164911273850869 0.923462736751354 0.4635961938811463 0.
9 0.07 0.6483763163456417 0.3901565406944371 0.7241326253017896 0.9278609626724071 0.49601354353255284 0.
10 0.08 0.6665840140182806 0.3550534951951814 0.7317382976124045 0.9318616057744784 0.5274983630587982 0.
11 0.09 0.6847162776119433 0.3200958808181962 0.7393124597949372 0.9354640163365924 0.5582303922647159 0.
12 0.1 0.7027902128942014 0.2852507189547545 0.7468622572263107 0.9386675557407496 0.5883604892249517 0.004034952213848706
13 0.11 0.7208298719332069 0.25047163906104203 0.7543977368741345 0.9414708123927996 0.6180221032545026 0.016031521294251994
14 0.12 0.7388665670611175 0.2156982733607376 0.7619319784446927 0.943870754968487 0.6473392272576862 0.029857267582036696
15 0.13 0.7569392765472108 0.18085547473834482 0.7694812638396673 0.9458617774020323 0.676432172396153 0.045365670193636125
16 0.14 0.7750950944867471 0.14585244938794778 0.7770652650825484 0.9474345911958609 0.7054219201084561 0.06017985923530026
17 0.15 0.793389684293558 0.11058188251425949 0.7847072337503834 0.9485749196617762 0.7344334940032564 0.07418869502646075
18 0.16 0.8117919447684838 0.07510373484536464 0.792394178330817 0.9492596163836376 0.7634480277996188 0.08767517868137237
19 0.17 0.8293050962981561 0.03629277424762101 0.799038155466063 0.9462308253550155 0.7922009241807345 0.10066327128139077
20 0.18 0.8213303100752708 -0.0062517290795987 0.7879999288492758 0.9088702681901394 0.7940579017644396 0.10139639009534024
21 0.19 0.8134831311534617 -0.048115463155645855 0.7771383286984362 0.8716809050191757 0.7954897210083888 0.10232311621802098
22 0.2 0.80558613530069 -0.0902449644291895 0.7662077749032042 0.8337524177888596 0.7965471523787845 0.10344968926026826
23 0.21 0.7975860185564765 -0.13292460297117392 0.7551344872795225 0.7947193410849823 0.7972381033243311 0.10477682283894393
24 0.22 0.7894147026971006 -0.17651756772919341 0.7438242359834689 0.7540941866826836 0.7975605026647324 0.10631182441371936
25 0.23 0.7809997374598548 -0.2214103719409295 0.7321767396537806 0.7112894518675287 0.7974995317311054 0.1080672415170634
26 0.24 0.7722646970273015 -0.2680107379394189 0.7200862142018722 0.6655745739336695 0.7970267795229349 0.11006041388465265
27 0.25 0.7631307298557146 -0.3167393290089981 0.7074435179925446 0.6160047476007512 0.7960993904970947 0.11231257117602686
28 0.26 0.7535192192483822 -0.36801555555407994 0.6941398344519211 0.5612859274945571 0.794659599537827 0.11484733363789801
29 0.27 0.7433557597838075 -0.42223636134393283 0.6800721760037781 0.4994862901720824 0.7926351396848288 0.11768844813479104
30 0.28 0.732575139048096 -0.479749646583324 0.6651502794883674 0.42731393423789277 0.7899410218414098 0.12085678487511567
31 0.29 0.7211269294461059 -0.5408244362880141 0.6493043460161184 0.3378265607222193 0.786483110019224 0.124366774034814
32 0.3 0.7090756028785993 -0.6051167807996883 0.6326236137723747 0.2098475715121697 0.7821998608677176 0.12819222127525928
33 0.31 0.7094510768540225 -0.6165036055456403 0.5630307498747129 0.15061488620640032 0.7845112116922692 0.21943537230975235
34 0.32 0.7174669421288304 -0.5917687864932311 0.4797229624661701 0.18766933782916642 0.7905828987725732 0.31091344246312086
35 0.33 0.7249009746435938 -0.5688293479200438 0.40246208306061504 0.21160609617940718 0.7962175427587832 0.38519766326885596
36 0.34 0.7317072855135611 -0.5478268906666535 0.3317250285377912 0.22717569971119178 0.8013847719431052 0.4490960048955565
37 0.35 0.7379328517830899 -0.5286164561226088 0.26702357292455026 0.23690087622812972 0.8061220291668977 0.5056371468159843
38 0.36 0.7436229063122554 -0.5110584677642499 0.20788761731555405 0.24226377668817778 0.8104638164122776 0.5563570758573497
39 0.37 0.7488251728809415 -0.4950056627547577 0.15382117501783654 0.24424372086048424 0.8144455902164638 0.6022301663745243
40 0.38 0.7535943992285348 -0.48028910419451787 0.10425526029155024 0.24352232677523483 0.818107753931944 0.6440238320299774
41 0.39 0.757994865186593 -0.4667104416936734 0.05852182167144754 0.240562414747303 0.8214980148949816 0.6824536572462205
42 0.4 0.7620994844391137 -0.4540446830999986 0.015863077249098356 0.2356325204239052 0.8246710357361025 0.7182393675419642
43 0.41 0.7659871096124125 -0.4420485102716773 -0.024540477496154123 0.22880568593963535 0.8276865975886148 0.7521146815529202
44 0.42 0.7697410958994951 -0.4304647113488041 -0.06355514164248566 0.21993360985514526 0.8306086550266585 0.7848331944479765
45 0.43 0.773446484628189 -0.4190308715098135 -0.10206473803580057 0.20858849290850018 0.833503273690861 0.8171544357676854
46 0.44 0.7771893686864673 -0.4074813310994203 -0.14096401824224686 0.1939295692427068 0.8364382500400466 0.8498448067259188
47 0.45 0.7810574093604746 -0.3955455908045306 -0.18116403397486242 0.17438366103820427 0.839483669055626 0.8836865023336339
48 0.46 0.7851360804917298 -0.3829599011818591 -0.2235531031349741 0.14679145002531463 0.8427091517444469 0.9194481212717681
49 0.47 0.789525027020907 -0.369416784561489 -0.26916682191206776 0.10278921007810798 0.8461971304126237 0.9580316568065935
50 0.48 0.7942371698732826 -0.35487637041943493 -0.3181394757087982 0.0013920913109500188 0.8499626968466341 0.9995866371771526
51 0.49 0.7773897680996302 -0.31852357140025195 -0.34537976514700053 0.10740420703601522 0.8254781216972907 1.
52 0.5 0.7604011244310231 -0.28211213216592784 -0.3722846952738428 0.1581725581872408 0.8008522647497104 1.
53 0.51 0.7433440454962605 -0.2455540169176899 -0.3992980063927199 0.19300141807932156 0.7761561224913385 1.
54 0.52 0.7262590833969331 -0.20893614020926626 -0.42635547610418184 0.2194621842292243 0.751443124097109 1.
55 0.53 0.709058602701224 -0.17207067467417486 -0.453595892719742 0.2405673704012788 0.7265803324554873 1.
56 0.54 0.6915768892539101 -0.1346024482921609 -0.48128169789479536 0.25788347992973676 0.701321051230534 1.
57 0.55 0.6736331627810209 -0.09614399811510127 -0.5096991935104321 0.2722888922216317 0.6753950894563805 1.
58 0.56 0.6551463184003872 -0.05652149358027936 -0.5389768254408652 0.28422807900785235 0.6486730893521468 1.
59 0.57 0.6361671326276888 -0.01584376303510615 -0.5690341788729347 0.293907374075009 0.6212117649042732 1.
60 0.58 0.6168396823565967 0.025580396234342995 -0.5996430791016598 0.301442767979156 0.5931976878638505 1.
61 0.59 0.5973210287815495 0.06741435793529688 -0.6305547881733555 0.30694603901024253 0.5648312189065924 1.
62 0.6 0.5777303704171711 0.10940264614179468 -0.661580531294122 0.3105418468883679 0.5362525958007331 1.
63 0.61 0.5581475370499237 0.15137416317967575 -0.6925938819599547 0.3123531986526998 0.5075386530652202 1.
64 0.62 0.5386227795100639 0.19322120739317136 -0.7235152578861672 0.31248922600720636 0.4787151440558522 1.
65 0.63 0.5191666876024412 0.23492108185347996 -0.754327887989376 0.31103663081260624 0.44973844514160927 1.
66 0.64 0.4996990584326256 0.2766456839100268 -0.7851587896650079 0.30803814950244496 0.4204116611935119 1.
67 0.65 0.479957679121191 0.3189570094767831 -0.8164232296840259 0.30343473603466015 0.390226489453496 1.
68 0.66 0.4600072725872886 0.3617163391430824 -0.8480187063016573 0.29717122075330515 0.3591178757512998 1.
69 0.67 0.44600100870220305 0.4113853615984094 -0.8697728377551008 0.3178994129506999 0.3295740682997879 1.
70 0.68 0.4574651571354146 0.44026390446569547 -0.8504539292487465 0.3842479358768364 0.3280946443367561 1.
71 0.69 0.4691809168948424 0.46977626401045774 -0.830711015748157 0.44293649140770447 0.3260767554252525 1.
72 0.7 0.4811696900083858 0.49997635259991063 -0.8105080314416201 0.49708450874457527 0.3234487047238236 1.
73 0.71 0.49350094811609174 0.5310391714342613 -0.7897279055963483 0.5485591109413528 0.3201099534066949 1.
74 0.72 0.5062548753068121 0.5631667067020758 -0.7682355153041539 0.5985798481027601 0.3159263917472715 1.
75 0.73 0.5195243020949684 0.5965928013272943 -0.7458744264238399 0.6480500606439057 0.31071717884730565 1.
76 0.74 0.5334043922713477 0.6315571758288618 -0.7224842728734379 0.6976685401842261 0.3042411890803418 1.
77 0.75 0.5479805812358602 0.6682750446095802 -0.697921082452685 0.7479712773579563 0.29618040787504757 1.
78 0.76 0.5633244502526606 0.7069267230777347 -0.6720642293775535 0.7993701361353484 0.28611136999256687 1.
79 0.77 0.5794956601139 0.7476624986056212 -0.6448131757501174 0.8521918014427678 0.2734527325942473 1.
80 0.78 0.5965429098573916 0.7906050455688622 -0.6160858559672187 0.9067003897516911 0.2573693489198746 1.
81 0.79 0.6145761476424179 0.8360313267658297 -0.5856969899409387 0.963334644317004 0.23648492980159264 1.
82 0.8 0.6232910688128902 0.859291371252556 -0.5300995185388214 1. 0.21867949406239662 0.9712088595948508
83 0.81 0.6159984336377875 0.8439887543380684 -0.44635440435952856 1. 0.21606849746358275 0.9041480210597966
84 0.82 0.6091642745073532 0.8296481879180277 -0.36787420852419694 1. 0.21421830096504035 0.8419706002336461
85 0.83 0.6025478038652375 0.8157644115969636 -0.2918938425681935 1. 0.21295365915197917 0.7823908751330636
86 0.84 0.5961857222953111 0.8024144366282877 -0.21883475834162458 0.9971140114799418 0.21220068235083267 0.7256713129328118
87 0.85 0.5900921771070883 0.7896279492437488 -0.1488594167412921 0.993273906363258 0.2118788857127918 0.671860243327784
88 0.86 0.5842771639541229 0.7774259239818333 -0.08208260304413262 0.9887084084529413 0.21191070453347688 0.6209624706933893
89 0.87 0.578741582584259 0.7658102488427286 -0.018514649521559012 0.9835846378805114 0.2122246941077346 0.5728987835613306
90 0.88 0.5734741590353537 0.7547572669288056 0.04197390858426542 0.9780378159372328 0.21275878699579343 0.5274829957183049
91 0.89 0.5684517008574971 0.7442183119942206 0.09964940221121898 0.9721670725313721 0.21346242315895625 0.4844270603851604
92 0.9 0.5636419856510335 0.7341257696545772 0.15488185789614228 0.9660363209686843 0.21429691147008262 0.4433660148378527
93 0.91 0.5590069340453534 0.7243997354573974 0.20810856081277884 0.9596781387247791 0.2152344151262528 0.4038812338146013
94 0.92 0.5545051525321143 0.7149533506766244 0.25980485409830323 0.9530986696850675 0.21625626438013962 0.3655130449917989
95 0.93 0.5500961975299247 0.705701749880514 0.3104351723857584 0.9462863346513658 0.21735046958786286 0.327780364198278
96 0.94 0.545740378056064 0.6965616468647046 0.36045530782708896 0.93921469089265 0.21851014470332586 0.29014917175372823
97 0.95 0.5414004092067859 0.6874548042588865 0.41029342232076466 0.9318478255642132 0.21973168075163751 0.2519897371806688
98 0.96 0.5370416605957644 0.6783085548415655 0.46034719456417006 0.9241434776436454 0.22101341980094052 0.2124579038400577
99 0.97 0.5326309593934517 0.6690532898786764 0.5109975653738162 0.9160532016485884 0.22235495330179011 0.17018252385769012
100 0.98 0.5281374148557197 0.6596241892863608 0.5625992691950712 0.90752576202319 0.22375597459867458 0.1223073280126531
101 0.99 0.5235317096396147 0.6499597345521199 0.615488972291106 0.8985077346125597 0.22521565729028564 0.05933950582860665
102 1. 0.5187848173343539 0.6399990176455989 0.67 0.8889427469969852 0.22673227640012172 0.

View File

@@ -1,6 +1,8 @@
from datetime import datetime
import optuna
import torch
import util
from hypertraining.hypertraining import HyperTraining
from hypertraining.settings import (
GlobalSettings,
@@ -16,24 +18,29 @@ global_settings = GlobalSettings(
)
data_settings = DataSettings(
config_path="data/*-128-16384-100000-0-0-17-0-PAM4-0.ini",
# config_path="data/*-128-16384-100000-0-0-17-0-PAM4-0.ini",
config_path="data/20241204-131003-128-16384-100000-0-0-17-0-PAM4-0.ini",
dtype="complex64",
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
symbols=13, # study: single_core_regen_20241123_011232
# symbols=13, # study: single_core_regen_20241123_011232
# symbols = (3, 13),
symbols=4,
# output_size = (11, 32), # ballpark 26 taps -> 2 taps per input symbol -> 1 tap every 0.01m (model has 52 inputs)
output_size=26, # study: single_core_regen_20241123_011232 (model_input_dim/2)
# output_size=26, # study: single_core_regen_20241123_011232 (model_input_dim/2)
output_size=(8, 30),
shuffle=True,
in_out_delay=0,
xy_delay=0,
drop_first=128 * 100,
drop_first=256,
train_split=0.8,
randomise_polarisations=False,
)
pytorch_settings = PytorchSettings(
epochs=10000,
epochs=10,
batchsize=2**10,
device="cuda",
dataloader_workers=12,
dataloader_workers=4,
dataloader_prefetch=4,
summary_dir=".runs",
write_every=2**5,
@@ -43,28 +50,70 @@ pytorch_settings = PytorchSettings(
model_settings = ModelSettings(
output_dim=2,
# n_hidden_layers = (3, 8),
n_hidden_layers=4,
overrides={
"n_hidden_nodes_0": 8,
"n_hidden_nodes_1": 6,
"n_hidden_nodes_2": 4,
"n_hidden_nodes_3": 8,
},
model_activation_func="Mag",
# satabsT0=(1e-6, 1),
n_hidden_layers = (2, 5),
n_hidden_nodes=(2, 16),
model_activation_func="EOActivation",
dropout_prob=0,
model_layer_function="ONNRect",
model_layer_kwargs={"square": True},
# scale=(False, True),
scale=False,
model_layer_parametrizations=[
{
"tensor_name": "weight",
"parametrization": util.complexNN.energy_conserving,
},
{
"tensor_name": "alpha",
"parametrization": util.complexNN.clamp,
},
{
"tensor_name": "gain",
"parametrization": util.complexNN.clamp,
"kwargs": {
"min": 0,
"max": float("inf"),
},
},
{
"tensor_name": "phase_bias",
"parametrization": util.complexNN.clamp,
"kwargs": {
"min": 0,
"max": 2 * torch.pi,
},
},
{
"tensor_name": "scales",
"parametrization": util.complexNN.clamp,
},
{
"tensor_name": "angle",
"parametrization": util.complexNN.clamp,
"kwargs": {
"min": -torch.pi,
"max": torch.pi,
},
},
{
"tensor_name": "loss",
"parametrization": util.complexNN.clamp,
},
],
)
optimizer_settings = OptimizerSettings(
optimizer="Adam",
# learning_rate = (1e-5, 1e-1),
learning_rate=5e-3
# learning_rate=5e-4,
optimizer="AdamW",
optimizer_kwargs={
"lr": 5e-3,
"amsgrad": True,
# "weight_decay": 1e-7,
},
)
optuna_settings = OptunaSettings(
n_trials=1,
n_workers=1,
n_trials=1024,
n_workers=8,
timeout=3600,
directions=("minimize",),
metrics_names=("mse",),

View File

@@ -1,4 +1,4 @@
from datetime import datetime
# from datetime import datetime
from pathlib import Path
import matplotlib
import numpy as np
@@ -13,7 +13,7 @@ from hypertraining.settings import (
OptimizerSettings,
)
from hypertraining.training import RegenerationTrainer, PolarizationTrainer
from hypertraining.training import RegenerationTrainer#, PolarizationTrainer
# import torch
import json
@@ -26,25 +26,39 @@ global_settings = GlobalSettings(
)
data_settings = DataSettings(
config_path="data/*-128-16384-10000-0-0-0-0-PAM4-0-0.4.ini",
# config_path=[f"data/20241202-*-128-16384-{length}-0-0-17-0-PAM4-0.ini" for length in range(48000, 53000, 1000)],
# config_path="data/20250115-114319-128-16384-inf-100000-0-0-0-0-PAM4-1-0-10.ini", # no effects enabled - baseline
# config_path = "data/20250115-233553-128-16384-1060.0-100000-0-0.2-17.0-0.058-PAM4-1.0-0.0-10.ini", # dispersion + slope only
# config_path="data/20250115-115836-128-16384-60.0-100000-0-0.2-17-0.058-PAM4-1000-0.2-10.ini", # all linear effects enabled with realistic values + noise + pmd (delta_beta=0.2) + ortho_error = 0.1
# config_path="data/20250118-225840-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini", # a)
# config_path="data/20250116-214606-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini", # b)
# config_path="data/20250116-214547-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini", # c)
# config_path="data/20250117-143926-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini", # d) 10ps dgd
config_path="data/20250120-105720-128-16384-inf-100000-0-0.2-17-0.058-PAM4-0-0-10.ini", # d) 10ns
# config_path="data/20250114-215547-128-16384-60.0-100000-1.15-0.2-17-0.058-PAM4-1-0-10.ini", # with gamma=1.15, 2.5dBm launch power, no pmd
dtype="complex64",
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
symbols=13, # study: single_core_regen_20241123_011232
symbols=4, # study: single_core_regen_20241123_011232 -> taps spread over 4 symbols @ 10GBd
# output_size = (11, 32), # 26 taps -> 2 taps per input symbol -> 1 tap every 1cm (model has 52 inputs (x/y))
output_size=26, # study: single_core_regen_20241123_011232 (model_input_dim/2)
shuffle=True,
drop_first=64,
output_size=20, # = model_input_dim/2 study: single_core_regen_20241123_011232
shuffle=False,
drop_first=256,
drop_last=256,
train_split=0.8,
randomise_polarisations=True,
randomise_polarisations=False,
polarisations=False,
# cross_pol_interference=0.01,
osnr=16, #16dB due to amplification with NF 5
)
pytorch_settings = PytorchSettings(
epochs=10000,
batchsize=2**14,
epochs=1000,
batchsize=2**13,
device="cuda",
dataloader_workers=16,
dataloader_prefetch=8,
dataloader_workers=32,
dataloader_prefetch=4,
summary_dir=".runs",
write_every=2**5,
save_models=True,
@@ -53,80 +67,51 @@ pytorch_settings = PytorchSettings(
model_settings = ModelSettings(
output_dim=2,
n_hidden_layers=5,
n_hidden_layers=3,
overrides={
# "hidden_layer_dims": (8, 8, 4, 4),
"n_hidden_nodes_0": 8,
"n_hidden_nodes_0": 16,
"n_hidden_nodes_1": 8,
"n_hidden_nodes_2": 4,
"n_hidden_nodes_3": 4,
"n_hidden_nodes_4": 2,
"n_hidden_nodes_2": 8,
# "n_hidden_nodes_3": 4,
# "n_hidden_nodes_4": 2,
},
model_activation_func="EOActivation",
dropout_prob=0.01,
dropout_prob=0,
model_layer_function="ONNRect",
model_layer_kwargs={"square": True},
scale=False,
scale=2.0,
model_layer_parametrizations=[
{
"tensor_name": "weight",
"parametrization": util.complexNN.energy_conserving,
},
# EOactivation
{
"tensor_name": "alpha",
"parametrization": util.complexNN.clamp,
"kwargs": {
"min": 0,
"max": 1,
},
},
# ONNRect
{
"tensor_name": "gain",
"tensor_name": "weight",
"parametrization": torch.nn.utils.parametrizations.orthogonal,
},
# Scale
{
"tensor_name": "scale",
"parametrization": util.complexNN.clamp,
"kwargs": {
"min": 0,
"max": float("inf"),
"max": 10,
},
},
{
"tensor_name": "phase_bias",
"parametrization": util.complexNN.clamp,
"kwargs": {
"min": 0,
"max": 2 * torch.pi,
},
},
{
"tensor_name": "scales",
"parametrization": util.complexNN.clamp,
},
{
"tensor_name": "angle",
"parametrization": util.complexNN.clamp,
"kwargs": {
"min": -torch.pi,
"max": torch.pi,
},
},
# {
# "tensor_name": "scale",
# "parametrization": util.complexNN.clamp,
# },
# {
# "tensor_name": "bias",
# "parametrization": util.complexNN.clamp,
# },
# {
# "tensor_name": "V",
# "parametrization": torch.nn.utils.parametrizations.orthogonal,
# },
{
"tensor_name": "loss",
"parametrization": util.complexNN.clamp,
},
}
],
)
optimizer_settings = OptimizerSettings(
optimizer="AdamW",
optimizer_kwargs={
"lr": 0.01,
"lr": 0.005,
"amsgrad": True,
# "weight_decay": 1e-7,
},
@@ -134,175 +119,35 @@ optimizer_settings = OptimizerSettings(
scheduler="ReduceLROnPlateau",
scheduler_kwargs={
"patience": 2**6,
"factor": 0.75,
"factor": 0.5,
# "threshold": 1e-3,
"min_lr": 1e-6,
"cooldown": 10,
},
early_stopping=True,
early_stop_kwargs={
"threshold": 1e-06,
"plateau": 2**7,
}
)
def save_dict_to_file(dictionary, filename):
"""
Save the best dictionary to a JSON file.
:param best: Dictionary containing the best training results.
:type best: dict
:param filename: Path to the JSON file where the dictionary will be saved.
:type filename: str
"""
with open(filename, "w") as f:
json.dump(dictionary, f, indent=4)
def sweep_lengths(*lengths, model=None, data_glob: str = None, strategy="newest"):
assert model is not None, "Model must be provided."
assert data_glob is not None, "Data glob must be provided."
model = model
fiber_ins = {}
fiber_outs = {}
regens = {}
timestampss = {}
trainer = RegenerationTrainer(
checkpoint_path=model,
)
trainer.define_model()
for length in lengths:
data_glob_length = data_glob.replace("{length}", str(length))
files = list(Path.cwd().glob(data_glob_length))
if len(files) == 0:
continue
if strategy == "newest":
sorted_kwargs = {
"key": lambda x: x.stat().st_mtime,
"reverse": True,
}
elif strategy == "oldest":
sorted_kwargs = {
"key": lambda x: x.stat().st_mtime,
"reverse": False,
}
else:
raise ValueError(f"Unknown strategy {strategy}.")
file = sorted(files, **sorted_kwargs)[0]
loader, _ = trainer.get_sliced_data(override={"config_path": file})
fiber_in, fiber_out, regen, timestamps = trainer.run_model(trainer.model, loader=loader)
fiber_ins[length] = fiber_in
fiber_outs[length] = fiber_out
regens[length] = regen
timestampss[length] = timestamps
data = torch.zeros(2 * len(timestampss.keys()) + 2, 2, tuple(fiber_outs.values())[-1].shape[0])
channel_names = ["" for _ in range(2 * len(timestampss.keys()) + 2)]
data[1, 0, :] = timestampss[tuple(timestampss.keys())[-1]] / 128
data[1, 1, :] = fiber_ins[tuple(timestampss.keys())[-1]][:, 0].abs().square()
channel_names[1] = "fiber in x"
for li, length in enumerate(timestampss.keys()):
data[2 + 2 * li, 0, :] = timestampss[length] / 128
data[2 + 2 * li, 1, :] = fiber_outs[length][:, 0].abs().square()
data[2 + 2 * li + 1, 0, :] = timestampss[length] / 128
data[2 + 2 * li + 1, 1, :] = regens[length][:, 0].abs().square()
channel_names[2 + 2 * li + 1] = f"regen x {length}"
channel_names[2 + 2 * li] = f"fiber out x {length}"
# get current backend
backend = matplotlib.get_backend()
matplotlib.use("TkCairo")
eye = util.eye_diagram.eye_diagram(data.to(dtype=torch.float32).detach().cpu().numpy(), channel_names=channel_names)
print_attrs = ("channel_name", "success", "min_area")
with np.printoptions(precision=3, suppress=True, formatter={"float": "{:0.3e}".format}):
for result in eye.eye_stats:
print_dict = {attr: result[attr] for attr in print_attrs}
rprint(print_dict)
rprint()
eye.plot(all_stats=False)
matplotlib.use(backend)
if __name__ == "__main__":
# lengths = range(90000, 100000+10000, 10000)
# lengths = [100000]
# sweep_lengths(*lengths, model=".models/best_20241204_132605.tar", data_glob="data/202412*-{length}-*-0.ini", strategy="newest")
trainer = RegenerationTrainer(
global_settings=global_settings,
data_settings=data_settings,
pytorch_settings=pytorch_settings,
model_settings=model_settings,
optimizer_settings=optimizer_settings,
# checkpoint_path=".models/best_20241205_235929.tar",
# 20241202_143149
checkpoint_path=".models/best_20250117_144001.tar",
new_model=True,
settings_override={
"data_settings": data_settings.__dict__,
# "optimizer_settings": {
# "early_stop_kwargs":{
# "plateau": 2**8,
# }
# }
}
)
trainer.train()
# from hypertraining.lighning_models import regenerator, regeneratorData
# import lightning as L
# model = regenerator(
# 2 * data_settings.output_size,
# *model_settings.overrides["hidden_layer_dims"],
# model_settings.output_dim,
# layer_function=getattr(util.complexNN, model_settings.model_layer_function),
# layer_func_kwargs=model_settings.model_layer_kwargs,
# act_function=getattr(util.complexNN, model_settings.model_activation_func),
# act_func_kwargs=None,
# parametrizations=model_settings.model_layer_parametrizations,
# dtype=getattr(torch, data_settings.dtype),
# dropout_prob=model_settings.dropout_prob,
# scale_layers=model_settings.scale,
# optimizer=getattr(torch.optim, optimizer_settings.optimizer),
# optimizer_kwargs=optimizer_settings.optimizer_kwargs,
# lr_scheduler=getattr(torch.optim.lr_scheduler, optimizer_settings.scheduler),
# lr_scheduler_kwargs=optimizer_settings.scheduler_kwargs,
# )
# dm = regeneratorData(
# config_globs=data_settings.config_path,
# output_symbols=data_settings.symbols,
# output_dim=data_settings.output_size,
# dtype=getattr(torch, data_settings.dtype),
# drop_first=data_settings.drop_first,
# shuffle=data_settings.shuffle,
# train_split=data_settings.train_split,
# batch_size=pytorch_settings.batchsize,
# loader_settings={
# "num_workers": pytorch_settings.dataloader_workers,
# "prefetch_factor": pytorch_settings.dataloader_prefetch,
# "pin_memory": True,
# "drop_last": True,
# },
# seed=global_settings.seed,
# )
# # writer = L.SummaryWriter(pytorch_settings.summary_dir + f"/{datetime.now().strftime('%Y%m%d_%H%M%S')}")
# # from torch.utils.tensorboard import SummaryWriter
# subdir = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}"
# # writer = SummaryWriter(pytorch_settings.summary_dir + f"/{subdir}")
# logger = L.pytorch.loggers.TensorBoardLogger(pytorch_settings.summary_dir, name=subdir, log_graph=True)
# trainer = L.Trainer(
# fast_dev_run=False,
# # max_epochs=pytorch_settings.epochs,
# max_epochs=2,
# enable_checkpointing=True,
# default_root_dir=f".models/{subdir}/",
# logger=logger,
# )
# trainer.fit(model, dm)

View File

@@ -12,20 +12,22 @@ Full license text in LICENSE file
"""
import configparser
# import copy
from datetime import datetime
import hashlib
from pathlib import Path
import time
import h5py
from matplotlib import pyplot as plt # noqa: F401
import numpy as np
import add_pypho # noqa: F401
from . import add_pypho # noqa: F401
import pypho
default_config = f"""
[glova]
nos = 256
sps = 256
sps = 128
nos = 16384
f0 = 193414489032258.06
symbolrate = 10e9
wisdom_dir = "{str((Path.home() / ".pypho"))}"
@@ -37,9 +39,9 @@ length = 10000
gamma = 1.14
alpha = 0.2
D = 17
S = 0
birefsteps = 0
max_delta_beta = 0.4
S = 0.058
bireflength = 10
pmd_q = 0.2
; birefseed = 0xC0FFEE
[signal]
@@ -47,17 +49,15 @@ max_delta_beta = 0.4
modulation = "pam"
mod_order = 4
mod_depth = 0.8
mod_depth = 1
max_jitter = 0.02
; jitter_seed = 0xC0FFEE
laser_power = 0
edfa_power = 3
edfa_power = 0
edfa_nf = 5
pulse_shape = "gauss"
fwhm = 0.33
osnr = "inf"
[data]
dir = "data"
@@ -71,6 +71,7 @@ def get_config(config_file=None):
"""
if config_file is None:
config_file = Path(__file__).parent / "signal_generation.ini"
config_file = Path(config_file)
if not config_file.exists():
with open(config_file, "w") as f:
f.write(default_config)
@@ -83,7 +84,10 @@ def get_config(config_file=None):
conf[section] = {}
for key in config[section]:
# print(f"{key} = {config[section][key]}")
conf[section][key] = eval(config[section][key])
try:
conf[section][key] = eval(config[section][key])
except NameError:
conf[section][key] = float(config[section][key])
# if isinstance(conf[section][key], str):
# conf[section][key] = config[section][key].strip('"')
return conf
@@ -96,7 +100,9 @@ class PDM_IM_IPM:
mod_order=8,
seed=None,
):
assert np.cbrt(mod_order) == int(np.cbrt(mod_order)) and mod_order > 1, "mod_order must be a cube of an integer greater than 1"
assert np.cbrt(mod_order) == int(np.cbrt(mod_order)) and mod_order > 1, (
"mod_order must be a cube of an integer greater than 1"
)
self.glova = glova
self.mod_order = mod_order
self.symbols_per_dim = int(np.cbrt(mod_order))
@@ -106,18 +112,11 @@ class PDM_IM_IPM:
rs = np.random.RandomState(self.seed)
symbols = rs.randint(0, self.mod_order, n)
return symbols
class pam_generator:
def __init__(
self,
glova,
mod_order=None,
mod_depth=0.5,
pulse_shape="gauss",
fwhm=0.33,
seed=None,
single_channel=False
self, glova, mod_order=None, mod_depth=0.5, pulse_shape="gauss", fwhm=0.33, seed=None, single_channel=False
) -> None:
self.glova = glova
self.pulse_shape = pulse_shape
@@ -133,49 +132,43 @@ class pam_generator:
wavelet = self.gauss(oversampling=6)
else:
raise ValueError(f"Unknown pulse shape: {self.pulse_shape}")
# prepare symbols
symbols_x = symbols[0] / (self.mod_order)
diffs_x = np.diff(symbols_x, prepend=symbols_x[0])
digital_x = self.generate_digital_signal(diffs_x, max_jitter)
digital_x = np.pad(
digital_x, (0, self.glova.sps // 2), "constant", constant_values=(0, 0)
)
digital_x = np.pad(digital_x, (0, self.glova.sps // 2), "constant", constant_values=(0, 0))
# create analog signal of diff of symbols
E_x = np.convolve(digital_x, wavelet)
# convert to pam and set modulation depth (scale and move up such that 1 stays at 1)
E_x = np.cumsum(E_x) * self.modulation_depth + 2*(1 - self.modulation_depth)
E_x = np.cumsum(E_x) * self.modulation_depth + (1 - self.modulation_depth)
# cut off the wavelet tails
E_x = E_x[self.glova.sps // 2 + len(wavelet) // 2 - 1 : -len(wavelet) // 2]
# modulate the laser
E[0]["E"][0] = np.sqrt(np.multiply(np.square(E[0]["E"][0]), E_x))
if not self.single_channel:
symbols_y = symbols[1] / (self.mod_order)
diffs_y = np.diff(symbols_y, prepend=symbols_y[0])
digital_y = self.generate_digital_signal(diffs_y, max_jitter)
digital_y = np.pad(
digital_y, (0, self.glova.sps // 2), "constant", constant_values=(0, 0)
)
digital_y = np.pad(digital_y, (0, self.glova.sps // 2), "constant", constant_values=(0, 0))
E_y = np.convolve(digital_y, wavelet)
E_y = np.cumsum(E_y) * self.modulation_depth + (1 - self.modulation_depth)
E_y = E_y[self.glova.sps // 2 + len(wavelet) // 2 - 1 : -len(wavelet) // 2]
E[0]["E"][1] = np.sqrt(np.multiply(np.square(E[0]["E"][1]), E_y))
# rotate the signal on the y-polarisation by 90°
E[0]["E"][1] *= 1j
# E[0]["E"][1] *= 1j
else:
E[0]["E"][1] = np.zeros_like(E[0]["E"][0], dtype=E[0]["E"][0].dtype)
return E
def generate_digital_signal(self, symbols, max_jitter=0):
rs = np.random.RandomState(self.seed)
@@ -198,19 +191,19 @@ class pam_generator:
endpoint=True,
)
sigma = self.fwhm / (1 * np.sqrt(2 * np.log(2))) * self.glova.sps
pulse = (
1
/ (sigma * np.sqrt(2 * np.pi))
* np.exp(-np.square(sample_points) / (2 * np.square(sigma)))
)
pulse = 1 / (sigma * np.sqrt(2 * np.pi)) * np.exp(-np.square(sample_points) / (2 * np.square(sigma)))
return pulse
def initialize_fiber_and_data(config, input_data_override=None):
def initialize_fiber_and_data(config):
f0 = config["glova"].get("f0", None)
if f0 is None:
f0 = 299792458/(config["glova"].get("lambda0", 1550)*1e-9)
config["glova"]["f0"] = f0
py_glova = pypho.setup(
nos=config["glova"]["nos"],
sps=config["glova"]["sps"],
f0=config["glova"]["f0"],
f0=f0,
symbolrate=config["glova"]["symbolrate"],
wisdom_dir=config["glova"]["wisdom_dir"],
flags=config["glova"]["flags"],
@@ -221,49 +214,89 @@ def initialize_fiber_and_data(config, input_data_override=None):
c_data = pypho.cfiber.DataWrapper(py_glova.sps * py_glova.nos)
py_edfa = pypho.oamp(py_glova, Pmean=config["signal"]["edfa_power"], NF=config["signal"]["edfa_nf"])
if input_data_override is not None:
c_data.E_in = input_data_override[0]
noise = input_data_override[1]
else:
config["signal"]["seed"] = config["signal"].get(
"seed", (int(time.time() * 1000)) % 2**32
)
config["signal"]["jitter_seed"] = config["signal"].get(
"jitter_seed", (int(time.time() * 1000)) % 2**32
)
symbolsrc = pypho.symbols(
py_glova, py_glova.nos, pattern="ones", p1=config["signal"]["mod_order"], seed=config["signal"]["seed"]
)
laser = pypho.lasmod(
py_glova, power=config["signal"]["laser_power"]+1.5, Df=0, theta=np.pi / 4
)
modulator = pam_generator(
py_glova,
mod_depth=config["signal"]["mod_depth"],
pulse_shape=config["signal"]["pulse_shape"],
fwhm=config["signal"]["fwhm"],
seed=config["signal"]["jitter_seed"],
single_channel=False,
mod_order=config["signal"]["mod_order"],
osnr = config["signal"]["osnr"] = float(config["signal"].get("osnr", "inf"))
config["signal"]["seed"] = config["signal"].get("seed", (int(time.time() * 1000)) % 2**32)
config["signal"]["jitter_seed"] = config["signal"].get("jitter_seed", (int(time.time() * 1000)) % 2**32)
symbolsrc = pypho.symbols(
py_glova, py_glova.nos, pattern="ones", p1=config["signal"]["mod_order"], seed=config["signal"]["seed"]
)
laserx = pypho.lasmod(py_glova, power=0, Df=0, theta=np.pi/4)
# lasery = pypho.lasmod(py_glova, power=0, Df=25, theta=0)
modulator = pam_generator(
py_glova,
mod_depth=config["signal"]["mod_depth"],
pulse_shape=config["signal"]["pulse_shape"],
fwhm=config["signal"]["fwhm"],
seed=config["signal"]["jitter_seed"],
mod_order=config["signal"]["mod_order"],
)
symbols_x = symbolsrc(pattern="random")
symbols_y = symbolsrc(pattern="random")
symbols_x[:3] = 0
symbols_y[:3] = 0
# symbols_x += 1
cw = laserx()
# cwy = lasery()
# cw[0]['E'][0] = cw[0]['E'][0]
# cw[0]['E'][1] = cwy[0]['E'][0]
source_signal = modulator(E=cw, symbols=(symbols_x, symbols_y))
if osnr != float("inf"):
osnr_lin = 10 ** (osnr / 10)
signal_power = np.sum(pypho.functions.getpower_W(source_signal[0]["E"]))
noise_power = signal_power / osnr_lin
noise = np.random.normal(0, 1, source_signal[0]["E"].shape) + 1j * np.random.normal(
0, 1, source_signal[0]["E"].shape
)
noise_power_is = np.sum(pypho.functions.getpower_W(noise))
noise = noise * np.sqrt(noise_power / noise_power_is)
noise_power_is = np.sum(pypho.functions.getpower_W(noise))
source_signal[0]["E"] += noise
source_signal[0]["noise"] = noise_power_is
symbols_x = symbolsrc(pattern="random")
symbols_y = symbolsrc(pattern="random")
symbols_x[:3] = 0
symbols_y[:3] = 0
# symbols_x += 1
# source_signal[0]['E'][1] += source_signal[0]['E'][1][np.argmin(np.abs(source_signal[0]['E'][1]))]
## side channels
# df = 100
# signal_power = pypho.functions.W_to_dBm(np.sum(pypho.functions.getpower_W(source_signal[0]["E"])))
cw = laser()
# symbols_x_side = symbolsrc(pattern="random")
# symbols_y_side = symbolsrc(pattern="random")
# symbols_x_side[:3] = 0
# symbols_y_side[:3] = 0
# cw_left = laser(Df=-df)
# source_signal_left = modulator(E=cw_left, symbols=(symbols_x_side, symbols_y_side))
# cw_right = laser(Df=df)
# source_signal_right = modulator(E=cw_right, symbols=(symbols_y_side, symbols_x_side))
source_signal = modulator(E=cw, symbols=(symbols_x, symbols_y))
E_in_pure = source_signal[0]["E"]
# source_signal[0]['E'][1] += source_signal[0]['E'][1][np.argmin(np.abs(source_signal[0]['E'][1]))]
source_signal = py_edfa(E=source_signal)
nf = py_edfa.NF
pmean = py_edfa.Pmean
c_data.E_in = source_signal[0]["E"]
noise = source_signal[0]["noise"]
# ideal amplification to launch power into fiber
source_signal = py_edfa(E=source_signal, NF=0, Pmean=config["signal"]["laser_power"])
# source_signal_left = py_edfa(E=source_signal_left, NF=0, Pmean=config["signal"]["laser_power"])
# source_signal_right = py_edfa(E=source_signal_right, NF=0, Pmean=config["signal"]["laser_power"])
# source_signal[0]["E"][0] += source_signal_left[0]["E"][0] + source_signal_right[0]["E"][0]
# source_signal[0]["E"][1] += source_signal_left[0]["E"][1] + source_signal_right[0]["E"][1]
c_data.E_in = source_signal[0]["E"]
noise = source_signal[0]["noise"]
py_edfa.NF = nf
py_edfa.Pmean = pmean
py_fiber = pypho.fiber(
glova=py_glova,
@@ -272,27 +305,32 @@ def initialize_fiber_and_data(config, input_data_override=None):
gamma=config["fiber"]["gamma"],
D=config["fiber"]["d"],
S=config["fiber"]["s"],
phi_max=0.02,
)
if config["fiber"].get("birefsteps", 0) > 0:
seed = config["fiber"].get(
"birefseed", (int(time.time() * 1000)) % 2**32
)
config["fiber"]["birefsteps"] = config["fiber"].get(
"birefsteps", config["fiber"]["length"] // config["fiber"].get("bireflength", config["fiber"]["length"])
)
if config["fiber"]["birefsteps"] > 0:
config["fiber"]["bireflength"] = config["fiber"].get("bireflength", config["fiber"]["length"] / config["fiber"]["birefsteps"])
seed = config["fiber"].get("birefseed", (int(time.time() * 1000)) % 2**32)
py_fiber.birefarray = pypho.birefringence_segment.create_pmd_fibre(
py_fiber.l,
py_fiber.l / config["fiber"]["birefsteps"],
# maxDeltaD=config["fiber"]["d"]/5,
maxDeltaBeta = config["fiber"].get("max_delta_beta", 0),
config["fiber"]["length"],
config["fiber"]["bireflength"],
maxDeltaBeta=config["fiber"].get("pmd_q", 0)/np.sqrt(2*config["fiber"]["bireflength"]),
seed=seed,
)
c_params = pypho.cfiber.ParamsWrapper.from_fiber(
py_fiber, max_step=1e3 if py_fiber.gamma == 0 else 200
)
elif (dgd := config['fiber'].get('dgd', 0)) > 0:
py_fiber.birefarray = [
pypho.birefringence_segment(z_point=0, angle=np.pi/2, delta_beta=1000*dgd/config["fiber"]["length"])
]
c_params = pypho.cfiber.ParamsWrapper.from_fiber(py_fiber, max_step=config["fiber"]["length"] if py_fiber.gamma == 0 else 200)
c_fiber = pypho.cfiber.FiberWrapper(c_data, c_params, c_glova)
return c_fiber, c_data, noise, py_edfa
return c_fiber, c_data, noise, py_edfa, (symbols_x, symbols_y), py_glova, E_in_pure
def save_data(data, config):
def save_data(data, config, **metadata):
data_dir = Path(config["data"]["dir"])
npy_dir = config["data"].get("npy_dir", "")
save_dir = data_dir / npy_dir if len(npy_dir) else data_dir
@@ -307,6 +345,7 @@ def save_data(data, config):
seed = config["signal"].get("seed", False)
jitter_seed = config["signal"].get("jitter_seed", False)
birefseed = config["fiber"].get("birefseed", False)
osnr = float(config["signal"].get("osnr", "inf"))
config_content = "\n".join((
f"; Generated by {str(Path(__file__).name)} @ {timestamp.strftime('%Y-%m-%d %H:%M:%S')}",
@@ -318,16 +357,19 @@ def save_data(data, config):
f'wisdom_dir = "{config["glova"]["wisdom_dir"]}"',
f'flags = "{config["glova"]["flags"]}"',
f"nthreads = {config['glova']['nthreads']}",
" ",
"",
"[fiber]",
f"length = {config['fiber']['length']}",
f"gamma = {config['fiber']['gamma']}",
f"alpha = {config['fiber']['alpha']}",
f"D = {config['fiber']['d']}",
f"S = {config['fiber']['s']}",
f"birefsteps = {config['fiber'].get('birefsteps',0)}",
f"max_delta_beta = {config['fiber'].get('max_delta_beta', 0)}",
f"birefsteps = {config['fiber'].get('birefsteps', 0)}",
f"pmd_q = {config['fiber'].get('pmd_q', 0)}",
f"birefseed = {hex(birefseed)}" if birefseed else "; birefseed = not set",
f"dgd = {config['fiber'].get('dgd', 0)}",
f"ortho_error = {config['fiber'].get('ortho_error', 0)}",
f"pol_error = {config['fiber'].get('pol_error', 0)}",
"",
"[signal]",
f"seed = {hex(seed)}" if seed else "; seed = not set",
@@ -335,100 +377,93 @@ def save_data(data, config):
f'modulation = "{config["signal"]["modulation"]}"',
f"mod_order = {config['signal']['mod_order']}",
f"mod_depth = {config['signal']['mod_depth']}",
""
"",
f"max_jitter = {config['signal']['max_jitter']}",
f"jitter_seed = {hex(jitter_seed)}" if jitter_seed else "; jitter_seed = not set",
""
"",
f"laser_power = {config['signal']['laser_power']}",
f"edfa_power = {config['signal']['edfa_power']}",
f"edfa_nf = {config['signal']['edfa_nf']}",
""
f"osnr = {osnr}",
"",
f'pulse_shape = "{config["signal"]["pulse_shape"]}"',
f"fwhm = {config['signal']['fwhm']}",
"",
"[data]",
f'dir = "{str(data_dir)}"',
f'npy_dir = "{npy_dir}"',
"file = "
"file = ",
))
config_hash = hashlib.md5(config_content.encode()).hexdigest()
save_file = f"{config_hash}.npy"
save_file = f"{config_hash}.h5"
config_content += f'"{str(save_file)}"\n'
config_filename:Path = create_config_filename(config, data_dir, timestamp)
while config_filename.exists():
time.sleep(1)
config_filename = create_config_filename(config, data_dir=data_dir)
with open(config_filename, "w") as f:
f.write(config_content)
with h5py.File(save_dir / save_file, "w") as outfile:
outfile.create_dataset("data", data=save_data)
outfile.create_dataset("symbols", data=metadata.pop("symbols"))
for key, value in metadata.items():
# if isinstance(value, dict):
# value = json.dumps(model_runner.convert_arrays(value))
outfile.attrs[key] = value
# np.save(save_dir / save_file, save_data)
# print("Saved config to", config_filename)
# print("Saved data to", save_dir / save_file)
return config_filename
def create_config_filename(config, data_dir:Path, timestamp=None):
if timestamp is None:
timestamp = datetime.now()
filename_components = (
timestamp.strftime("%Y%m%d-%H%M%S"),
config["glova"]["sps"],
config["glova"]["nos"],
config["signal"]["osnr"],
config["fiber"]["length"],
config["fiber"]["gamma"],
config["fiber"]["alpha"],
config["fiber"]["d"],
config["fiber"]["s"],
f"{config['signal']['modulation'].upper()}{config['signal']['mod_order']}",
config['fiber'].get('birefsteps',0),
config["fiber"].get("max_delta_beta", 0),
config["fiber"].get("birefsteps", 0),
config["fiber"].get("pmd_q", 0),
int(config["glova"]["symbolrate"] / 1e9),
)
lookup_file = "-".join(map(str, filename_components)) + ".ini"
with open(data_dir / lookup_file, "w") as f:
f.write(config_content)
return data_dir / lookup_file
np.save(save_dir / save_file, save_data)
print("Saved config to", data_dir / lookup_file)
print("Saved data to", save_dir / save_file)
def length_loop(config, lengths, incremental=False, bireflength=None, save=True):
def length_loop(config, lengths, save=True):
lengths = sorted(lengths)
input_override = None
birefsteps_running = 0
for lind, length in enumerate(lengths):
# print(f"\nGenerating data for fiber length {length}")
if lind > 0 and incremental:
# set the length to the difference between the current and previous length -> incremental
length = lengths[lind] - lengths[lind - 1]
if incremental:
print(
f"\nGenerating data for fiber length {lengths[lind]}m [using {length}m increment]"
)
else:
print(f"\nGenerating data for fiber length {length}m")
for length in lengths:
print(f"\nGenerating data for fiber length {length}m")
config["fiber"]["length"] = length
if bireflength is not None and bireflength > 0:
config["fiber"]["birefsteps"] = length // bireflength
birefsteps_running += config["fiber"]["birefsteps"]
# set the input data to the output data of the previous run
cfiber, cdata, noise, edfa = initialize_fiber_and_data(
config, input_data_override=input_override
)
if lind == 0:
cdata_orig = cdata
cfiber, cdata, noise, edfa, symbols, py_glova = initialize_fiber_and_data(config)
mean_power_in = np.sum(pypho.functions.getpower_W(cdata.E_in))
print(
f"Mean power in: {mean_power_in:.3e} W ({pypho.functions.W_to_dBm(mean_power_in):.3e} dBm)"
)
cfiber()
mean_power_out = np.sum(pypho.functions.getpower_W(cdata.E_out))
print(
f"Mean power out: {mean_power_out:.3e} W ({pypho.functions.W_to_dBm(mean_power_out):.3e} dBm)"
)
if incremental:
input_override = (cdata.E_out, noise)
cdata.E_in = cdata_orig.E_in
config["fiber"]["length"] = lengths[lind]
if bireflength is not None:
config["fiber"]["birefsteps"] = birefsteps_running
E_tmp = [{'E': cdata.E_out, 'noise': noise*(-cfiber.params.l*cfiber.params.alpha)}]
print(f"Mean power in: {mean_power_in:.3e} W ({pypho.functions.W_to_dBm(mean_power_in):.3e} dBm)")
print(f"Mean power out: {mean_power_out:.3e} W ({pypho.functions.W_to_dBm(mean_power_out):.3e} dBm)")
E_tmp = [{"E": cdata.E_out, "noise": noise * (-cfiber.params.l * cfiber.params.alpha)}]
E_tmp = edfa(E=E_tmp)
cdata.E_out = E_tmp[0]['E']
cdata.E_out = E_tmp[0]["E"]
mean_power_amp = np.sum(pypho.functions.getpower_W(cdata.E_out))
print(f"Mean power after EDFA: {mean_power_amp:.3e} W ({pypho.functions.W_to_dBm(mean_power_amp):.3e} dBm)")
if save:
save_data(cdata, config)
@@ -436,27 +471,55 @@ def length_loop(config, lengths, incremental=False, bireflength=None, save=True)
def single_run_with_plot(config, save=True):
cfiber, cdata, noise, edfa = initialize_fiber_and_data(config)
mean_power_in = np.sum(pypho.functions.getpower_W(cdata.E_in))
print(
f"Mean power in: {mean_power_in:.3e} W ({pypho.functions.W_to_dBm(mean_power_in):.3e} dBm)"
)
cfiber()
mean_power_out = np.sum(pypho.functions.getpower_W(cdata.E_out))
print(
f"Mean power out: {mean_power_out:.3e} W ({pypho.functions.W_to_dBm(mean_power_out):.3e} dBm)"
)
E_tmp = [{'E': cdata.E_out, 'noise': noise*(-cfiber.params.l*cfiber.params.alpha)}]
E_tmp = edfa(E=E_tmp)
cdata.E_out = E_tmp[0]['E']
if save:
save_data(cdata, config)
cfiber, cdata, config_filename = single_run(config, save)
in_out_eyes(cfiber, cdata, show_pols=False)
return config_filename
def single_run(config, save=True, silent=True):
cfiber, cdata, noise, edfa, symbols, glova, E_in = initialize_fiber_and_data(config)
# transmit
cfiber()
# amplify
E_tmp = [{"E": cdata.E_out, "noise": noise}]
E_tmp = edfa(E=E_tmp)
# rotate
# ortho error
ortho_error = config["fiber"].get("ortho_error", 0)
E_tmp[0]["E"] = np.stack((
E_tmp[0]["E"][0] * np.cos(ortho_error/2) + E_tmp[0]["E"][1] * np.sin(ortho_error/2),
E_tmp[0]["E"][0] * np.sin(ortho_error/2) + E_tmp[0]["E"][1] * np.cos(ortho_error/2)
), axis=0)
pol_error = config['fiber'].get('pol_error', 0)
E_tmp[0]["E"] = np.stack((
E_tmp[0]["E"][0] * np.cos(pol_error) - E_tmp[0]["E"][1] * np.sin(pol_error),
E_tmp[0]["E"][0] * np.sin(pol_error) + E_tmp[0]["E"][1] * np.cos(pol_error)
), axis=0)
# output
cdata.E_out = E_tmp[0]["E"]
config_filename = None
symbols = np.array(symbols)
if save:
config_filename = save_data(cdata, config, **{"symbols": symbols})
if not silent:
print(f"Saved config to {config_filename}")
return cfiber, cdata, config_filename
def in_out_eyes(cfiber, cdata, show_pols=False):
fig, axs = plt.subplots(2, 2, sharex=True, sharey=True)
@@ -620,9 +683,7 @@ def plot_eye_diagram(
signal = signal[: head * eye_width]
if normalize:
signal = signal / np.max(signal)
slices = np.lib.stride_tricks.sliding_window_view(signal, eye_width + 1)[
offset % (eye_width + 1) :: eye_width
]
slices = np.lib.stride_tricks.sliding_window_view(signal, eye_width + 1)[offset % (eye_width + 1) :: eye_width]
plt_ax = np.arange(-eye_width // 2, eye_width // 2 + 1) / samplerate
for slice in slices:
ax.plot(plt_ax, slice, color=color, alpha=0.1)
@@ -642,14 +703,27 @@ if __name__ == "__main__":
# lengths.append(10*max(ranges))
# lengths = [*lengths, *lengths]
lengths = (
# 8000, 9000,
10000, 20000, 30000, 40000, 50000, 60000, 70000, 80000, 90000,
95000, 100000, 105000, 110000, 115000, 120000
# 8000, 9000,
10000,
20000,
30000,
40000,
50000,
60000,
70000,
80000,
90000,
95000,
100000,
105000,
110000,
115000,
120000,
)
lengths = sorted(lengths)
length_loop(config, lengths, incremental=False, bireflength=1000, save=True)
# lengths = (10000,100000)
# length_loop(config, lengths, save=True)
# birefringence is constant over coupling length -> several 100m -> bireflength=1000 (m)
# single_run_with_plot(config, save=True)
single_run_with_plot(config, save=False)

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff

View File

@@ -26,7 +26,7 @@ global_settings = GlobalSettings(
)
data_settings = DataSettings(
config_path="data/*-128-16384-10000-0-0-0-0-PAM4-0-0.4.ini",
config_path="data/20241211-105524-128-16384-1-0-0-0-0-PAM4-0-0.ini",
# config_path=[f"data/20241202-*-128-16384-{length}-0-0-17-0-PAM4-0.ini" for length in range(48000, 53000, 1000)],
dtype="complex64",
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
@@ -53,14 +53,14 @@ pytorch_settings = PytorchSettings(
)
model_settings = ModelSettings(
output_dim=3,
output_dim=1,
n_hidden_layers=3,
overrides={
"n_hidden_nodes_0": 2,
"n_hidden_nodes_1": 2,
"n_hidden_nodes_2": 2,
"n_hidden_nodes_0": 4,
"n_hidden_nodes_1": 4,
"n_hidden_nodes_2": 4,
},
dropout_prob=0.01,
dropout_prob=0,
model_layer_function="ONNRect",
model_activation_func="EOActivation",
model_layer_kwargs={"square": True},
@@ -110,20 +110,24 @@ model_settings = ModelSettings(
)
optimizer_settings = OptimizerSettings(
optimizer="AdamW",
optimizer="RMSprop",
# optimizer="AdamW",
optimizer_kwargs={
"lr": 0.005,
"amsgrad": True,
"lr": 0.01,
"alpha": 0.9,
"momentum": 0.1,
"eps": 1e-8,
"centered": True,
# "amsgrad": True,
# "weight_decay": 1e-7,
},
# learning_rate=0.05,
scheduler="ReduceLROnPlateau",
scheduler_kwargs={
"patience": 2**6,
"patience": 2**5,
"factor": 0.75,
# "threshold": 1e-3,
"min_lr": 1e-6,
"cooldown": 10,
# "cooldown": 10,
},
)

View File

@@ -319,6 +319,29 @@ class normalize_by_first(nn.Module):
def forward(self, data):
return data / data[:, 0].unsqueeze(1)
class rotate(nn.Module):
def __init__(self):
super(rotate, self).__init__()
def forward(self, data, angle):
# data -> (batch, n*2)
# angle -> (batch, n)
data_ = data
if angle.ndim == 1:
angle_ = angle.unsqueeze(1)
else:
angle_ = angle
angle_ = angle_.expand(-1, data_.shape[1]//2)
c = torch.cos(angle_)
s = torch.sin(angle_)
rot = torch.stack([torch.stack([c, -s], dim=2),
torch.stack([s, c], dim=2)], dim=3)
d = torch.bmm(data_.reshape(-1, 1, 2), rot.view(-1, 2, 2).to(dtype=data_.dtype)).reshape(*data.shape)
# d = torch.bmm(data.unsqueeze(-1).mT, rot.to(dtype=data.dtype).mT).mT.squeeze(-1)
return d
class photodiode(nn.Module):
def __init__(self, size, bias=True):
@@ -418,8 +441,7 @@ class input_rotator(nn.Module):
# return out
#### as defined by zhang et al
#### as defined by zhang et alas
class DropoutComplex(nn.Module):
def __init__(self, p=0.5):
@@ -441,7 +463,7 @@ class Scale(nn.Module):
self.scale = nn.Parameter(torch.ones(size, dtype=torch.float32))
def forward(self, x):
return x * self.scale
return x * torch.sqrt(self.scale)
def __repr__(self):
return f"Scale({self.size})"
@@ -458,6 +480,15 @@ class Identity(nn.Module):
def forward(self, x):
return x
class phase_shift(nn.Module):
def __init__(self, size):
super(phase_shift, self).__init__()
self.size = size
self.phase = nn.Parameter(torch.rand(size))
def forward(self, x):
return x * torch.exp(1j*self.phase)
class PowRot(nn.Module):
@@ -487,7 +518,7 @@ class MZISingle(nn.Module):
return x * torch.exp(1j * self.phi) * torch.sin(self.omega + self.func(x))
def naive_angle_loss(x: torch.Tensor, target: torch.Tensor, mod=2*torch.pi):
return torch.fmod((x - target), mod).square().mean()
return torch.fmod((x.abs().real - target.abs().real), mod).abs().mean()
def cosine_loss(x: torch.Tensor, target: torch.Tensor):
return (2*(1 - torch.cos(x - target))).mean()
@@ -508,51 +539,46 @@ def angle_mse_loss(x: torch.Tensor, target: torch.Tensor):
class EOActivation(nn.Module):
def __init__(self, size=None):
# 10.1109/SiPhotonics60897.2024.10543376
# 10.1109/JSTQE.2019.2930455
super(EOActivation, self).__init__()
if size is None:
raise ValueError("Size must be specified")
self.size = size
self.alpha = nn.Parameter(torch.ones(size))
self.V_bias = nn.Parameter(torch.ones(size))
self.gain = nn.Parameter(torch.ones(size))
# if bias:
# self.phase_bias = nn.Parameter(torch.zeros(size))
# else:
# self.register_buffer("phase_bias", torch.zeros(size))
self.register_buffer("phase_bias", torch.clamp(torch.ones(size) + torch.randn(size)*0.1, 0, 1)*torch.pi)
self.register_buffer("responsivity", torch.ones(size)*0.9)
self.register_buffer("V_pi", torch.ones(size)*3)
self.alpha = nn.Parameter(torch.rand(size))
self.gain = nn.Parameter(torch.rand(size))
self.V_bias = nn.Parameter(torch.rand(size))
# self.register_buffer("gain", torch.ones(size))
# self.register_buffer("responsivity", torch.ones(size))
# self.register_buffer("V_pi", torch.ones(size))
self.reset_weights()
def reset_weights(self):
if "alpha" in self._parameters:
self.alpha.data = torch.ones(self.size)*0.5
if "V_pi" in self._parameters:
self.V_pi.data = torch.ones(self.size)*3
self.alpha.data = torch.rand(self.size)
# if "V_pi" in self._parameters:
# self.V_pi.data = torch.rand(self.size)*3
if "V_bias" in self._parameters:
self.V_bias.data = torch.zeros(self.size)
self.V_bias.data = torch.randn(self.size)
if "gain" in self._parameters:
self.gain.data = torch.ones(self.size)
if "responsivity" in self._parameters:
self.responsivity.data = torch.ones(self.size)*0.9
if "bias" in self._parameters:
self.phase_bias.data = torch.zeros(self.size)
self.gain.data = torch.rand(self.size)
# if "responsivity" in self._parameters:
# self.responsivity.data = torch.ones(self.size)*0.9
# if "bias" in self._parameters:
# self.phase_bias.data = torch.zeros(self.size)
def forward(self, x: torch.Tensor):
phi_b = torch.pi * self.V_bias / (self.V_pi + 1e-8)
g_phi = torch.pi * (self.alpha * self.gain * self.responsivity) / (self.V_pi + 1e-8)
phi_b = torch.pi * self.V_bias# / (self.V_pi)
g_phi = torch.pi * (self.alpha * self.gain)# * self.responsivity)# / (self.V_pi)
intermediate = g_phi * x.abs().square() + phi_b
return (
1j
* torch.sqrt(1 - self.alpha)
* torch.exp(-0.5j * (intermediate + self.phase_bias))
* torch.exp(-0.5j * intermediate)
* torch.cos(0.5 * intermediate)
* x
)
class Pow(nn.Module):
"""
implements the activation function
@@ -693,6 +719,7 @@ __all__ = [
MZISingle,
EOActivation,
photodiode,
phase_shift,
# SaturableAbsorberLambertW,
# SaturableAbsorber,
# SpreadLayer,

View File

@@ -0,0 +1,105 @@
# Copyright (c) 2015, Warren Weckesser. All rights reserved.
# This software is licensed according to the "BSD 2-clause" license.
import hashlib
import h5py
import numpy as _np
from scipy.interpolate import interp1d as _interp1d
from scipy.ndimage import gaussian_filter as _gaussian_filter
from ._brescount import bres_curve_count as _bres_curve_count
from pathlib import Path
__all__ = ['grid_count']
def grid_count(y, window_size, offset=0, size=None, fuzz=True, blur=0, bounds=None):
"""
Parameters
----------
`y` is the 1-d array of signal samples.
`window_size` is the number of samples to show horizontally in the
eye diagram. Typically this is twice the number of samples in a
"symbol" (i.e. in a data bit).
`offset` is the number of initial samples to skip before computing
the eye diagram. This allows the overall phase of the diagram to
be adjusted.
`size` must be a tuple of two integers. It sets the size of the
array of counts, (height, width). The default is (800, 640).
`fuzz`: If True, the values in `y` are reinterpolated with a
random "fuzz factor" before plotting in the eye diagram. This
reduces an aliasing-like effect that arises with the use of
Bresenham's algorithm.
`bounds` must be a tuple of two floating point values, (ymin, ymax).
These set the y range of the returned array. If not given, the
bounds are `(y.min() - 0.05*A, y.max() + 0.05*A)`, where `A` is
`y.max() - y.min()`.
Return Value
------------
Returns a numpy array of integers.
"""
# hash input params
param_ob = (y, window_size, offset, size, fuzz, blur, bounds)
param_hash = hashlib.md5(str(param_ob).encode()).hexdigest()
cache_dir = Path.home()/".eyediagram"/".cache"
cache_dir.mkdir(parents=True, exist_ok=True)
if (cache_dir/param_hash).is_file():
try:
with h5py.File(cache_dir/param_hash, "r") as infile:
counts = infile["counts"][:]
if counts.len() != 0:
return counts
except:
pass
if size is None:
size = (800, 640)
height, width = size
dt = width / window_size
counts = _np.zeros((width, height), dtype=_np.int32)
if bounds is None:
ymin = y.min()
ymax = y.max()
yamp = ymax - ymin
ymin = ymin - 0.05*yamp
ymax = ymax + 0.05*yamp
ymax = _np.ceil(ymax*10)/10
ymin = _np.floor(ymin*10)/10
else:
ymin, ymax = bounds
start = offset
while start + window_size < len(y):
end = start + window_size
yy = y[start:end+1]
k = _np.arange(len(yy))
xx = dt*k
if fuzz:
f = _interp1d(xx, yy, kind='cubic')
jiggle = dt*(_np.random.beta(a=3, b=3, size=len(xx)-2) - 0.5)
xx[1:-1] += jiggle
yd = f(xx)
else:
yd = yy
iyd = (height * (yd - ymin)/(ymax - ymin)).astype(_np.int32)
_bres_curve_count(xx.astype(_np.int32), iyd, counts)
start = end
if blur != 0:
counts = _gaussian_filter(counts, sigma=blur)
with h5py.File(cache_dir/param_hash, "w") as outfile:
outfile.create_dataset("data", data=counts)
return counts

View File

@@ -1,10 +1,12 @@
from pathlib import Path
import h5py
import torch
from torch.utils.data import Dataset
# from torch.utils.data import Sampler
import numpy as np
import configparser
import multiprocessing as mp
# class SubsetSampler(Sampler[int]):
# """
@@ -24,7 +26,22 @@ import configparser
# return len(self.indices)
def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=False, device=None, dtype=None):
def load_from_file(datapath):
if str(datapath).endswith(".h5"):
symbols = None
with h5py.File(datapath, "r") as infile:
data = infile["data"][:]
try:
symbols = np.swapaxes(infile["symbols"][:], 0, 1)
except KeyError:
pass
else:
symbols = None
data = np.load(datapath)
return data, symbols
def load_data(config_path, skipfirst=0, skiplast=0, symbols=None, real=False, normalize=1, device=None, dtype=None):
filepath = Path(config_path)
filepath = filepath.parent.glob(filepath.name)
config = configparser.ConfigParser()
@@ -40,14 +57,28 @@ def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=Fals
if symbols is None:
symbols = int(config["glova"]["nos"]) - skipfirst
data = np.load(datapath)[int(skipfirst * sps) : int(symbols * sps + skipfirst * sps)]
timestamps = np.arange(int(skipfirst * sps), int(symbols * sps + skipfirst * sps))
data, orig_symbols = load_from_file(datapath)
if normalize:
# square gets normalized to 1, as the power is (proportional to) the square of the amplitude
a, b, c, d = np.square(data.T)
a, b, c, d = a / np.max(np.abs(a)), b / np.max(np.abs(b)), c / np.max(np.abs(c)), d / np.max(np.abs(d))
data = np.sqrt(np.array([a, b, c, d]).T)
data = data[int(skipfirst * sps) : int(symbols * sps + skipfirst * sps - skiplast * sps)]
orig_symbols = orig_symbols[skipfirst : symbols + skipfirst - skiplast]
timestamps = np.arange(int(skipfirst * sps), int(symbols * sps + skipfirst * sps - skiplast * sps))
data *= np.sqrt(normalize)
launch_power = float(config["signal"]["laser_power"])
output_power = float(config["signal"]["edfa_power"])
target_normalization = 10 ** (output_power / 10) / 10 ** (launch_power / 10)
# target_normalization *= 0.5 # allow 50% power loss, so the network can ignore parts of the signal
data[:, 0:2] *= np.sqrt(target_normalization)
# if normalize:
# # square gets normalized to 1, as the power is (proportional to) the square of the amplitude
# a, b, c, d = data.T
# a, b, c, d = a - np.min(np.abs(a)), b - np.min(np.abs(b)), c - np.min(np.abs(c)), d - np.min(np.abs(d))
# a, b, c, d = a / np.max(np.abs(a)), b / np.max(np.abs(b)), c / np.max(np.abs(c)), d / np.max(np.abs(d))
# data = np.array([a, b, c, d]).T
if real:
data = np.abs(data)
@@ -58,7 +89,7 @@ def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=Fals
data = torch.tensor(data, device=device, dtype=dtype)
return data, config
return data, config, orig_symbols
def roll_along(arr, shifts, dim):
@@ -110,11 +141,15 @@ class FiberRegenerationDataset(Dataset):
target_delay: float | int = 0,
xy_delay: float | int = 0,
drop_first: float | int = 0,
drop_last=0,
dtype: torch.dtype = None,
real: bool = False,
device=None,
polarisations: tuple | list = (0,),
# osnr: float|None = None,
polarisations=None,
randomise_polarisations: bool = False,
repeat_randoms: int = 1,
# cross_pol_interference: float = 0,
**kwargs,
):
"""
@@ -148,64 +183,53 @@ class FiberRegenerationDataset(Dataset):
assert drop_first >= 0, "drop_first must be non-negative"
self.randomise_polarisations = randomise_polarisations
# self.cross_pol_interference = cross_pol_interference
faux = kwargs.pop("faux", False)
if faux:
data_raw = np.array(
[[i + 0.1j, i + 0.2j, i + 1.1j, i + 1.2j] for i in range(12800)],
dtype=np.complex128,
data_raw = None
self.config = None
files = []
self.orig_symbols = None
for file_path in file_path if isinstance(file_path, (tuple, list)) else [file_path]:
data, config, orig_syms = load_data(
file_path,
skipfirst=drop_first,
skiplast=drop_last,
symbols=kwargs.get("num_symbols", None),
real=real,
normalize=1000,
device=device,
dtype=dtype,
)
data_raw = torch.tensor(data_raw, device=device, dtype=dtype)
timestamps = torch.arange(12800)
data_raw = torch.concatenate([data_raw, timestamps.reshape(-1, 1)], axis=-1)
self.config = {
"data": {"dir": '"."', "npy_dir": '"."', "file": "faux"},
"glova": {"sps": 128},
}
else:
data_raw = None
self.config = None
files = []
for file_path in file_path if isinstance(file_path, (tuple, list)) else [file_path]:
data, config = load_data(
file_path,
skipfirst=drop_first,
symbols=kwargs.get("num_symbols", None),
real=real,
normalize=True,
device=device,
dtype=dtype,
)
if data_raw is None:
data_raw = data
if orig_syms is not None:
if self.orig_symbols is None:
self.orig_symbols = orig_syms
else:
data_raw = torch.cat([data_raw, data], dim=0)
if self.config is None:
self.config = config
else:
assert self.config["glova"]["sps"] == config["glova"]["sps"], "samples per symbol must be the same"
files.append(config["data"]["file"].strip('"'))
self.config["data"]["file"] = str(files)
self.orig_symbols = np.concat((self.orig_symbols, orig_syms), axis=-1)
for i, angle in enumerate(torch.tensor(np.array(polarisations))):
data_raw_copy = data_raw.clone()
if angle == 0:
continue
sine = torch.sin(angle)
cosine = torch.cos(angle)
data_raw_copy[:, 2] = data_raw[:, 2] * cosine - data_raw[:, 3] * sine
data_raw_copy[:, 3] = data_raw[:, 2] * sine + data_raw[:, 3] * cosine
if i == 0:
data_raw = data_raw_copy
else:
data_raw = torch.cat([data_raw, data_raw_copy], dim=0)
if data_raw is None:
data_raw = data
else:
data_raw = torch.cat([data_raw, data], dim=0)
if self.config is None:
self.config = config
else:
assert self.config["glova"]["sps"] == config["glova"]["sps"], "samples per symbol must be the same"
files.append(config["data"]["file"].strip('"'))
self.config["data"]["file"] = str(files)
# if polarisations is not None:
# data_raw_clone = data_raw.clone()
# # rotate the polarisation by 180 degrees
# data_raw_clone[2, :] *= -1
# data_raw_clone[3, :] *= -1
# data_raw = torch.cat([data_raw, data_raw_clone], dim=0)
self.polarisations = bool(polarisations)
self.device = data_raw.device
self.samples_per_symbol = int(self.config["glova"]["sps"])
# self.num_symbols = int(self.config["glova"]["nos"])
self.samples_per_slice = int(symbols * self.samples_per_symbol)
self.symbols_per_slice = self.samples_per_slice / self.samples_per_symbol
@@ -278,23 +302,94 @@ class FiberRegenerationDataset(Dataset):
timestamps = data_raw[4, :]
data_raw = data_raw[:4, :]
data_raw = data_raw.view(2, 2, -1)
timestamps_doubled = torch.cat([timestamps.unsqueeze(dim=0), timestamps.unsqueeze(dim=0)], dim=0).unsqueeze(
dim=1
)
data_raw = torch.cat([data_raw, timestamps_doubled], dim=1)
fiber_in = data_raw[0, :, :]
fiber_out = data_raw[1, :, :]
# timestamps_doubled = torch.cat([timestamps.unsqueeze(dim=0), timestamps.unsqueeze(dim=0)], dim=0).unsqueeze(
# dim=1
# )
fiber_in = torch.cat([fiber_in, timestamps.unsqueeze(0)], dim=0)
fiber_out = torch.cat([fiber_out, timestamps.unsqueeze(0)], dim=0)
# fiber_out: [E_out_x, E_out_y, timestamps]
# add noise related to amplification necessary due to splitting of the signal
# gain_lin = output_dim*2
# gain_lin = 1
# edfa_nf = float(self.config["signal"]["edfa_nf"])
# nf_lin = 10**(edfa_nf/10)
# f0 = float(self.config["glova"]["f0"])
# noise_add = (gain_lin-1)*nf_lin*6.626e-34*f0*12.5e9
# noise = torch.randn_like(fiber_out[:2, :])
# noise_power = torch.sum(torch.mean(noise.abs().square().squeeze(), dim=-1))
# noise = noise * torch.sqrt(noise_add / noise_power)
# fiber_out[:2, :] += noise
# if osnr is None:
# noisy = fiber_out[:2, :]
# else:
# noisy = self.add_noise(fiber_out[:2, :], osnr)
# fiber_out = torch.cat([fiber_out, noisy], dim=0)
# fiber_out: [E_out_x, E_out_y, timestamps, E_out_x_noisy, E_out_y_noisy]
if repeat_randoms > 1:
fiber_in = fiber_in.repeat(1, 1, repeat_randoms)
fiber_out = fiber_out.repeat(1, 1, repeat_randoms)
# review: potential problems with repeated timestamps when plotting
else:
repeat_randoms = 1
if self.randomise_polarisations:
angles = torch.fmod(torch.randperm(data_raw.shape[-1]*repeat_randoms, device=fiber_out.device), 2) * torch.pi
start_angle = torch.rand(1) * 2 * torch.pi
angles = start_angle + torch.cumsum(torch.randn(data_raw.shape[-1])/333, dim=0) * torch.pi * 2 # random walk
angles = torch.randn(data_raw.shape[-1], device=fiber_out.device) * 2*torch.pi / 36 # sigma = 10 degrees
# self.angles = torch.rand(self.data.shape[0]) * 2 * torch.pi
else:
angles = torch.zeros(data_raw.shape[-1], device=fiber_out.device)
sin = torch.sin(angles)
cos = torch.cos(angles)
rot = torch.stack([torch.stack([cos, -sin], dim=1), torch.stack([sin, cos], dim=1)], dim=2)
data_rot = torch.bmm(fiber_out[:2, :].T.unsqueeze(1), rot.to(dtype=data_raw.dtype)).squeeze(1).T
# data_rot_noisy = torch.bmm(fiber_out[3:5, :].T.unsqueeze(1), rot.to(dtype=data_raw.dtype)).squeeze(1).T
fiber_out = torch.cat((fiber_out, data_rot), dim=0)
fiber_out = torch.cat([fiber_out, angles.unsqueeze(0)], dim=0)
# fiber_in:
# 0 E_in_x,
# 1 E_in_y,
# 2 timestamps
# fiber_out:
# 0 E_out_x,
# 1 E_out_y,
# 2 timestamps,
# 3 E_out_x_rot,
# 4 E_out_y_rot,
# 5 angle
# data_raw = torch.cat([data_raw, timestamps_doubled], dim=1)
# data layout
# [ [E_in_x, E_in_y, timestamps],
# [E_out_x, E_out_y, timestamps] ]
self.data = data_raw.unfold(dimension=-1, size=self.samples_per_slice, step=1)
self.data = self.data.movedim(-2, 0)
self.fiber_in = fiber_in.unfold(dimension=-1, size=self.samples_per_slice, step=1)
self.fiber_in = self.fiber_in.movedim(-2, 0)
if randomise_polarisations:
self.angles = torch.rand(self.data.shape[0]) * np.pi * 2
# self.data[:, 1, :2, :] = self.rotate(self.data[:, 1, :2, :], self.angles)
else:
self.angles = torch.zeros(self.data.shape[0])
self.fiber_out = fiber_out.unfold(dimension=-1, size=self.samples_per_slice, step=1)
self.fiber_out = self.fiber_out.movedim(-2, 0)
# if self.randomise_polarisations:
# self.angles = torch.cumsum((torch.rand(self.fiber_out.shape[0]) - 0.5) * 2 * torch.pi * 2 / 5000, dim=0)
# self.data = data_raw.unfold(dimension=-1, size=self.samples_per_slice, step=1)
# self.data = self.data.movedim(-2, 0)
# self.angles = torch.zeros(self.data.shape[0])
...
# ...
# -> [no_slices, 2, 3, samples_per_slice]
@@ -305,77 +400,116 @@ class FiberRegenerationDataset(Dataset):
# ...
# ] -> [no_slices, 2, 3, samples_per_slice]
...
def __len__(self):
return self.data.shape[0]
return self.fiber_in.shape[0]
def add_noise(self, data, osnr):
osnr_lin = 10 ** (osnr / 10)
popt = torch.mean(data.abs().square().squeeze(), dim=-1)
noise = torch.randn_like(data)
pn = torch.mean(noise.abs().square().squeeze(), dim=-1)
mult = torch.sqrt(popt / (pn * osnr_lin))
mult = mult * torch.eye(popt.shape[0], device=mult.device)
mult = mult.to(dtype=noise.dtype)
noise = mult @ noise
pn = torch.mean(noise.abs().square().squeeze(), dim=-1)
noisy = data + noise
return noisy
def __getitem__(self, idx):
if isinstance(idx, slice):
return [self.__getitem__(i) for i in range(*idx.indices(len(self)))]
else:
data_slice = self.data[idx].squeeze()
data_slice = data_slice[:, :, : data_slice.shape[2] // self.output_dim * self.output_dim]
# fiber in: [E_in_x, E_in_y, timestamps]
# fiber out: [E_out_x, E_out_y, timestamps, E_out_x_rot, E_out_y_rot, angle]
data_slice = data_slice.view(data_slice.shape[0], data_slice.shape[1], self.output_dim, -1)
# if self.polarisations:
output_dim = self.output_dim // 2
self.output_dim = output_dim * 2
# if self.randomise_polarisations:
# angle = torch.rand(1) * torch.pi * 2
# sine = torch.sin(angle)
# cosine = torch.cos(angle)
# data_slice_ = data_slice[1]
# data_slice[1, 0] = data_slice_[0] * cosine - data_slice_[1] * sine
# data_slice[1,1] = data_slice_[0] * sine + data_slice_[1] * cosine
# else:
# angle = torch.zeros(1)
if not self.polarisations:
output_dim = 2 * output_dim
# data = data_slice[1, :2, :, 0]
angle = self.angles[idx]
fiber_in = self.fiber_in[idx].squeeze()
fiber_out = self.fiber_out[idx].squeeze()
data_index = 1
fiber_in = fiber_in[..., : fiber_in.shape[-1] // output_dim * output_dim]
fiber_out = fiber_out[..., : fiber_out.shape[-1] // output_dim * output_dim]
data_slice[1, :2, :, :] = self.rotate(data_slice[data_index, :2, :, :], angle)
fiber_in = fiber_in.view(fiber_in.shape[0], output_dim, -1)
fiber_out = fiber_out.view(fiber_out.shape[0], output_dim, -1)
data = data_slice[1, :2, :, 0]
# data = self.rotate(data, angle)
center_angle = fiber_out[5, output_dim // 2, 0]
angles = fiber_out[5, :, 0]
plot_data_rot = fiber_out[3:5, output_dim // 2, 0].detach().clone()
data = fiber_out[0:2, :, 0]
plot_data = fiber_out[0:2, output_dim // 2, 0].detach().clone()
# for both polarisations (x/y), calculate the mean of the signal around the current symbol (-> corresponding to a lowpass filter)
angle_data = data_slice[1, :2, :, :].reshape(2, -1).mean(dim=1)
angle_data2 = self.complex_max(data_slice[1, :2, :, :].reshape(2, -1))
plot_data = data_slice[1, :2, self.output_dim // 2, 0]
sop = self.polarimeter(plot_data)
# angle_data = data_slice[1, :2, :self.output_dim//2, :].squeeze().reshape(2, -1).mean(dim=-1)
# angle = data_slice[1, 3, self.output_dim // 2, 0].real
target = data_slice[0, :2, self.output_dim // 2, 0]
target_timestamp = data_slice[0, 2, self.output_dim // 2, 0].real
target = fiber_in[:2, output_dim // 2, 0]
plot_target = fiber_in[:2, output_dim // 2, 0].detach().clone()
target_timestamp = fiber_in[2, output_dim // 2, 0].real
...
# data_timestamps = data[-1,:].real
# data = data[:-1, :]
# target_timestamp = target[-1].real
# target = target[:-1]
# plot_data = plot_data[:-1]
if self.polarisations:
rot = int(np.random.randint(2) * 2 - 1)
data = rot * data
target = rot * target
plot_data_rot = rot * plot_data_rot
center_angle = center_angle + (rot - 1) * torch.pi / 2 # rot: -1 or 1 -> -2 or 0 -> -pi or 0
angles = angles + (rot - 1) * torch.pi / 2
pol_flipped_data = -data
pol_flipped_target = -target
# transpose to interleave the x and y data in the output tensor
data = data.transpose(0, 1).flatten().squeeze()
angle_data = angle_data.flatten().squeeze()
angle_data2 = angle_data.flatten().squeeze()
angle = angle.flatten().squeeze()
data = data / torch.sqrt(torch.ones(1) * len(data)) # power loss due to splitting
pol_flipped_data = pol_flipped_data.transpose(0, 1).flatten().squeeze()
pol_flipped_data = pol_flipped_data / torch.sqrt(
torch.ones(1) * len(pol_flipped_data)
) # power loss due to splitting
# angle_data = angle_data.transpose(0, 1).flatten().squeeze()
# angle_data2 = angle_data2.transpose(0,1).flatten().squeeze()
center_angle = center_angle.flatten().squeeze()
angles = angles.flatten().squeeze()
# data_timestamps = data_timestamps.flatten().squeeze()
# target = target.transpose(0,1).flatten().squeeze()
target = target.flatten().squeeze()
pol_flipped_target = pol_flipped_target.flatten().squeeze()
target_timestamp = target_timestamp.flatten().squeeze()
plot_target = plot_target.flatten().squeeze()
plot_data = plot_data.flatten().squeeze()
plot_data_rot = plot_data_rot.flatten().squeeze()
return {
"x": data,
"x_flipped": pol_flipped_data,
"x_stacked": torch.cat([data, pol_flipped_data], dim=-1),
"y": target,
"y_flipped": pol_flipped_target,
"y_stacked": torch.cat([target, pol_flipped_target], dim=-1),
"center_angle": center_angle,
"angles": angles,
"mean_angle": angles.mean(),
# "sop": sop,
# "angle_data": angle_data,
# "angle_data2": angle_data2,
"timestamp": target_timestamp,
"plot_target": plot_target,
"plot_data": plot_data,
"plot_data_rot": plot_data_rot,
# "plot_clean": fiber_out_plot_clean,
}
return {"x": data, "y": target, "angle": angle, "sop": sop, "angle_data": angle_data, "angle_data2": angle_data2, "timestamp": target_timestamp, "plot_data": plot_data}
def complex_max(self, data, dim=-1):
# returns element(s) with the maximum absolute value along a given dimension
# ind = torch.argmax(data.abs(), dim=dim, keepdim=True)
# max_values = torch.gather(data, dim, ind).squeeze(dim=dim)
# return max_values
return torch.gather(data, dim, torch.argmax(data.abs(), dim=dim, keepdim=True)).squeeze(dim=dim)
def rotate(self, data, angle):
# rotates a 2d tensor by a given angle
@@ -388,7 +522,25 @@ class FiberRegenerationDataset(Dataset):
cosine = torch.cos(angle)
return torch.stack([data[0] * cosine - data[1] * sine, data[0] * sine + data[1] * cosine], dim=0)
def rotate_all(self):
def do_rotation(j, num_processes):
for i in range(len(self) // num_processes):
index = i * num_processes + j
self.data[index, 1, :2, :] = self.rotate(self.data[index, 1, :2, :], self.angles[index])
self.processes = []
for j in range(mp.cpu_count()):
self.processes.append(mp.Process(target=do_rotation, args=(j, mp.cpu_count())))
self.processes[-1].start()
for p in self.processes:
p.join()
for i in range(len(self) // mp.cpu_count() * mp.cpu_count(), len(self)):
self.data[i, 1, :2, :] = self.rotate(self.data[i, 1, :2, :], self.angles[i])
def polarimeter(self, data):
# data: [2, ...] -> x, y
# returns [4] -> S0, S1, S2, S3
@@ -396,12 +548,12 @@ class FiberRegenerationDataset(Dataset):
y = data[1].mean()
I_X = x.abs().square()
I_Y = y.abs().square()
I_45 = (x+y).abs().square()
I_RHC = (x + 1j*y).abs().square()
I_45 = (x + y).abs().square()
I_RHC = (x + 1j * y).abs().square()
S0 = I_X + I_Y
S1 = (2*I_X - S0) / S0
S2 = (2*I_45 - S0) / S0
S3 = (2*I_RHC - S0) / S0
S1 = (2 * I_X - S0) / S0
S2 = (2 * I_45 - S0) / S0
S3 = (2 * I_RHC - S0) / S0
return torch.stack([S1, S2, S3], dim=0)
return torch.stack([S0, S1, S2, S3], dim=0)

View File

@@ -1,16 +1,23 @@
from datetime import datetime
import json
from pathlib import Path
from typing import Literal
import h5py
from matplotlib import pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
# from cmap import Colormap as cm
import numpy as np
from scipy.cluster.vq import kmeans2
import warnings
import multiprocessing
from rich.traceback import install
from rich import pretty
from rich import print
install()
pretty.install()
# from rich import pretty
# from rich import print
# pretty.install()
def generate_sample_data(n_symbols=1000, sps=128, noise=0.01, skew=1):
@@ -21,6 +28,7 @@ def generate_sample_data(n_symbols=1000, sps=128, noise=0.01, skew=1):
xaxis = np.arange(0, len(signal)) / sps
return np.vstack([xaxis, signal])
def create_symbol_sequence(n_symbols, skew=1):
np.random.seed(42)
data = np.random.randint(0, 4, n_symbols) / 4
@@ -39,6 +47,14 @@ def generate_signal(data, sps):
signal = np.convolve(data_padded, wavelet)
signal = np.cumsum(signal)
signal = signal[sps // 2 + len(wavelet) // 2 - 1 : -len(wavelet) // 2]
mi, ma = np.min(signal), np.max(signal)
signal = (signal - mi) / (ma - mi)
mod = 0.8
signal *= mod
signal += 1 - mod
return signal
@@ -49,8 +65,8 @@ def normalization_with_noise(signal, noise=0):
signal += awgn
# min-max normalization
signal = signal - np.min(signal)
signal = signal / np.max(signal)
# signal = signal - np.min(signal)
# signal = signal / np.max(signal)
return signal
@@ -68,98 +84,264 @@ def generate_wavelet(sps, oversample=3):
class eye_diagram:
def __init__(self, data, *, channel_names=None, horizontal_bins=256, vertical_bins=1000, n_levels=4, multithreaded=True):
def __init__(
self,
data,
*,
channel_names=None,
horizontal_bins=256,
vertical_bins=1000,
n_levels=4,
multithreaded=True,
save_file_or_dir=None,
):
# data has shape [channels, 2, samples]
# each sample has a timestamp and a value
if data.ndim == 2:
data = data[np.newaxis, :, :]
self.channel_names = channel_names
self.raw_data = data
self.channels = data.shape[0]
self.y_bins = np.zeros(1)
self.x_bins = np.zeros(1)
self.eye_data = np.zeros(1)
self.channel_names = channel_names
self.n_channels = data.shape[0]
self.n_levels = n_levels
self.eye_stats = [{"success": False} for _ in range(self.channels)]
self.eye_stats = [{"success": False} for _ in range(self.n_channels)]
self.horizontal_bins = horizontal_bins
self.vertical_bins = vertical_bins
self.multi_threaded = multithreaded
self.analysed = False
self.eye_built = False
self.analyse()
def generate_eye_data(self):
self.x_bins = np.linspace(0, 2, self.horizontal_bins, endpoint=False)
self.y_bins = np.zeros((self.channels, self.vertical_bins))
self.eye_data = np.zeros((self.channels, self.vertical_bins, self.horizontal_bins))
datas = [self.raw_data[i] for i in range(self.channels)]
if self.multi_threaded:
with multiprocessing.Pool() as pool:
results = pool.map(self.generate_eye_data_single, datas)
for i, result in enumerate(results):
self.eye_data[i], self.y_bins[i] = result
self.save_file = save_file_or_dir
def load_data(self, file=None):
file = self.save_file if file is None else file
if file is None:
raise FileNotFoundError("No file specified.")
self.save_file = str(file)
# self.file_or_dir = self.save_file
with h5py.File(file, "r") as infile:
self.y_bins = infile["y_bins"][:]
self.x_bins = infile["x_bins"][:]
self.eye_data = infile["eye_data"][:]
self.channel_names = infile.attrs["channel_names"]
self.n_channels = infile.attrs["n_channels"]
self.n_levels = infile.attrs["n_levels"]
self.eye_stats = infile.attrs["eye_stats"]
self.eye_stats = [json.loads(stat) for stat in self.eye_stats]
self.horizontal_bins = infile.attrs["horizontal_bins"]
self.vertical_bins = infile.attrs["vertical_bins"]
self.multi_threaded = infile.attrs["multithreaded"]
self.analysed = infile.attrs["analysed"]
self.eye_built = infile.attrs["eye_built"]
def save_data(self, file_or_dir=None):
file_or_dir = self.save_file if file_or_dir is None else file_or_dir
if file_or_dir is None:
file = Path(f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_eye_data.eye.h5")
elif Path(file_or_dir).is_dir():
file = Path(file_or_dir) / f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_eye_data.eye.h5"
else:
for i, data in enumerate(datas):
self.eye_data[i], self.y_bins[i] = self.generate_eye_data_single(data)
self.eye_built = True
file = Path(file_or_dir)
# file.parent.mkdir(parents=True, exist_ok=True)
self.save_file = str(file)
with h5py.File(file, "w") as outfile:
outfile.create_dataset("eye_data", data=self.eye_data)
outfile.create_dataset("y_bins", data=self.y_bins)
outfile.create_dataset("x_bins", data=self.x_bins)
outfile.attrs["channel_names"] = self.channel_names
outfile.attrs["n_channels"] = self.n_channels
outfile.attrs["n_levels"] = self.n_levels
self.eye_stats = eye_diagram.convert_arrays(self.eye_stats)
outfile.attrs["eye_stats"] = [json.dumps(stat) for stat in self.eye_stats]
outfile.attrs["horizontal_bins"] = self.horizontal_bins
outfile.attrs["vertical_bins"] = self.vertical_bins
outfile.attrs["multithreaded"] = self.multi_threaded
outfile.attrs["analysed"] = self.analysed
outfile.attrs["eye_built"] = self.eye_built
@staticmethod
def convert_arrays(input_object):
"""
convert ndarrays in (nested) dict to lists
"""
if isinstance(input_object, np.ndarray):
return input_object.tolist()
elif isinstance(input_object, list):
return [eye_diagram.convert_arrays(old) for old in input_object]
elif isinstance(input_object, tuple):
return tuple(eye_diagram.convert_arrays(old) for old in input_object)
elif isinstance(input_object, dict):
dict_out = {}
for key, value in input_object.items():
dict_out[key] = eye_diagram.convert_arrays(value)
return dict_out
return input_object
def generate_eye_data(
self, mode: Literal["default", "load", "save", "nosave"] = "default", file_or_dir: Path | None | str = None
):
# modes:
# default: try to load eye data from file, if not found, generate and save
# load: try to load eye data from file, if not found, generate but don't save
# save: generate eye data and save
update_save = True
if mode == "load":
self.load_data(file_or_dir)
update_save = False
elif mode == "default":
try:
self.load_data(file_or_dir)
update_save = False
except (FileNotFoundError, IsADirectoryError):
pass
if not self.eye_built:
update_save = True
self.x_bins = np.linspace(0, 2, self.horizontal_bins, endpoint=False)
self.y_bins = np.zeros((self.n_channels, self.vertical_bins))
self.eye_data = np.zeros((self.n_channels, self.vertical_bins, self.horizontal_bins))
datas = [self.raw_data[i] for i in range(self.n_channels)]
if self.multi_threaded:
with multiprocessing.Pool() as pool:
results = pool.map(self.generate_eye_data_single, datas)
for i, result in enumerate(results):
self.eye_data[i], self.y_bins[i] = result
else:
for i, data in enumerate(datas):
self.eye_data[i], self.y_bins[i] = self.generate_eye_data_single(data)
self.eye_built = True
if mode == "save" or (mode == "default" and update_save):
self.save_data(file_or_dir)
def generate_eye_data_single(self, data):
eye_data = np.zeros((self.vertical_bins, self.horizontal_bins))
data_min = np.min(data[1, :])
data_max = np.max(data[1, :])
# round down/up to 1 decimal
data_min = np.floor(data_min*10)/10
data_max = np.ceil(data_max*10)/10
# data_range = data_max - data_min
# data_min -= 0.1 * data_range
# data_max += 0.1 * data_range
# data_min = -0.05
# data_max += 0.05
# data[1,:] -= np.min(data[1, :])
# data[1,:] /= np.max(data[1, :])
# data_min = 0
# data_max = 1
y_bins = np.linspace(data_min, data_max, self.vertical_bins, endpoint=False)
t_vals = data[0, :] % 2
val_vals = data[1, :]
t_vals = data[0, :] % 2 # + np.random.randn(*data[0, :].shape) * 1 / (512)
val_vals = data[1, :] # + np.random.randn(*data[1, :].shape) * 1 / (320)
x_indices = np.digitize(t_vals, self.x_bins) - 1
y_indices = np.digitize(val_vals, y_bins) - 1
np.add.at(eye_data, (y_indices, x_indices), 1)
return eye_data, y_bins
def plot(self, title="Eye Diagram", stats=True, all_stats=True, show=True):
def plot(
self,
title="Eye Diagram",
stats=True,
all_stats=True,
show=True,
mode: Literal["default", "load", "save", "nosave"] = "default",
# save_images = False,
# image_dir = None,
# cmap=None,
):
if stats and not self.analysed:
self.analyse(mode=mode)
if not self.eye_built:
self.generate_eye_data()
self.generate_eye_data(mode=mode)
cmap = LinearSegmentedColormap.from_list(
"eyemap",
[(0, "white"), (0.1, "blue"), (0.2, "cyan"), (0.5, "green"), (0.8, "yellow"), (0.9, "red"), (1, "magenta")],
[
(0, "#FFFFFF00"),
(0.1, "blue"),
(0.2, "cyan"),
(0.5, "green"),
(0.8, "yellow"),
(0.9, "red"),
(1, "magenta"),
],
)
if self.channels % 2 == 0:
# cmap = cm('google:turbo_r' if cmap is None else cmap)
# first = cmap(-1)
# cmap = cmap.to_mpl()
# cmap.set_under(first, alpha=0)
if self.n_channels % 2 == 0:
rows = 2
cols = self.channels // 2
cols = self.n_channels // 2
else:
cols = int(np.ceil(np.sqrt(self.channels)))
rows = int(np.ceil(self.channels / cols))
cols = int(np.ceil(np.sqrt(self.n_channels)))
rows = int(np.ceil(self.n_channels / cols))
fig, ax = plt.subplots(rows, cols, sharex=True, sharey=False)
fig.suptitle(title)
fig.tight_layout()
ax = np.atleast_1d(ax).transpose().flatten()
for i in range(self.channels):
ax[i].set_title(self.channel_names[i] if self.channel_names is not None else f"Channel {i+1}")
if (i+1) % rows == 0:
for i in range(self.n_channels):
ax[i].set_title(self.channel_names[i] if self.channel_names is not None else f"Channel {i + 1}")
if (i + 1) % rows == 0:
ax[i].set_xlabel("Symbol")
if i < rows:
ax[i].set_ylabel("Amplitude")
ax[i].grid()
ax[i].set_axisbelow(True)
ax[i].imshow(
self.eye_data[i],
self.eye_data[i] - 0.1,
origin="lower",
aspect="auto",
cmap=cmap,
extent=[self.x_bins[0], self.x_bins[-1], self.y_bins[i][0], self.y_bins[i][-1]],
interpolation="gaussian",
vmin=0,
zorder=3,
)
ax[i].set_xlim((self.x_bins[0], self.x_bins[-1]))
ymin = np.min(self.y_bins[:, 0])
ymax = np.max(self.y_bins[:, -1])
yspan = ymax - ymin
ax[i].set_ylim((ymin - 0.1 * yspan, ymax + 0.1 * yspan))
# if save_images:
# image_dir = "images_out" if image_dir is None else image_dir
# image_path = Path(image_dir) / (slugify(f"{datetime.now().strftime("%Y%m%d_%H%M%S")}_{title.replace(" ","_")}_{self.channel_names[i].replace(" ", "_") if self.channel_names is not None else f"{i + 1}"}_{ymin:.1f}_{ymax:.1f}") + ".png")
# image_path.parent.mkdir(parents=True, exist_ok=True)
# # plt.imsave(
# # image_path,
# # self.eye_data[i] - 0.1,
# # origin="lower",
# # # aspect="auto",
# # cmap=cmap,
# # # extent=[self.x_bins[0], self.x_bins[-1], self.y_bins[i][0], self.y_bins[i][-1]],
# # # interpolation="gaussian",
# # vmin=0,
# # # zorder=3,
# # )
if stats and self.eye_stats[i]["success"]:
# add min_area above the plot
ax[i].annotate(
f"Min Area: {self.eye_stats[i]['min_area']:.2e}",
xy=(0.05, ymax + 0.05 * yspan),
# xycoords="axes fraction",
ha="left",
va="center",
bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"),
)
# # add min_area above the plot
# ax[i].annotate(
# f"Min Area: {self.eye_stats[i]['min_area']:.2e}",
# xy=(0.05, ymax + 0.05 * yspan),
# # xycoords="axes fraction",
# ha="left",
# va="center",
# bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"),
# )
if all_stats:
ax[i].plot([0, 2], [self.eye_stats[i]["levels"], self.eye_stats[i]["levels"]], "k--")
ax[i].set_yticks(self.eye_stats[i]["levels"])
y_ticks = (*self.eye_stats[i]["levels"], *self.eye_stats[i]["thresholds"])
# y_ticks = np.sort(y_ticks)
ax[i].set_yticks(y_ticks)
# add arrows for amplitudes
for j in range(len(self.eye_stats[i]["amplitudes"])):
ax[i].annotate(
@@ -193,62 +375,69 @@ class eye_diagram:
except (ValueError, IndexError):
pass
# add arrows for eye widths
for j in range(len(self.eye_stats[i]["widths"])):
try:
left = np.max(self.eye_stats[i]["time_clusters"][j][0])
right = np.min(self.eye_stats[i]["time_clusters"][j][1])
vertical = (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2
# for j in range(len(self.eye_stats[i]["widths"])):
# try:
# left = np.max(self.eye_stats[i]["time_clusters"][j][0])
# right = np.min(self.eye_stats[i]["time_clusters"][j][1])
# vertical = (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2
ax[i].annotate(
"",
xy=(left, vertical),
xytext=(right, vertical),
arrowprops=dict(arrowstyle="<->", facecolor="black"),
)
ax[i].annotate(
f"{self.eye_stats[i]['widths'][j]:.2e}",
xy=((left + right) / 2 - 0.15, vertical + 0.01),
bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"),
)
except (ValueError, IndexError):
pass
# ax[i].annotate(
# "",
# xy=(left, vertical),
# xytext=(right, vertical),
# arrowprops=dict(arrowstyle="<->", facecolor="black"),
# )
# ax[i].annotate(
# f"{self.eye_stats[i]['widths'][j]:.2e}",
# xy=((left + right) / 2 - 0.15, vertical + 0.01),
# bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"),
# )
# except (ValueError, IndexError):
# pass
# add area
for j in range(len(self.eye_stats[i]["areas"])):
horizontal = self.eye_stats[i]["time_midpoint"]
vertical = (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2
ax[i].annotate(
f"{self.eye_stats[i]['areas'][j]:.2e}",
xy=(horizontal + 0.035, vertical - 0.07),
bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"),
)
# # add area
# for j in range(len(self.eye_stats[i]["areas"])):
# horizontal = self.eye_stats[i]["time_midpoint"]
# vertical = (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2
# ax[i].annotate(
# f"{self.eye_stats[i]['areas'][j]:.2e}",
# xy=(horizontal + 0.035, vertical - 0.07),
# bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"),
# )
fig.tight_layout()
if show:
plt.show()
return fig
@staticmethod
def calculate_thresholds(levels):
ret = np.cumsum(levels, dtype=float)
ret[2:] = ret[2:] - ret[:-2]
return ret[1:] / 2
def analyse_single(self, data, index):
warnings.filterwarnings("error")
eye_stats = {}
eye_stats["channel_name"] = str(index+1) if self.channel_names is None else self.channel_names[index]
eye_stats["channel_name"] = str(index + 1) if self.channel_names is None else self.channel_names[index]
try:
approx_levels = eye_diagram.approximate_levels(data, self.n_levels)
time_bounds = eye_diagram.calculate_time_bounds(data, approx_levels)
eye_stats["time_midpoint"] = (time_bounds[0] + time_bounds[1]) / 2
eye_stats["time_midpoint"] = float((time_bounds[0] + time_bounds[1]) / 2)
# eye_stats["time_midpoint"] = 1.0
eye_stats["levels"], eye_stats["amplitude_clusters"] = eye_diagram.calculate_levels(
data, approx_levels, time_bounds
)
eye_stats["thresholds"] = self.calculate_thresholds(eye_stats["levels"])
eye_stats["amplitudes"] = np.diff(eye_stats["levels"])
eye_stats["heights"] = eye_diagram.calculate_eye_heights(
eye_stats["amplitude_clusters"]
)
eye_stats["heights"] = eye_diagram.calculate_eye_heights(eye_stats["amplitude_clusters"])
eye_stats["widths"], eye_stats["time_clusters"] = eye_diagram.calculate_eye_widths(
data, eye_stats["levels"]
@@ -260,36 +449,59 @@ class eye_diagram:
# if not (np.max(eye_stats['time_clusters'][j][0]) < eye_stats["time_midpoint"] < np.min(eye_stats['time_clusters'][j][1])):
# raise ValueError
eye_stats["areas"] = eye_stats["heights"] * eye_stats["widths"]
eye_stats["mean_area"] = np.mean(eye_stats["areas"])
eye_stats["min_area"] = np.min(eye_stats["areas"])
# eye_stats["areas"] = eye_stats["heights"] * eye_stats["widths"]
# eye_stats["mean_area"] = np.mean(eye_stats["areas"])
# eye_stats["min_area"] = np.min(eye_stats["areas"])
eye_stats["success"] = True
except (RuntimeWarning, UserWarning, ValueError):
eye_stats["success"] = False
eye_stats["time_midpoint"] = 0
eye_stats["levels"] = np.zeros(self.n_levels)
eye_stats["amplitude_clusters"] = []
eye_stats["amplitudes"] = np.zeros(self.n_levels - 1)
eye_stats["heights"] = np.zeros(self.n_levels - 1)
eye_stats["widths"] = np.zeros(self.n_levels - 1)
eye_stats["areas"] = np.zeros(self.n_levels - 1)
eye_stats["mean_area"] = 0
eye_stats["min_area"] = 0
eye_stats["time_midpoint"] = None
eye_stats["levels"] = None
eye_stats["thresholds"] = None
eye_stats["amplitude_clusters"] = None
eye_stats["amplitudes"] = None
eye_stats["heights"] = None
eye_stats["widths"] = None
# eye_stats["areas"] = np.zeros(self.n_levels - 1)
# eye_stats["mean_area"] = 0
# eye_stats["min_area"] = 0
warnings.resetwarnings()
return eye_stats
def analyse(
self, mode: Literal["default", "load", "save", "nosave"] = "default", file_or_dir: Path | None | str = None
):
# modes:
# default: try to load eye data from file, if not found, generate and save
# load: try to load eye data from file, if not found, generate but don't save
# save: generate eye data and save
update_save = True
if mode == "load":
self.load_data(file_or_dir)
update_save = False
elif mode == "default":
try:
self.load_data(file_or_dir)
update_save = False
except (FileNotFoundError, IsADirectoryError):
pass
def analyse(self):
self.eye_stats = []
if self.multi_threaded:
with multiprocessing.Pool() as pool:
results = pool.starmap(self.analyse_single, [(self.raw_data[i], i) for i in range(self.channels)])
for i, result in enumerate(results):
self.eye_stats.append(result)
else:
for i in range(self.channels):
self.eye_stats.append(self.analyse_single(self.raw_data[i], i))
if not self.analysed:
update_save = True
self.eye_stats = []
if self.multi_threaded:
with multiprocessing.Pool() as pool:
results = pool.starmap(self.analyse_single, [(self.raw_data[i], i) for i in range(self.n_channels)])
for i, result in enumerate(results):
self.eye_stats.append(result)
else:
for i in range(self.n_channels):
self.eye_stats.append(self.analyse_single(self.raw_data[i], i))
self.analysed = True
if mode == "save" or (mode == "default" and update_save):
self.save_data(file_or_dir)
@staticmethod
def approximate_levels(data, levels):
@@ -431,7 +643,7 @@ class eye_diagram:
if __name__ == "__main__":
length = int(2**14)
length = int(2**16)
# data = generate_sample_data(length, noise=1)
# data1 = generate_sample_data(length, noise=0.01)
# data2 = generate_sample_data(length, noise=0.01, skew=1.2)
@@ -439,12 +651,13 @@ if __name__ == "__main__":
# data = np.stack([data, data1, data2, data3])
data = generate_sample_data(length, noise=0.005)
eye = eye_diagram(data, horizontal_bins=256, vertical_bins=256)
attrs = ("success", "amplitudes", "time_midpoint", "levels", "heights", "widths", "area", "mean_area", "min_area")
for i, channel in enumerate(eye.eye_stats):
print(f"Channel {i}")
print_data = {attr: channel[attr] for attr in attrs}
print(print_data)
data = generate_sample_data(length, noise=0.0000)
eye = eye_diagram(data, horizontal_bins=1056, vertical_bins=1200)
eye.plot(mode="nosave", stats=False)
# attrs = ("success", "amplitudes", "time_midpoint", "levels", "heights", "widths")#, "area", "mean_area", "min_area")
# for i, channel in enumerate(eye.eye_stats):
# print(f"Channel {i}")
# print_data = {attr: channel[attr] for attr in attrs}
# print(print_data)
eye.plot()
# eye.plot()

View File

@@ -0,0 +1,122 @@
# Copyright (c) 2015, Warren Weckesser. All rights reserved.
# This software is licensed according to the "BSD 2-clause" license.
# modified by Joseph Hopfmüller in 2025,
# for integration into optical regeneration analysis scripts
from pathlib import Path
from matplotlib.colors import LinearSegmentedColormap
import matplotlib.colors as colors
import numpy as _np
from .core import grid_count as _grid_count
import matplotlib.pyplot as _plt
import numpy as np
from scipy.ndimage import gaussian_filter
# from ._common import _common_doc
__all__ = ["eyediagram"] # , 'eyediagram_lines']
# def eyediagram_lines(y, window_size, offset=0, **plotkwargs):
# """
# Plot an eye diagram using matplotlib by repeatedly calling the `plot`
# function.
# <common>
# """
# start = offset
# while start < len(y):
# end = start + window_size
# if end > len(y):
# end = len(y)
# yy = y[start:end+1]
# _plt.plot(_np.arange(len(yy)), yy, 'k', **plotkwargs)
# start = end
# eyediagram_lines.__doc__ = eyediagram_lines.__doc__.replace("<common>",
# _common_doc)
eyemap = LinearSegmentedColormap.from_list(
"eyemap",
[
(0, "#0000FF00"),
(0.1, "blue"),
(0.2, "cyan"),
(0.5, "green"),
(0.8, "yellow"),
(0.9, "red"),
(1, "magenta"),
],
)
def eyediagram(
y,
window_size,
offset=0,
colorbar=False,
show=False,
save_im=False,
overwrite=False,
blur: int | bool = True,
save_path="out.png",
bounds=None,
**imshowkwargs,
):
"""
Plot an eye diagram using matplotlib by creating an image and calling
the `imshow` function.
<common>
"""
if bounds is None:
ymax = y.max()
ymin = y.min()
yamp = ymax - ymin
ymin = ymin - 0.05 * yamp
ymax = ymax + 0.05 * yamp
ymin = np.floor(ymin * 10) / 10
ymax = np.ceil(ymax * 10) / 10
bounds = (ymin, ymax)
counts = _grid_count(y, window_size, offset, bounds=bounds, size=(1000, 1200), blur=int(blur))
counts = counts.astype(_np.float32)
origin = imshowkwargs.pop("origin", "lower")
cmap: colors.Colormap = imshowkwargs.pop("cmap", eyemap)
vmin = imshowkwargs.pop("vmin", 1)
vmax = imshowkwargs.pop("vmax", None)
cmap.set_under("white", alpha=0)
if show:
_plt.imshow(
counts.T[::-1, :],
extent=[0, 2, *bounds],
origin=origin,
cmap=cmap,
vmin=vmin,
vmax=vmax,
**imshowkwargs,
)
_plt.grid()
if colorbar:
_plt.colorbar()
if Path(save_path).is_file() and not overwrite:
save_im = False
if save_im:
from PIL import Image
arr = counts.T[::-1, :]
if origin == "lower":
arr = arr[::-1]
arr = (arr-arr.min())/(arr.max()-arr.min())
image = Image.fromarray((cmap(arr)[:, :, :] * 255).astype(np.uint8))
image.save(save_path)
# print("-")
if show:
_plt.show()
# eyediagram.__doc__ = eyediagram.__doc__.replace("<common>", _common_doc)

View File

@@ -1,6 +1,9 @@
import matplotlib.pyplot as plt
import numpy as np
from .datasets import load_data
if __name__ == "__main__":
from datasets import load_data
else:
from .datasets import load_data
def eye(*, path=None, data=None, sps=None, title=None, symbols=1000, skipfirst=0, width=2, alpha=None, complex=False, show=True):
"""Plot an eye diagram for the data given by filepath.
@@ -20,6 +23,7 @@ def eye(*, path=None, data=None, sps=None, title=None, symbols=1000, skipfirst=0
raise ValueError("Either path or data and sps must be given.")
if path is not None:
data, config = load_data(path, skipfirst, symbols)
data = data.detach().cpu().numpy()[:, :4]
sps = int(config["glova"]["sps"])
if sps is None:
raise ValueError("sps not set.")
@@ -71,3 +75,6 @@ def eye(*, path=None, data=None, sps=None, title=None, symbols=1000, skipfirst=0
plt.show()
return fig
if __name__ == "__main__":
eye(path="data/20241229-163838-128-16384-50000-0-0.2-16.8-0.058-PAM4-0-0.16.ini", symbols=1000, width=2, alpha=0.1, complex=False)