Compare commits
5 Commits
33141bdf41
...
machine_le
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
249fe1e940 | ||
|
|
f38d0ca3bb | ||
|
|
3af73343c1 | ||
|
|
7a0b65f82d | ||
|
|
98305fdf47 |
2
.gitignore
vendored
2
.gitignore
vendored
@@ -163,3 +163,5 @@ cython_debug/
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
tolerance_results/*
|
||||
data/*
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:bc02b0099ea3bb136733e3d20817cad79b6c50c2e4b845f0d206455dde188cc4
|
||||
size 615
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:d80ff6f2a84acf973fbdf81a05ed0b1902f8bf97856cd5132b646f6b1173f496
|
||||
size 615
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:54ac6b6a452aa6b7d312a4c8ab8f7ebe2f96c1c4170cbc56147e8f2f9d934ad6
|
||||
size 615
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:bd2c6f4050488b6e857759d48aa1f1f37399d81cee1667d3668145e938d17c83
|
||||
size 615
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:c59b8113092459a8751b385a7b1a6f10828626d2ec2f29775512157fd9bbc75c
|
||||
size 615
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:04eaa2a29b3302e5de3bf99bf6a57fc8f27361fd3df3cac9245e25ab99324829
|
||||
size 615
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:704d3b0b17b9d320f4717b5a5a8bdbc5714f3caa4efa7153e980766429e834f4
|
||||
size 615
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:29304f35a88fd777566105f8666fff1c9927beb32756822365bcf9c159feb98e
|
||||
size 615
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:4a87488c12e0253b2bb5d1ac7aa1536f69ef569e62c2aab6a10149d753e049b4
|
||||
size 615
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:ff47c8d5413881edb03cfadfcde3b550ef7089543615a467c4f0027edaf1455e
|
||||
size 615
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:9e05a0a54c7e3aaffeca0ac90cce1b274a544d90b329a93d453392f5df4e91a8
|
||||
size 616
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:25e45b06b551ab8031c2159030d658999fcb3d1f0a34538c90768b94c8116771
|
||||
size 616
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:707ed73713b6c2d80d4e333b1ccdf4650f50aefefff58ddb471c1d5411954b3d
|
||||
size 616
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:a8c3baf878943741d83835c1afed05c1f9780ff3f0df260c0d92706343e59c50
|
||||
size 616
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:98878d09510dedc24f1429bbd12cce49c66be6c9d279a28765b120efe820a171
|
||||
size 616
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:fcbdaffa211d6b0b44b3ae1c66645999e95901bfdb2fffee4c45e34a0d901ee1
|
||||
size 649
|
||||
3
data/npys/0d56eb7933de7bd4a847eaa1ba4fd2c4.npy
Normal file
3
data/npys/0d56eb7933de7bd4a847eaa1ba4fd2c4.npy
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:73e922310068a66ab1a0c3d39b0b0fd7db798f2637b199a8c4bd113a38bb28c8
|
||||
size 134217856
|
||||
3
data/npys/0fb99fd81fd3076612f6aca9e2926e36.npy
Normal file
3
data/npys/0fb99fd81fd3076612f6aca9e2926e36.npy
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:43e80dc7d21aeff62c73f0ed02b20a4ac9b573d0e131a3ab8d1471077e03634b
|
||||
size 134217856
|
||||
3
data/npys/2a7f021c1c1112554d469b4e7bc3081f.npy
Normal file
3
data/npys/2a7f021c1c1112554d469b4e7bc3081f.npy
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:fb74cfbeec54b4f263c08510312881527fc7e484604fa0a1213b596f175fecc2
|
||||
size 134217856
|
||||
3
data/npys/3f6b6dfaecf6f1c003c6d9dfe950c6d0.npy
Normal file
3
data/npys/3f6b6dfaecf6f1c003c6d9dfe950c6d0.npy
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:447cca0af25e309c8be61216cea4fb2d3a8a967b0522760f1dab8f15e6b41574
|
||||
size 134217856
|
||||
3
data/npys/4365258197699ae6a83e846b9985f606.npy
Normal file
3
data/npys/4365258197699ae6a83e846b9985f606.npy
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:63321906920de825fc7857aa9f2e4944c3b32f3eadf99da06b66f6599924bc4c
|
||||
size 134217856
|
||||
3
data/npys/484ffd154056e75885022c286dbfafa7.npy
Normal file
3
data/npys/484ffd154056e75885022c286dbfafa7.npy
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:92e3fcc41f05380cb7b334fefc0a30bb8c1dfd9ca5c3b2cfaad36b0c7093914e
|
||||
size 134217856
|
||||
3
data/npys/55aac65b78a8181629f963815a7cdf5c.npy
Normal file
3
data/npys/55aac65b78a8181629f963815a7cdf5c.npy
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:0bf671d666d35617edd7bfb58548a5120bd92e5f9cb6edb4b5fc8d3bf5db8987
|
||||
size 134217856
|
||||
3
data/npys/5aaa36769f52738e1fa37f4c79b9eed2.npy
Normal file
3
data/npys/5aaa36769f52738e1fa37f4c79b9eed2.npy
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:bd429a5776555671b19d42e3ae152026dafd5bf95aeed9847df1432ed37f3eba
|
||||
size 134217856
|
||||
3
data/npys/6789fdea2609799ef2e975907625b79a.h5
Normal file
3
data/npys/6789fdea2609799ef2e975907625b79a.h5
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:1df90745cc2e6d4b0ad964fca2de1441e6e0b4b8345fbb0fbc1ffe9820674269
|
||||
size 134481920
|
||||
3
data/npys/78f8295da779018e431af641e9f3eb76.npy
Normal file
3
data/npys/78f8295da779018e431af641e9f3eb76.npy
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:477365edd696f66610298198257422a017f825e1f8bec4363bdfb0da0d741ebc
|
||||
size 134217856
|
||||
3
data/npys/7a63d865c33ab1bab2697ee5a9c2ce3b.npy
Normal file
3
data/npys/7a63d865c33ab1bab2697ee5a9c2ce3b.npy
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:1814f8ae0e6cdb69c0030741a3c6b35c74f2d6985eed34c0d5b4135384014abc
|
||||
size 134217856
|
||||
3
data/npys/a0ddcc9a8f2fc25a1152ec78d54d9676.npy
Normal file
3
data/npys/a0ddcc9a8f2fc25a1152ec78d54d9676.npy
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:d53bc0fd0897c7372c2f182b84163bcf401ad91f26b6949c6d7a1d70c5dbb513
|
||||
size 134217856
|
||||
3
data/npys/cf33366e5842c443d7e67f89947a380f.npy
Normal file
3
data/npys/cf33366e5842c443d7e67f89947a380f.npy
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:fc882f29530f7d683be631214f6667611e0aba87453a11157d71c50f3548fb3c
|
||||
size 134217856
|
||||
3
data/npys/d30358afa19a86066b4fda6adb74c1b4.npy
Normal file
3
data/npys/d30358afa19a86066b4fda6adb74c1b4.npy
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:0f0bc616c2e581444a6fa658e45ee942c1ef5a4d21f22363518331f8d80cbe62
|
||||
size 134217856
|
||||
3
data/npys/ec3a1c902e17ca5cb5d85398cea182f7.npy
Normal file
3
data/npys/ec3a1c902e17ca5cb5d85398cea182f7.npy
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:ad94f2af43a2a06ebddc78566cbef0ea538c3da9191682e2fde4ecddb061b0f0
|
||||
size 134217856
|
||||
3
data/npys/f4a8fbbeced371651d851440bd565ee4.npy
Normal file
3
data/npys/f4a8fbbeced371651d851440bd565ee4.npy
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:6930349d1ae1479dcb4c9ee9eaefe2da3eae62539cb4b97de873a7a8b175e809
|
||||
size 134217856
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:746ea83013870351296e01e294905ee027291ef79fd78c1e6b69dd9ebaa1cba0
|
||||
size 10240000
|
||||
oid sha256:76934d1d202aea1311ba67f5ea35eeb99a9c5c856f491565032e7d54ca6f072d
|
||||
size 13598720
|
||||
|
||||
37
notes/models.md
Normal file
37
notes/models.md
Normal file
@@ -0,0 +1,37 @@
|
||||
# models
|
||||
|
||||
## no polarisation flipping
|
||||
|
||||
```py
|
||||
config_path="data/20241229-163*-128-16384-50000-*.ini"
|
||||
model=".models/best_20241230_011907.tar"
|
||||
```
|
||||
|
||||
```py
|
||||
config_path="data/20241229-163*-128-16384-80000-*.ini"
|
||||
model=".models/best_20241230_103752.tar"
|
||||
```
|
||||
|
||||
```py
|
||||
config_path="data/20241229-163*-128-16384-100000-*.ini"
|
||||
model=".models/best_20241230_164534.tar"
|
||||
```
|
||||
|
||||
## with polarisation flipping
|
||||
|
||||
polarisation flipping: signal is randomly rotated by 180°. polarization rotation can be detected by adding a tone on one of the polarisations, but only to mod 180° with a direct detection setup. the randomly flipped signal should allow the network to hopefully learn to compensate for dispersion, pmd independently from the polarization rot. the training data includes the flipped signal as well, but no indication if the polarisation is flipped.
|
||||
|
||||
```py
|
||||
config_path="data/20241229-163*-128-16384-50000-*.ini"
|
||||
model=".models/best_20241231_000328.tar"
|
||||
```
|
||||
|
||||
```py
|
||||
config_path="data/20241229-163*-128-16384-80000-*.ini"
|
||||
model=".models/best_20241231_163614.tar"
|
||||
```
|
||||
|
||||
```py
|
||||
config_path="data/20241229-163*-128-16384-100000-*.ini"
|
||||
model=".models/best_20241231_170532.tar"
|
||||
```
|
||||
59
notes/tolerance_testing.md
Normal file
59
notes/tolerance_testing.md
Normal file
@@ -0,0 +1,59 @@
|
||||
# Baseline Models
|
||||
|
||||
## a) D+S, pol_error 0, ortho_error 0, DGD 0
|
||||
|
||||
dataset
|
||||
|
||||
```raw
|
||||
data/20250118-225840-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini
|
||||
```
|
||||
|
||||
model
|
||||
|
||||
```raw
|
||||
.models/best_20250118_225918.tar
|
||||
```
|
||||
|
||||
## b) D+S, pol_error 0.4, ortho_error 0, DGD 0
|
||||
|
||||
dataset
|
||||
|
||||
```raw
|
||||
data/20250116-214606-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini
|
||||
```
|
||||
|
||||
model
|
||||
|
||||
```raw
|
||||
.models/best_20250116_214816.tar
|
||||
```
|
||||
|
||||
## c) D+S, pol_error 0, ortho_error 0.1, DGD 0
|
||||
|
||||
dataset
|
||||
|
||||
```raw
|
||||
data/20250116-214547-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini
|
||||
```
|
||||
|
||||
model
|
||||
|
||||
```raw
|
||||
.models/best_20250117_122319.tar
|
||||
```
|
||||
|
||||
## d) D+S, pol_error 0, ortho_error 0, DGD 10ps (1 T_sym)
|
||||
|
||||
birefringence angle pi/2 (worst case)
|
||||
|
||||
dataset
|
||||
|
||||
```raw
|
||||
data/20250117-143926-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini
|
||||
```
|
||||
|
||||
model
|
||||
|
||||
```raw
|
||||
.models/best_20250117_144001.tar
|
||||
```
|
||||
2
pypho
2
pypho
Submodule pypho updated: dd015f4852...e44fc477fe
@@ -26,6 +26,8 @@ import torch
|
||||
import torch.optim as optim
|
||||
import torch.utils.data
|
||||
|
||||
import hypertraining.models as models
|
||||
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
import multiprocessing
|
||||
@@ -253,14 +255,17 @@ class HyperTraining:
|
||||
model_kwargs = {
|
||||
"dims": (input_dim, *hidden_dims, self.model_settings.output_dim),
|
||||
"layer_function": layer_func,
|
||||
"layer_parametrizations": layer_parametrizations,
|
||||
"activation_function": afunc,
|
||||
"layer_func_kwargs": self.model_settings.model_layer_kwargs,
|
||||
"act_function": afunc,
|
||||
"act_func_kwargs": None,
|
||||
"parametrizations": layer_parametrizations,
|
||||
"dtype": dtype,
|
||||
"droupout_prob": self.model_settings.dropout_prob,
|
||||
"scale": scale_layers,
|
||||
"dropout_prob": self.model_settings.dropout_prob,
|
||||
"scale_layers": scale_layers,
|
||||
"rotate": False,
|
||||
}
|
||||
|
||||
model = util.complexNN.regenerator(**model_kwargs)
|
||||
model = models.regenerator(*model_kwargs.pop("dims"), **model_kwargs)
|
||||
n_nodes = sum(hidden_dims)
|
||||
|
||||
if writer is not None:
|
||||
@@ -381,7 +386,10 @@ class HyperTraining:
|
||||
running_loss = 0.0
|
||||
model.train()
|
||||
loader_len = len(train_loader)
|
||||
for batch_idx, (x, y, _) in enumerate(train_loader):
|
||||
for batch_idx, batch in enumerate(train_loader):
|
||||
x = batch["x"]
|
||||
y = batch["y"]
|
||||
|
||||
if batch_idx >= self.optuna_settings._n_train_batches:
|
||||
break
|
||||
model.zero_grad(set_to_none=True)
|
||||
@@ -390,7 +398,7 @@ class HyperTraining:
|
||||
y.to(self.pytorch_settings.device),
|
||||
)
|
||||
y_pred = model(x)
|
||||
loss = util.complexNN.complex_mse_loss(y_pred, y)
|
||||
loss = util.complexNN.complex_mse_loss(y_pred, y, power=True)
|
||||
loss_value = loss.item()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
@@ -444,7 +452,9 @@ class HyperTraining:
|
||||
model.eval()
|
||||
running_error = 0
|
||||
with torch.no_grad():
|
||||
for batch_idx, (x, y, _) in enumerate(valid_loader):
|
||||
for batch_idx, batch in enumerate(valid_loader):
|
||||
x = batch["x"]
|
||||
y = batch["y"]
|
||||
if batch_idx >= self.optuna_settings._n_valid_batches:
|
||||
break
|
||||
x, y = (
|
||||
@@ -452,50 +462,44 @@ class HyperTraining:
|
||||
y.to(self.pytorch_settings.device),
|
||||
)
|
||||
y_pred = model(x)
|
||||
error = util.complexNN.complex_mse_loss(y_pred, y)
|
||||
error = util.complexNN.complex_mse_loss(y_pred, y, power=True)
|
||||
error_value = error.item()
|
||||
running_error += error_value
|
||||
|
||||
running_error /= min(len(valid_loader), self.optuna_settings._n_valid_batches)
|
||||
|
||||
if writer is not None:
|
||||
title_append, subtitle = self.build_title(trial)
|
||||
writer.add_figure(
|
||||
"fiber response",
|
||||
self.plot_model_response(
|
||||
trial,
|
||||
model=model,
|
||||
title_append=title_append,
|
||||
subtitle=subtitle,
|
||||
show=False,
|
||||
),
|
||||
epoch + 1,
|
||||
)
|
||||
writer.add_figure(
|
||||
"eye diagram",
|
||||
self.plot_model_response(
|
||||
trial,
|
||||
model=self.model,
|
||||
title_append=title_append,
|
||||
subtitle=subtitle,
|
||||
show=False,
|
||||
mode="eye",
|
||||
),
|
||||
epoch + 1,
|
||||
writer.add_scalar(
|
||||
"eval loss",
|
||||
running_error,
|
||||
epoch,
|
||||
)
|
||||
# if (epoch + 1) % 10 == 0 or epoch < 10:
|
||||
# # plotting is slow, so only do it every 10 epochs
|
||||
# title_append, subtitle = self.build_title(trial)
|
||||
# head_fig, eye_fig, powers_fig = self.plot_model_response(
|
||||
# model=model,
|
||||
# title_append=title_append,
|
||||
# subtitle=subtitle,
|
||||
# show=False,
|
||||
# )
|
||||
# writer.add_figure(
|
||||
# "fiber response",
|
||||
# head_fig,
|
||||
# epoch + 1,
|
||||
# )
|
||||
# writer.add_figure(
|
||||
# "eye diagram",
|
||||
# eye_fig,
|
||||
# epoch + 1,
|
||||
# )
|
||||
|
||||
writer.add_figure(
|
||||
"powers",
|
||||
self.plot_model_response(
|
||||
trial,
|
||||
model=self.model,
|
||||
title_append=title_append,
|
||||
subtitle=subtitle,
|
||||
mode="powers",
|
||||
show=False,
|
||||
),
|
||||
epoch + 1,
|
||||
)
|
||||
# writer.add_figure(
|
||||
# "powers",
|
||||
# powers_fig,
|
||||
# epoch + 1,
|
||||
# )
|
||||
# writer.flush()
|
||||
|
||||
# if enable_progress:
|
||||
# progress.stop()
|
||||
@@ -511,15 +515,18 @@ class HyperTraining:
|
||||
|
||||
with torch.no_grad():
|
||||
model = model.to(self.pytorch_settings.device)
|
||||
for x, y, timestamp in loader:
|
||||
for batch in loader:
|
||||
x = batch["x"]
|
||||
y = batch["y"]
|
||||
timestamp = batch["timestamp"]
|
||||
x, y = (
|
||||
x.to(self.pytorch_settings.device),
|
||||
y.to(self.pytorch_settings.device),
|
||||
)
|
||||
if trace_powers:
|
||||
y_pred, powers = model(x, trace_powers).cpu()
|
||||
y_pred, powers = model(x, trace_powers=True).cpu()
|
||||
else:
|
||||
y_pred = model(x, trace_powers).cpu()
|
||||
y_pred = model(x, trace_powers=True).cpu()
|
||||
# x = x.cpu()
|
||||
# y = y.cpu()
|
||||
y_pred = y_pred.view(y_pred.shape[0], -1, 2)
|
||||
@@ -539,7 +546,7 @@ class HyperTraining:
|
||||
return fiber_in, fiber_out, regen, timestamps, powers
|
||||
return fiber_in, fiber_out, regen, timestamps
|
||||
|
||||
def objective(self, trial: optuna.Trial, plot_before=False):
|
||||
def objective(self, trial: optuna.Trial):
|
||||
if self.stop_study:
|
||||
trial.study.stop()
|
||||
model = None
|
||||
@@ -555,54 +562,54 @@ class HyperTraining:
|
||||
|
||||
title_append, subtitle = self.build_title(trial)
|
||||
|
||||
writer.add_figure(
|
||||
"fiber response",
|
||||
self.plot_model_response(
|
||||
trial,
|
||||
model=model,
|
||||
title_append=title_append,
|
||||
subtitle=subtitle,
|
||||
show=False,
|
||||
),
|
||||
0,
|
||||
)
|
||||
writer.add_figure(
|
||||
"eye diagram",
|
||||
self.plot_model_response(
|
||||
trial,
|
||||
model=self.model,
|
||||
title_append=title_append,
|
||||
subtitle=subtitle,
|
||||
mode="eye",
|
||||
show=False,
|
||||
),
|
||||
0,
|
||||
)
|
||||
# writer.add_figure(
|
||||
# "fiber response",
|
||||
# self.plot_model_response(
|
||||
# trial,
|
||||
# model=model,
|
||||
# title_append=title_append,
|
||||
# subtitle=subtitle,
|
||||
# show=False,
|
||||
# ),
|
||||
# 0,
|
||||
# )
|
||||
# writer.add_figure(
|
||||
# "eye diagram",
|
||||
# self.plot_model_response(
|
||||
# trial,
|
||||
# model=self.model,
|
||||
# title_append=title_append,
|
||||
# subtitle=subtitle,
|
||||
# mode="eye",
|
||||
# show=False,
|
||||
# ),
|
||||
# 0,
|
||||
# )
|
||||
|
||||
writer.add_figure(
|
||||
"powers",
|
||||
self.plot_model_response(
|
||||
trial,
|
||||
model=self.model,
|
||||
title_append=title_append,
|
||||
subtitle=subtitle,
|
||||
mode="powers",
|
||||
show=False,
|
||||
),
|
||||
0,
|
||||
)
|
||||
# writer.add_figure(
|
||||
# "powers",
|
||||
# self.plot_model_response(
|
||||
# trial,
|
||||
# model=self.model,
|
||||
# title_append=title_append,
|
||||
# subtitle=subtitle,
|
||||
# mode="powers",
|
||||
# show=False,
|
||||
# ),
|
||||
# 0,
|
||||
# )
|
||||
|
||||
train_loader, valid_loader = self.get_sliced_data(trial)
|
||||
|
||||
optimizer_name = trial.suggest_categorical_optional("optimizer", self.optimizer_settings.optimizer)
|
||||
|
||||
lr = trial.suggest_float_optional("lr", self.optimizer_settings.learning_rate, log=True)
|
||||
lr = trial.suggest_float_optional("lr", self.optimizer_settings.optimizer_kwargs["lr"], log=True)
|
||||
|
||||
optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)
|
||||
if self.optimizer_settings.scheduler is not None:
|
||||
scheduler = getattr(optim.lr_scheduler, self.optimizer_settings.scheduler)(
|
||||
optimizer, **self.optimizer_settings.scheduler_kwargs
|
||||
)
|
||||
# if self.optimizer_settings.scheduler is not None:
|
||||
# scheduler = getattr(optim.lr_scheduler, self.optimizer_settings.scheduler)(
|
||||
# optimizer, **self.optimizer_settings.scheduler_kwargs
|
||||
# )
|
||||
|
||||
for epoch in range(self.pytorch_settings.epochs):
|
||||
trial.set_user_attr("epoch", epoch)
|
||||
@@ -628,8 +635,8 @@ class HyperTraining:
|
||||
writer,
|
||||
# enable_progress=enable_progress,
|
||||
)
|
||||
if self.optimizer_settings.scheduler is not None:
|
||||
scheduler.step(error)
|
||||
# if self.optimizer_settings.scheduler is not None:
|
||||
# scheduler.step(error)
|
||||
|
||||
trial.set_user_attr("mse", error)
|
||||
trial.set_user_attr("log_mse", np.log10(error + np.finfo(float).eps))
|
||||
@@ -645,10 +652,10 @@ class HyperTraining:
|
||||
if self.optuna_settings._multi_objective:
|
||||
return -np.log10(error + np.finfo(float).eps), trial.user_attrs.get("model_n_nodes", -1)
|
||||
|
||||
if self.pytorch_settings.save_models and model is not None:
|
||||
save_path = Path(self.pytorch_settings.model_dir) / f"{self.optuna_settings.study_name}_{trial.number}.pth"
|
||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
torch.save(model, save_path)
|
||||
# if self.pytorch_settings.save_models and model is not None:
|
||||
# save_path = Path(self.pytorch_settings.model_dir) / f"{self.optuna_settings.study_name}_{trial.number}.pth"
|
||||
# save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
# torch.save(model, save_path)
|
||||
|
||||
return error
|
||||
|
||||
|
||||
@@ -8,7 +8,8 @@ from util.complexNN import (
|
||||
photodiode,
|
||||
EOActivation,
|
||||
polarimeter,
|
||||
normalize_by_first
|
||||
# normalize_by_first,
|
||||
rotate,
|
||||
)
|
||||
|
||||
|
||||
@@ -19,11 +20,11 @@ class polarisation_estimator2(Module):
|
||||
polarimeter(),
|
||||
torch.nn.Linear(4, 4),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Dropout(p=0.01),
|
||||
# torch.nn.Dropout(p=0.01),
|
||||
torch.nn.Linear(4, 4),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Dropout(p=0.01),
|
||||
torch.nn.Linear(4, 4),
|
||||
# torch.nn.Dropout(p=0.01),
|
||||
torch.nn.Linear(4, 1),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
@@ -123,7 +124,8 @@ class regenerator(Module):
|
||||
parametrizations: list[dict] = None,
|
||||
dtype=torch.float64,
|
||||
dropout_prob=0.01,
|
||||
scale_layers=False,
|
||||
prescale=1,
|
||||
rotate=False,
|
||||
):
|
||||
super(regenerator, self).__init__()
|
||||
self._n_hidden_layers = len(dims) - 2
|
||||
@@ -131,14 +133,15 @@ class regenerator(Module):
|
||||
layer_func_kwargs = layer_func_kwargs or {}
|
||||
act_func_kwargs = act_func_kwargs or {}
|
||||
|
||||
self.build_model(dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob, scale_layers)
|
||||
self.rotation = rotate
|
||||
self.prescale = prescale
|
||||
|
||||
def build_model(self, dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob, scale_layers):
|
||||
self.build_model(dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob)
|
||||
|
||||
def build_model(self, dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob):
|
||||
for i in range(0, self._n_hidden_layers):
|
||||
self.add_module(f"layer_{i}", Sequential())
|
||||
|
||||
if scale_layers:
|
||||
self.get_submodule(f"layer_{i}").add_module("scale", Scale(dims[i]))
|
||||
|
||||
module = layer_function(dims[i], dims[i + 1], dtype=dtype, **layer_func_kwargs)
|
||||
self.get_submodule(f"layer_{i}").add_module("ONN", module)
|
||||
@@ -146,13 +149,14 @@ class regenerator(Module):
|
||||
module = act_function(size=dims[i + 1], **act_func_kwargs)
|
||||
self.get_submodule(f"layer_{i}").add_module("activation", module)
|
||||
|
||||
module = DropoutComplex(p=dropout_prob)
|
||||
self.get_submodule(f"layer_{i}").add_module("dropout", module)
|
||||
if dropout_prob is not None and dropout_prob > 0:
|
||||
module = DropoutComplex(p=dropout_prob)
|
||||
self.get_submodule(f"layer_{i}").add_module("dropout", module)
|
||||
|
||||
self.add_module(f"layer_{self._n_hidden_layers}", Sequential())
|
||||
|
||||
if scale_layers:
|
||||
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("scale", Scale(dims[-2]))
|
||||
# if scale_layers:
|
||||
# self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("scale", Scale(dims[-2]))
|
||||
|
||||
module = layer_function(dims[-2], dims[-1], dtype=dtype, **layer_func_kwargs)
|
||||
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("ONN", module)
|
||||
@@ -160,6 +164,14 @@ class regenerator(Module):
|
||||
module = act_function(size=dims[-1], **act_func_kwargs)
|
||||
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("activation", module)
|
||||
|
||||
module = Scale(size=dims[-1])
|
||||
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("out_scale", module)
|
||||
|
||||
if self.rotation:
|
||||
module = rotate()
|
||||
self.add_module("rotate", module)
|
||||
|
||||
|
||||
# module = Scale(size=dims[-1])
|
||||
# self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("out_scale", module)
|
||||
|
||||
@@ -190,15 +202,28 @@ class regenerator(Module):
|
||||
powers.append(x.abs().square().sum())
|
||||
return powers
|
||||
|
||||
def forward(self, x, trace_powers=False):
|
||||
def forward(self, x, angle=None, pre_rot=False, trace_powers=False):
|
||||
x = x * self.prescale
|
||||
powers = self._trace_powers(trace_powers, x)
|
||||
x = self.layer_0(x)
|
||||
powers = self._trace_powers(trace_powers, x, powers)
|
||||
for i in range(1, self._n_hidden_layers):
|
||||
# x = self.layer_0(x)
|
||||
# powers = self._trace_powers(trace_powers, x, powers)
|
||||
for i in range(0, self._n_hidden_layers):
|
||||
x = getattr(self, f"layer_{i}")(x)
|
||||
powers = self._trace_powers(trace_powers, x, powers)
|
||||
x = getattr(self, f"layer_{self._n_hidden_layers}")(x)
|
||||
powers = self._trace_powers(trace_powers, x, powers)
|
||||
if trace_powers:
|
||||
return x, powers
|
||||
return x
|
||||
if self.rotation:
|
||||
try:
|
||||
x_rot = self.rotate(x, angle)
|
||||
except AttributeError:
|
||||
pass
|
||||
powers = self._trace_powers(trace_powers, x_rot, powers)
|
||||
else:
|
||||
x_rot = x
|
||||
|
||||
if pre_rot and trace_powers:
|
||||
return x_rot, x, powers
|
||||
if pre_rot and not trace_powers:
|
||||
return x_rot, x
|
||||
if not pre_rot and trace_powers:
|
||||
return x_rot, powers
|
||||
return x_rot
|
||||
@@ -18,10 +18,14 @@ class DataSettings:
|
||||
shuffle: bool = True
|
||||
in_out_delay: float = 0
|
||||
xy_delay: tuple | float | int = 0
|
||||
drop_first: int = 1000
|
||||
drop_first: int = 64
|
||||
drop_last: int = 64
|
||||
train_split: float = 0.8
|
||||
polarisations: tuple | list = (0,)
|
||||
# cross_pol_interference: float = 0
|
||||
randomise_polarisations: bool = False
|
||||
osnr: float | int = None
|
||||
seed: int = None
|
||||
|
||||
"""
|
||||
change to:
|
||||
@@ -91,6 +95,12 @@ class ModelSettings:
|
||||
"""
|
||||
|
||||
|
||||
def _early_stop_default_kwargs():
|
||||
return {
|
||||
"threshold": 1e-05,
|
||||
"plateau": 25,
|
||||
}
|
||||
|
||||
@dataclass
|
||||
class OptimizerSettings:
|
||||
optimizer: tuple | str = ("Adam", "RMSprop", "SGD")
|
||||
@@ -99,6 +109,9 @@ class OptimizerSettings:
|
||||
scheduler: str | None = None
|
||||
scheduler_kwargs: dict | None = None
|
||||
|
||||
early_stopping: bool = False
|
||||
early_stop_kwargs: dict = field(default_factory=_early_stop_default_kwargs)
|
||||
|
||||
"""
|
||||
change to:
|
||||
|
||||
|
||||
@@ -2,9 +2,9 @@ import copy
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
import random
|
||||
from typing import Literal
|
||||
import matplotlib
|
||||
from matplotlib.colors import LinearSegmentedColormap
|
||||
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
||||
import torch.nn.utils.parametrize
|
||||
|
||||
try:
|
||||
@@ -47,38 +47,98 @@ from .settings import (
|
||||
PytorchSettings,
|
||||
)
|
||||
|
||||
from cmcrameri import cm
|
||||
# from matplotlib import colors as mcolors
|
||||
# alpha_map = mcolors.LinearSegmentedColormap(
|
||||
# 'alphamap',
|
||||
# {
|
||||
# 'red': [(0, 0, 0), (1, 0, 0)],
|
||||
# 'green': [(0, 0, 0), (1, 0, 0)],
|
||||
# 'blue': [(0, 0, 0), (1, 0, 0)],
|
||||
# 'alpha': [
|
||||
# (0, 1, 1),
|
||||
# # (0.2, 0.2, 0.1),
|
||||
# (1, 0, 0)
|
||||
# ]
|
||||
# }
|
||||
# )
|
||||
# alpha_map.set_bad(color="#AAAAAA")
|
||||
|
||||
def pad_to_size(array, size):
|
||||
if not hasattr(size, "__len__"):
|
||||
size = (size, size)
|
||||
|
||||
left = (
|
||||
(size[0] - array.shape[0] + 1) // 2 if size[0] is not None else 0
|
||||
)
|
||||
right = (
|
||||
(size[0] - array.shape[0]) // 2 if size[0] is not None else 0
|
||||
)
|
||||
top = (
|
||||
(size[1] - array.shape[1] + 1) // 2 if size[1] is not None else 0
|
||||
)
|
||||
bottom = (
|
||||
(size[1] - array.shape[1]) // 2 if size[1] is not None else 0
|
||||
)
|
||||
|
||||
array: np.ndarray = array
|
||||
if array.ndim == 2:
|
||||
return np.pad(
|
||||
array,
|
||||
(
|
||||
(left, right),
|
||||
(top, bottom),
|
||||
),
|
||||
constant_values=(np.nan, np.nan),
|
||||
)
|
||||
elif array.ndim == 3:
|
||||
return np.pad(
|
||||
array,
|
||||
(
|
||||
(left, right),
|
||||
(top, bottom),
|
||||
(0,0)
|
||||
),
|
||||
constant_values=(np.nan, np.nan),
|
||||
)
|
||||
|
||||
def traverse_dict_update(target, source):
|
||||
for k, v in source.items():
|
||||
if isinstance(v, dict):
|
||||
if k not in target:
|
||||
target[k] = {}
|
||||
traverse_dict_update(target[k], v)
|
||||
try:
|
||||
if k not in target:
|
||||
target[k] = {}
|
||||
traverse_dict_update(target[k], v)
|
||||
except TypeError:
|
||||
if k not in target.__dict__:
|
||||
setattr(target, k, {})
|
||||
traverse_dict_update(target.__dict__[k], v)
|
||||
else:
|
||||
try:
|
||||
target[k] = v
|
||||
except TypeError:
|
||||
target.__dict__[k] = v
|
||||
|
||||
|
||||
def get_parameter_names_and_values(model):
|
||||
def is_parametrized(module):
|
||||
if hasattr(module, "parametrizations"):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _get_param_info(module, prefix='', parametrization=False):
|
||||
def _get_param_info(module, prefix="", parametrization=False):
|
||||
param_list = []
|
||||
for name, param in module.named_parameters(recurse = parametrization):
|
||||
for name, param in module.named_parameters(recurse=parametrization):
|
||||
if parametrization and name.startswith("parametrizations"):
|
||||
name_parts = name.split('.')
|
||||
name_parts = name.split(".")
|
||||
name = name_parts[1]
|
||||
param = getattr(module, name)
|
||||
full_name = prefix + ('.' if prefix else '') + name
|
||||
full_name = prefix + ("." if prefix else "") + name
|
||||
param_value = param.data
|
||||
param_list.append((full_name, param_value))
|
||||
|
||||
for child_name, child_module in module.named_children():
|
||||
child_prefix = prefix + ('.' if prefix else '') + child_name
|
||||
child_prefix = prefix + ("." if prefix else "") + child_name
|
||||
if child_name == "parametrizations":
|
||||
continue
|
||||
param_list.extend(_get_param_info(child_module, child_prefix, is_parametrized(child_module)))
|
||||
@@ -87,6 +147,7 @@ def get_parameter_names_and_values(model):
|
||||
|
||||
return _get_param_info(model)
|
||||
|
||||
|
||||
class PolarizationTrainer:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -101,7 +162,7 @@ class PolarizationTrainer:
|
||||
settings_override=None,
|
||||
reset_epoch=False,
|
||||
):
|
||||
self.mod = torch.pi/2
|
||||
self.mod = torch.pi / 2
|
||||
self.resume = checkpoint_path is not None
|
||||
torch.serialization.add_safe_globals([
|
||||
*util.complexNN.__all__,
|
||||
@@ -219,7 +280,7 @@ class PolarizationTrainer:
|
||||
|
||||
# dims = self.model_kwargs.pop("dims")
|
||||
model_kwargs = copy.deepcopy(self.model_kwargs)
|
||||
self.model = models.polarisation_estimator(*model_kwargs.pop('dims'),**model_kwargs)
|
||||
self.model = models.polarisation_estimator(*model_kwargs.pop("dims"), **model_kwargs)
|
||||
# self.model = models.polarisation_estimator2()
|
||||
|
||||
if self.writer is not None:
|
||||
@@ -260,6 +321,7 @@ class PolarizationTrainer:
|
||||
target_delay=in_out_delay,
|
||||
xy_delay=xy_delay,
|
||||
drop_first=self.data_settings.drop_first,
|
||||
drop_last=self.data_settings.drop_last,
|
||||
dtype=dtype,
|
||||
real=not dtype.is_complex,
|
||||
num_symbols=num_symbols,
|
||||
@@ -336,17 +398,20 @@ class PolarizationTrainer:
|
||||
write_div = 0
|
||||
loss_div = 0
|
||||
for batch_idx, batch in enumerate(train_loader):
|
||||
x = batch["x"]
|
||||
y = batch["sop"]
|
||||
x = batch["angle_data2"]
|
||||
y = batch["center_angle"]
|
||||
self.model.zero_grad(set_to_none=True)
|
||||
x, y = (
|
||||
x.to(self.pytorch_settings.device),
|
||||
y.to(self.pytorch_settings.device),
|
||||
)
|
||||
y_pred = self.model(x)
|
||||
y_pred = self.model(x).abs().real
|
||||
# y_pred = torch.fmod(y_pred, self.mod)
|
||||
y = y.abs().real
|
||||
# y = torch.fmod(y, self.mod)
|
||||
# loss = torch.nn.functional.mse_loss(torch.cos(y_pred), torch.cos(y))
|
||||
# loss = torch.nn.functional.smooth_l1_loss(torch.cos(torch.fmod(y_pred, torch.pi/2)).squeeze(), torch.cos(torch.fmod(y, torch.pi/2)).squeeze(), beta=0.5)
|
||||
loss = torch.nn.functional.mse_loss(y_pred, y)
|
||||
# loss = util.complexNN.naive_angle_loss(y_pred, y, mod=torch.pi/2)
|
||||
loss = util.complexNN.naive_angle_loss(y_pred, y, mod=self.mod)
|
||||
loss_value = loss.item()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
@@ -356,7 +421,7 @@ class PolarizationTrainer:
|
||||
loss_div += 1
|
||||
|
||||
if enable_progress:
|
||||
progress.update(task, advance=1, description=f"{loss_value:.3e}")
|
||||
progress.update(task, advance=1, description=f"{loss_value/np.pi*180:.3e} °")
|
||||
|
||||
if batch_idx % self.pytorch_settings.write_every == 0:
|
||||
self.writer.add_scalar(
|
||||
@@ -395,24 +460,28 @@ class PolarizationTrainer:
|
||||
loss_div = 0
|
||||
with torch.no_grad():
|
||||
for _, batch in enumerate(valid_loader):
|
||||
x = batch["x"]
|
||||
y = batch["sop"]
|
||||
# x = batch["angle_data2"]
|
||||
x = batch["angle_data2"]
|
||||
y = batch["center_angle"]
|
||||
x, y = (
|
||||
x.to(self.pytorch_settings.device),
|
||||
y.to(self.pytorch_settings.device),
|
||||
)
|
||||
y_pred = self.model(x)
|
||||
y_pred = self.model(x).abs().real
|
||||
# y_pred = torch.fmod(y_pred, self.mod)
|
||||
y = y.abs().real
|
||||
# y = torch.fmod(y, self.mod)
|
||||
# loss = torch.nn.functional.mse_loss(torch.cos(y_pred), torch.cos(y))
|
||||
# loss = torch.nn.functional.smooth_l1_loss(torch.cos(torch.fmod(y_pred, torch.pi/2)).squeeze(), torch.cos(torch.fmod(y, torch.pi/2)).squeeze(), beta=0.5)
|
||||
loss = torch.nn.functional.mse_loss(y_pred, y)
|
||||
# loss = util.complexNN.naive_angle_loss(y_pred, y, mod=torch.pi/2)
|
||||
loss = util.complexNN.naive_angle_loss(y_pred, y, mod=self.mod)
|
||||
loss_value = loss.item()
|
||||
running_loss += loss_value
|
||||
loss_div += 1
|
||||
|
||||
if enable_progress:
|
||||
progress.update(task, advance=1, description=f"{loss_value:.3e}")
|
||||
progress.update(task, advance=1, description=f"{loss_value/np.pi*180:.3e} °")
|
||||
|
||||
running_loss = running_loss/loss_div
|
||||
running_loss = running_loss / loss_div
|
||||
|
||||
self.writer.add_scalar(
|
||||
"eval loss",
|
||||
@@ -506,19 +575,19 @@ class PolarizationTrainer:
|
||||
for i, config_path in enumerate(self.data_settings.config_path):
|
||||
paths = Path.cwd().glob(config_path)
|
||||
for j, path in enumerate(paths):
|
||||
text = str(path) + '\n'
|
||||
with open(path, 'r') as f:
|
||||
text = str(path) + "\n"
|
||||
with open(path, "r") as f:
|
||||
text += f.read()
|
||||
text += '\n'
|
||||
self.writer.add_text(f"config_{i*len(self.data_settings.config_path)+j}", text)
|
||||
text += "\n"
|
||||
self.writer.add_text(f"config_{i * len(self.data_settings.config_path) + j}", text)
|
||||
|
||||
elif isinstance(self.data_settings.config_path, str):
|
||||
paths = Path.cwd().glob(self.data_settings.config_path)
|
||||
for j, path in enumerate(paths):
|
||||
text = str(path) + '\n'
|
||||
with open(path, 'r') as f:
|
||||
text = str(path) + "\n"
|
||||
with open(path, "r") as f:
|
||||
text += f.read()
|
||||
text += '\n'
|
||||
text += "\n"
|
||||
self.writer.add_text(f"config_{j}", text)
|
||||
|
||||
self.writer.flush()
|
||||
@@ -571,7 +640,8 @@ class PolarizationTrainer:
|
||||
if loss < self.best["loss"]:
|
||||
self.best = checkpoint
|
||||
save_path = (
|
||||
Path(self.pytorch_settings.model_dir) / f"best_pol_{self.writer.get_logdir().split('/')[-1]}.tar"
|
||||
Path(self.pytorch_settings.model_dir)
|
||||
/ f"best_pol_{self.writer.get_logdir().split('/')[-1]}.tar"
|
||||
)
|
||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.save_checkpoint(self.best, save_path)
|
||||
@@ -580,6 +650,7 @@ class PolarizationTrainer:
|
||||
self.writer.close()
|
||||
return self.best
|
||||
|
||||
|
||||
class RegenerationTrainer:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -592,6 +663,7 @@ class RegenerationTrainer:
|
||||
console=None,
|
||||
checkpoint_path=None,
|
||||
settings_override=None,
|
||||
new_model=False,
|
||||
reset_epoch=False,
|
||||
):
|
||||
self.resume = checkpoint_path is not None
|
||||
@@ -605,12 +677,23 @@ class RegenerationTrainer:
|
||||
models.regenerator,
|
||||
torch.nn.utils.parametrizations.orthogonal,
|
||||
])
|
||||
# self.new_model = True
|
||||
self.model_name = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
if self.resume:
|
||||
print(f"loading checkpoint from {checkpoint_path}")
|
||||
self.checkpoint_dict = torch.load(checkpoint_path, weights_only=True)
|
||||
if settings_override is not None:
|
||||
traverse_dict_update(self.checkpoint_dict["settings"], settings_override)
|
||||
if reset_epoch:
|
||||
|
||||
if not new_model:
|
||||
# self.new_model = False
|
||||
checkpoint_file = checkpoint_path.split("/")[-1].split(".")[0]
|
||||
if checkpoint_file.startswith("best"):
|
||||
self.model_name = "_".join(checkpoint_file.split("_")[1:])
|
||||
else:
|
||||
self.model_name = "_".join(checkpoint_file.split("_")[:-1])
|
||||
|
||||
if new_model or reset_epoch:
|
||||
self.checkpoint_dict["epoch"] = -1
|
||||
|
||||
self.global_settings: GlobalSettings = self.checkpoint_dict["settings"]["global_settings"]
|
||||
@@ -636,11 +719,15 @@ class RegenerationTrainer:
|
||||
self.model_settings: ModelSettings = model_settings
|
||||
self.optimizer_settings: OptimizerSettings = optimizer_settings
|
||||
|
||||
if self.global_settings.seed is not None:
|
||||
random.seed(self.global_settings.seed)
|
||||
np.random.seed(self.global_settings.seed)
|
||||
|
||||
self.console = console or Console()
|
||||
self.writer = None
|
||||
|
||||
def setup_tb_writer(self, append=None):
|
||||
log_dir = self.pytorch_settings.summary_dir + "/" + (datetime.now().strftime("%Y%m%d_%H%M%S"))
|
||||
log_dir = self.pytorch_settings.summary_dir + "/" + self.model_name
|
||||
if append is not None:
|
||||
log_dir += "_" + str(append)
|
||||
|
||||
@@ -669,7 +756,7 @@ class RegenerationTrainer:
|
||||
|
||||
def define_model(self, model_kwargs=None):
|
||||
if self.resume:
|
||||
model_kwargs = self.checkpoint_dict["model_kwargs"]
|
||||
model_kwargs = None
|
||||
else:
|
||||
model_kwargs = model_kwargs
|
||||
|
||||
@@ -678,6 +765,14 @@ class RegenerationTrainer:
|
||||
|
||||
input_dim = 2 * self.data_settings.output_size
|
||||
|
||||
# if self.data_settings.polarisations:
|
||||
# input_dim *= 2
|
||||
|
||||
output_dim = self.model_settings.output_dim
|
||||
|
||||
if self.data_settings.polarisations:
|
||||
output_dim *= 2
|
||||
|
||||
dtype = getattr(torch, self.data_settings.dtype)
|
||||
|
||||
afunc = getattr(util.complexNN, self.model_settings.model_activation_func)
|
||||
@@ -689,7 +784,7 @@ class RegenerationTrainer:
|
||||
hidden_dims = [self.model_settings.overrides.get(f"n_hidden_nodes_{i}") for i in range(n_hidden_layers)]
|
||||
|
||||
self.model_kwargs = {
|
||||
"dims": (input_dim, *hidden_dims, self.model_settings.output_dim),
|
||||
"dims": (input_dim, *hidden_dims, output_dim),
|
||||
"layer_function": layer_func,
|
||||
"layer_func_kwargs": self.model_settings.model_layer_kwargs,
|
||||
"act_function": afunc,
|
||||
@@ -697,7 +792,7 @@ class RegenerationTrainer:
|
||||
"parametrizations": layer_parametrizations,
|
||||
"dtype": dtype,
|
||||
"dropout_prob": self.model_settings.dropout_prob,
|
||||
"scale_layers": self.model_settings.scale,
|
||||
"prescale": self.model_settings.scale,
|
||||
}
|
||||
else:
|
||||
self.model_kwargs = model_kwargs
|
||||
@@ -706,10 +801,12 @@ class RegenerationTrainer:
|
||||
|
||||
# dims = self.model_kwargs.pop("dims")
|
||||
model_kwargs = copy.deepcopy(self.model_kwargs)
|
||||
self.model = models.regenerator(*model_kwargs.pop('dims'),**model_kwargs)
|
||||
self.model = models.regenerator(*model_kwargs.pop("dims"), **model_kwargs)
|
||||
|
||||
if self.writer is not None:
|
||||
self.writer.add_graph(self.model, torch.zeros(1, input_dim, dtype=dtype))
|
||||
self.writer.add_graph(
|
||||
self.model, (torch.rand(1, input_dim, dtype=dtype), torch.rand(1, 1, dtype=dtype.to_real()))
|
||||
)
|
||||
|
||||
self.model = self.model.to(self.pytorch_settings.device)
|
||||
if self.resume:
|
||||
@@ -728,13 +825,16 @@ class RegenerationTrainer:
|
||||
|
||||
num_symbols = None
|
||||
config_path = self.data_settings.config_path
|
||||
polarisations = self.data_settings.polarisations
|
||||
randomise_polarisations = self.data_settings.randomise_polarisations
|
||||
polarisations = self.data_settings.polarisations
|
||||
osnr = self.data_settings.osnr
|
||||
# cross_pol_interference = self.data_settings.cross_pol_interference
|
||||
if override is not None:
|
||||
num_symbols = override.get("num_symbols", None)
|
||||
config_path = override.get("config_path", config_path)
|
||||
polarisations = override.get("polarisations", polarisations)
|
||||
randomise_polarisations = override.get("randomise_polarisation", randomise_polarisations)
|
||||
# cross_pol_interference = override.get("angle_var", 0)
|
||||
# get dataset
|
||||
dataset = FiberRegenerationDataset(
|
||||
file_path=config_path,
|
||||
@@ -743,11 +843,14 @@ class RegenerationTrainer:
|
||||
target_delay=in_out_delay,
|
||||
xy_delay=xy_delay,
|
||||
drop_first=self.data_settings.drop_first,
|
||||
drop_last=self.data_settings.drop_last,
|
||||
dtype=dtype,
|
||||
real=not dtype.is_complex,
|
||||
num_symbols=num_symbols,
|
||||
polarisations=polarisations,
|
||||
randomise_polarisations=randomise_polarisations,
|
||||
polarisations=polarisations,
|
||||
# cross_pol_interference=cross_pol_interference,
|
||||
osnr = osnr,
|
||||
)
|
||||
|
||||
dataset_size = len(dataset)
|
||||
@@ -816,16 +919,25 @@ class RegenerationTrainer:
|
||||
running_loss = 0.0
|
||||
self.model.train()
|
||||
loader_len = len(train_loader)
|
||||
# x_key = "x_stacked"# if self.data_settings.polarisations else "x"
|
||||
# y_key = "y_stacked"# if self.data_settings.polarisations else "y"
|
||||
x_key = "x"
|
||||
y_key = "y"
|
||||
for batch_idx, batch in enumerate(train_loader):
|
||||
x = batch["x"]
|
||||
y = batch["y"]
|
||||
x = batch[x_key]
|
||||
y = batch[y_key]
|
||||
angle = batch["mean_angle"]
|
||||
self.model.zero_grad(set_to_none=True)
|
||||
x, y = (
|
||||
x, y, angle = (
|
||||
x.to(self.pytorch_settings.device),
|
||||
y.to(self.pytorch_settings.device),
|
||||
angle.to(self.pytorch_settings.device),
|
||||
)
|
||||
y_pred = self.model(x)
|
||||
y_pred = self.model(x, -angle)
|
||||
# loss = util.complexNN.complex_mse_loss(y_pred*torch.sqrt(torch.ones(1, device=y.device)*5), y*torch.sqrt(torch.ones(1, device=y.device)*5), power=True)
|
||||
loss = util.complexNN.complex_mse_loss(y_pred, y, power=True)
|
||||
|
||||
|
||||
loss_value = loss.item()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
@@ -868,23 +980,31 @@ class RegenerationTrainer:
|
||||
|
||||
self.model.eval()
|
||||
running_error = 0
|
||||
# x_key = "x_stacked"# if self.data_settings.polarisations else "x"
|
||||
# y_key = "y_stacked"# if self.data_settings.polarisations else "y"
|
||||
x_key = "x"
|
||||
y_key = "y"
|
||||
with torch.no_grad():
|
||||
for _, batch in enumerate(valid_loader):
|
||||
x = batch["x"]
|
||||
y = batch["y"]
|
||||
x, y = (
|
||||
x = batch[x_key]
|
||||
y = batch[y_key]
|
||||
angle = batch["mean_angle"]
|
||||
x, y, angle = (
|
||||
x.to(self.pytorch_settings.device),
|
||||
y.to(self.pytorch_settings.device),
|
||||
angle.to(self.pytorch_settings.device),
|
||||
)
|
||||
y_pred = self.model(x)
|
||||
y_pred = self.model(x, -angle)
|
||||
# error = util.complexNN.complex_mse_loss(y_pred*torch.sqrt(torch.ones(1, device=y.device)*5), y*torch.sqrt(torch.ones(1, device=y.device)*5), power=True)
|
||||
error = util.complexNN.complex_mse_loss(y_pred, y, power=True)
|
||||
|
||||
error_value = error.item()
|
||||
running_error += error_value
|
||||
|
||||
if enable_progress:
|
||||
progress.update(task, advance=1, description=f"{error_value:.3e}")
|
||||
|
||||
running_error = running_error/len(valid_loader)
|
||||
running_error = running_error / len(valid_loader)
|
||||
|
||||
self.writer.add_scalar(
|
||||
"eval loss",
|
||||
@@ -894,7 +1014,7 @@ class RegenerationTrainer:
|
||||
if (epoch + 1) % 10 == 0 or epoch < 10:
|
||||
# plotting is slow, so only do it every 10 epochs
|
||||
title_append, subtitle = self.build_title(epoch + 1)
|
||||
head_fig, eye_fig, powers_fig = self.plot_model_response(
|
||||
head_fig, eye_fig, weight_fig, powers_fig = self.plot_model_response(
|
||||
model=self.model,
|
||||
title_append=title_append,
|
||||
subtitle=subtitle,
|
||||
@@ -910,6 +1030,11 @@ class RegenerationTrainer:
|
||||
eye_fig,
|
||||
epoch + 1,
|
||||
)
|
||||
self.writer.add_figure(
|
||||
"weights",
|
||||
weight_fig,
|
||||
epoch + 1,
|
||||
)
|
||||
|
||||
self.writer.add_figure(
|
||||
"powers",
|
||||
@@ -928,45 +1053,70 @@ class RegenerationTrainer:
|
||||
def run_model(self, model, loader, trace_powers=False):
|
||||
model.eval()
|
||||
fiber_out = []
|
||||
fiber_out_rot = []
|
||||
fiber_in = []
|
||||
regen = []
|
||||
timestamps = []
|
||||
|
||||
angles = []
|
||||
# x_key = "x_stacked"# if self.data_settings.polarisations else "x"
|
||||
# y_key = "y_stacked"# if self.data_settings.polarisations else "y"
|
||||
x_key = "x"
|
||||
y_key = "y"
|
||||
with torch.no_grad():
|
||||
model = model.to(self.pytorch_settings.device)
|
||||
for batch in loader:
|
||||
x = batch["x"]
|
||||
y = batch["y"]
|
||||
x = batch[x_key]
|
||||
y = batch[y_key]
|
||||
plot_target = batch["plot_target"]
|
||||
angle = batch["mean_angle"]
|
||||
# center_angle = batch["center_angle"]
|
||||
timestamp = batch["timestamp"]
|
||||
plot_data = batch["plot_data"]
|
||||
x, y = (
|
||||
plot_data_rot = batch["plot_data_rot"]
|
||||
x, y, angle = (
|
||||
x.to(self.pytorch_settings.device),
|
||||
y.to(self.pytorch_settings.device),
|
||||
angle.to(self.pytorch_settings.device),
|
||||
)
|
||||
if trace_powers:
|
||||
y_pred, powers = model(x, trace_powers).cpu()
|
||||
y_pred, powers = model(x, -angle, True).cpu()
|
||||
else:
|
||||
y_pred = model(x, trace_powers).cpu()
|
||||
y_pred = model(x, -angle).cpu()
|
||||
# x = x.cpu()
|
||||
# y = y.cpu()
|
||||
# if self.data_settings.polarisations:
|
||||
y_pred = y_pred[:, :2]
|
||||
y_pred = y_pred.view(y_pred.shape[0], -1, 2)
|
||||
y = y.view(y.shape[0], -1, 2)
|
||||
plot_data = plot_data.view(plot_data.shape[0], -1, 2)
|
||||
y_pred = y_pred[:, y_pred.shape[1]//2, :]
|
||||
# y = y.view(y.shape[0], -1, 2)
|
||||
# plot_data = plot_data.view(plot_data.shape[0], -1, 2)
|
||||
# c = torch.cos(-angle).cpu()
|
||||
# s = torch.sin(-angle).cpu()
|
||||
# rot = torch.stack([torch.stack([c, -s], dim=1), torch.stack([s, c], dim=1)], dim=2).squeeze(-1)
|
||||
# plot_data = torch.bmm(plot_data, rot.to(dtype=plot_data.dtype))
|
||||
# plot_data = plot_data
|
||||
# sines = torch.sin(-angle.cpu())
|
||||
# cosines = torch.cos(-angle.cpu())
|
||||
# plot_data = torch.stack((plot_data[..., 0] * cosines - plot_data[..., 1] * sines, plot_data[..., 0] * sines + plot_data[..., 1] * cosines), dim=-1)
|
||||
# x = x.view(x.shape[0], -1, 2)
|
||||
|
||||
# timestamp = timestamp.view(-1, 1)
|
||||
fiber_out.append(plot_data.squeeze())
|
||||
fiber_in.append(y.squeeze())
|
||||
fiber_out_rot.append(plot_data_rot.squeeze())
|
||||
fiber_in.append(plot_target.squeeze())
|
||||
regen.append(y_pred.squeeze())
|
||||
timestamps.append(timestamp.squeeze())
|
||||
angles.append(angle.squeeze())
|
||||
|
||||
fiber_out = torch.vstack(fiber_out).cpu()
|
||||
fiber_out_rot = torch.vstack(fiber_out_rot).cpu()
|
||||
fiber_in = torch.vstack(fiber_in).cpu()
|
||||
regen = torch.vstack(regen).cpu()
|
||||
angles = torch.vstack(angles).cpu()
|
||||
timestamps = torch.concat(timestamps).cpu()
|
||||
if trace_powers:
|
||||
return fiber_in, fiber_out, regen, timestamps, powers
|
||||
return fiber_in, fiber_out, regen, timestamps
|
||||
return fiber_in, fiber_out, fiber_out_rot, angles, regen, timestamps, powers
|
||||
return fiber_in, fiber_out, fiber_out_rot, angles, regen, timestamps
|
||||
|
||||
def write_parameters(self, epoch, attributes: list[str] | tuple[str] = None):
|
||||
parameter_list = get_parameter_names_and_values(self.model)
|
||||
@@ -998,7 +1148,7 @@ class RegenerationTrainer:
|
||||
)
|
||||
|
||||
title_append, subtitle = self.build_title(0)
|
||||
head_fig, eye_fig, powers_fig = self.plot_model_response(
|
||||
head_fig, eye_fig, weight_fig, powers_fig = self.plot_model_response(
|
||||
model=self.model,
|
||||
title_append=title_append,
|
||||
subtitle=subtitle,
|
||||
@@ -1014,6 +1164,11 @@ class RegenerationTrainer:
|
||||
eye_fig,
|
||||
0,
|
||||
)
|
||||
self.writer.add_figure(
|
||||
"weights",
|
||||
weight_fig,
|
||||
0,
|
||||
)
|
||||
|
||||
self.writer.add_figure(
|
||||
"powers",
|
||||
@@ -1027,24 +1182,27 @@ class RegenerationTrainer:
|
||||
for i, config_path in enumerate(self.data_settings.config_path):
|
||||
paths = Path.cwd().glob(config_path)
|
||||
for j, path in enumerate(paths):
|
||||
text = str(path) + '\n'
|
||||
with open(path, 'r') as f:
|
||||
text = str(path) + "\n"
|
||||
with open(path, "r") as f:
|
||||
text += f.read()
|
||||
text += '\n'
|
||||
self.writer.add_text(f"config_{i*len(self.data_settings.config_path)+j}", text)
|
||||
text += "\n"
|
||||
self.writer.add_text(f"config_{i * len(self.data_settings.config_path) + j}", text)
|
||||
elif isinstance(self.data_settings.config_path, str):
|
||||
paths = Path.cwd().glob(self.data_settings.config_path)
|
||||
for j, path in enumerate(paths):
|
||||
text = str(path) + '\n'
|
||||
with open(path, 'r') as f:
|
||||
text = str(path) + "\n"
|
||||
with open(path, "r") as f:
|
||||
text += f.read()
|
||||
text += '\n'
|
||||
text += "\n"
|
||||
self.writer.add_text(f"config_{j}", text)
|
||||
|
||||
self.writer.flush()
|
||||
|
||||
train_loader, valid_loader = self.get_sliced_data()
|
||||
|
||||
# train_loader.dataset.fiber_out.to(self.pytorch_settings.device)
|
||||
# train_loader.dataset.fiber_in.to(self.pytorch_settings.device)
|
||||
|
||||
optimizer_name = self.optimizer_settings.optimizer
|
||||
|
||||
# lr = self.optimizer_settings.learning_rate
|
||||
@@ -1074,6 +1232,7 @@ class RegenerationTrainer:
|
||||
# except ValueError:
|
||||
# pass
|
||||
|
||||
self.early_stop_vals = {"min_loss": float("inf"), "plateau_cnt": 0}
|
||||
for epoch in range(self.best["epoch"] + 1, self.pytorch_settings.epochs):
|
||||
enable_progress = True
|
||||
if enable_progress:
|
||||
@@ -1089,33 +1248,69 @@ class RegenerationTrainer:
|
||||
epoch,
|
||||
enable_progress=enable_progress,
|
||||
)
|
||||
if self.early_stop(loss):
|
||||
self.save_model_checkpoints(epoch, loss)
|
||||
break
|
||||
if self.optimizer_settings.scheduler is not None:
|
||||
self.scheduler.step(loss)
|
||||
self.writer.add_scalar("learning rate", self.optimizer.param_groups[0]["lr"], epoch)
|
||||
if self.pytorch_settings.save_models and self.model is not None:
|
||||
save_path = (
|
||||
Path(self.pytorch_settings.model_dir) / f"{self.writer.get_logdir().split('/')[-1]}_{epoch}.tar"
|
||||
)
|
||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
checkpoint = self.build_checkpoint_dict(loss, epoch)
|
||||
self.save_checkpoint(checkpoint, save_path)
|
||||
|
||||
if loss < self.best["loss"]:
|
||||
self.best = checkpoint
|
||||
save_path = (
|
||||
Path(self.pytorch_settings.model_dir) / f"best_{self.writer.get_logdir().split('/')[-1]}.tar"
|
||||
)
|
||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.save_checkpoint(self.best, save_path)
|
||||
self.save_model_checkpoints(epoch, loss)
|
||||
self.writer.flush()
|
||||
|
||||
save_path = (Path(self.pytorch_settings.model_dir) / f"best_{self.writer.get_logdir().split('/')[-1]}.tar")
|
||||
print(f"Training complete. Best checkpoint: {save_path}")
|
||||
self.writer.close()
|
||||
return self.best
|
||||
|
||||
def early_stop(self, loss):
|
||||
# not stopping early at all
|
||||
if not self.optimizer_settings.early_stopping:
|
||||
return False
|
||||
|
||||
# stopping because of abs threshold
|
||||
if (loss_thr := self.optimizer_settings.early_stop_kwargs.get("threshold", None)) is not None:
|
||||
if loss <= loss_thr:
|
||||
print(f"Early stop: loss is below threshold ({loss:.2e} <= {loss_thr:.2e})")
|
||||
return True
|
||||
|
||||
# update vals
|
||||
if loss < self.early_stop_vals["min_loss"]:
|
||||
self.early_stop_vals["min_loss"] = loss
|
||||
self.early_stop_vals["plateau_cnt"] = 0
|
||||
return False
|
||||
|
||||
# stopping because of plateau
|
||||
if (plateau_thresh := self.optimizer_settings.early_stop_kwargs.get("plateau", None)) is not None:
|
||||
self.early_stop_vals["plateau_cnt"] += 1
|
||||
if self.early_stop_vals["plateau_cnt"] >= plateau_thresh:
|
||||
print(f"Early stop: loss plateau length over threshold ({self.early_stop_vals["plateau_cnt"]} >= {plateau_thresh})")
|
||||
return True
|
||||
|
||||
# no stop
|
||||
return False
|
||||
|
||||
def save_model_checkpoints(self, epoch, loss):
|
||||
if self.pytorch_settings.save_models and self.model is not None:
|
||||
save_path = (
|
||||
Path(self.pytorch_settings.model_dir) / f"{self.writer.get_logdir().split('/')[-1]}_{epoch}.tar"
|
||||
)
|
||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
checkpoint = self.build_checkpoint_dict(loss, epoch)
|
||||
self.save_checkpoint(checkpoint, save_path)
|
||||
|
||||
if loss < self.best["loss"]:
|
||||
self.best = checkpoint
|
||||
save_path = (
|
||||
Path(self.pytorch_settings.model_dir) / f"best_{self.writer.get_logdir().split('/')[-1]}.tar"
|
||||
)
|
||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.save_checkpoint(self.best, save_path)
|
||||
|
||||
def _plot_model_response_powers(self, powers, layer_names, title_append="", subtitle="", show=True):
|
||||
powers = [power / powers[0] for power in powers]
|
||||
fig, ax = plt.subplots()
|
||||
fig.set_figwidth(18)
|
||||
fig.set_figheight(4)
|
||||
fig.suptitle(
|
||||
f"Energy conservation{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}"
|
||||
)
|
||||
@@ -1131,6 +1326,77 @@ class RegenerationTrainer:
|
||||
plt.show()
|
||||
return fig
|
||||
|
||||
def _plot_model_weights(self, model, title_append="", subtitle="", show=True):
|
||||
model_params = []
|
||||
plots = []
|
||||
dims = []
|
||||
for num, (layer_name, layer) in enumerate(model.named_children()):
|
||||
onn_weights = layer.ONN.weight
|
||||
onn_weights = onn_weights.detach().cpu().numpy()
|
||||
onn_values = np.abs(onn_weights).real
|
||||
onn_angles = np.mod(np.angle(onn_weights), 2*np.pi).real
|
||||
|
||||
model_params.append({layer_name: onn_weights})
|
||||
plots.append({layer_name: (num, onn_values, onn_angles)})
|
||||
dims.append(onn_weights.shape[0])
|
||||
|
||||
max_size = np.max(dims)
|
||||
|
||||
for plot in plots:
|
||||
layer_name, (num, onn_values, onn_angles) = plot.popitem()
|
||||
|
||||
if num == 0:
|
||||
value_img = onn_values
|
||||
angle_img = onn_angles
|
||||
onn_angles = pad_to_size(onn_angles, (max_size, None))
|
||||
onn_values = pad_to_size(onn_values, (max_size, None))
|
||||
else:
|
||||
onn_values = pad_to_size(onn_values, (max_size, onn_values.shape[1]+1))
|
||||
onn_angles = pad_to_size(onn_angles, (max_size, onn_angles.shape[1]+1))
|
||||
value_img = np.concatenate((value_img, onn_values), axis=1)
|
||||
angle_img = np.concatenate((angle_img, onn_angles), axis=1)
|
||||
|
||||
value_img = np.ma.array(value_img, mask=np.isnan(value_img))
|
||||
angle_img = np.ma.array(angle_img, mask=np.isnan(angle_img))
|
||||
|
||||
fig, axs = plt.subplots(1, 2, sharex=True, sharey=True, figsize=(18, 6.5))
|
||||
fig.tight_layout()
|
||||
|
||||
dividers = map(make_axes_locatable, axs)
|
||||
caxs = list(map(lambda d: d.append_axes("right", size="5%", pad=0.05), dividers))
|
||||
|
||||
masked_value_img = value_img
|
||||
cmap = cm.batlow
|
||||
cmap.set_bad(color="#AAAAAA")
|
||||
im_val = axs[0].imshow(masked_value_img, cmap=cmap, vmin=0, vmax=1)
|
||||
fig.colorbar(im_val, cax=caxs[0], orientation="vertical")
|
||||
|
||||
masked_angle_img = np.ma.masked_where(np.isnan(angle_img), angle_img)
|
||||
cmap = cm.romaO
|
||||
cmap.set_bad(color="#AAAAAA")
|
||||
im_ang = axs[1].imshow(masked_angle_img, cmap=cmap, vmin=0, vmax=2*np.pi)
|
||||
cbar = fig.colorbar(im_ang, cax=caxs[1], orientation="vertical", ticks=[i/8 * 2*np.pi for i in range(9)])
|
||||
cbar.ax.set_yticklabels(["0", "π/4", "π/2", "3π/4", "π", "5π/4", "3π/2", "7π/4", "2π"])
|
||||
|
||||
|
||||
axs[0].axis("off")
|
||||
axs[1].axis("off")
|
||||
|
||||
axs[0].set_title("Values")
|
||||
axs[1].set_title("Angles")
|
||||
|
||||
title = "Layer Weights"
|
||||
if title_append:
|
||||
title += f" {title_append}"
|
||||
if subtitle:
|
||||
title += f"\n{subtitle}"
|
||||
fig.suptitle(title)
|
||||
|
||||
|
||||
if show:
|
||||
plt.show()
|
||||
return fig
|
||||
|
||||
def _plot_model_response_eye(
|
||||
self, *signals, timestamps, labels=None, sps=None, title_append="", subtitle="", show=True
|
||||
):
|
||||
@@ -1184,6 +1450,7 @@ class RegenerationTrainer:
|
||||
|
||||
fig, axs = plt.subplots(1, 2 * len(signals), sharex=True, sharey=True)
|
||||
fig.set_figwidth(18)
|
||||
fig.set_figheight(4)
|
||||
fig.suptitle(f"Eye diagram{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}")
|
||||
# xaxis = timestamps / sps
|
||||
# xaxis = np.arange(2 * sps) / sps
|
||||
@@ -1253,7 +1520,7 @@ class RegenerationTrainer:
|
||||
xaxis = timestamps / sps
|
||||
else:
|
||||
xaxis = timestamps
|
||||
ax.plot(xaxis, np.abs(signal[:, i]) ** 2, label=label)
|
||||
ax.plot(xaxis, np.abs(signal[:, i]) ** 2, label=label, alpha=0.7)
|
||||
ax.set_xlabel("Sample" if sps is None else "Symbol")
|
||||
ax.set_ylabel("normalized power")
|
||||
ax.minorticks_on()
|
||||
@@ -1269,7 +1536,7 @@ class RegenerationTrainer:
|
||||
|
||||
def plot_model_response(
|
||||
self,
|
||||
model:torch.nn.Module=None,
|
||||
model: torch.nn.Module = None,
|
||||
title_append="",
|
||||
subtitle="",
|
||||
# mode: Literal["eye", "head", "powers"] = "head",
|
||||
@@ -1281,7 +1548,9 @@ class RegenerationTrainer:
|
||||
model = model.to(self.pytorch_settings.device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
_, powers = model(input_data, trace_powers=True)
|
||||
_, powers = model(
|
||||
input_data, torch.zeros(input_data.shape[0], 1).to(self.pytorch_settings.device), trace_powers=True
|
||||
)
|
||||
|
||||
powers = [power.item() for power in powers]
|
||||
layer_names = [name for (name, _) in model.named_children()]
|
||||
@@ -1292,33 +1561,48 @@ class RegenerationTrainer:
|
||||
|
||||
data_settings_backup = copy.deepcopy(self.data_settings)
|
||||
pytorch_settings_backup = copy.deepcopy(self.pytorch_settings)
|
||||
self.data_settings.drop_first = 99.5 + random.randint(0, 1000)
|
||||
self.data_settings.drop_first = int(64 + random.randint(0, 1000))
|
||||
self.data_settings.shuffle = False
|
||||
self.data_settings.train_split = 1.0
|
||||
self.pytorch_settings.batchsize = max(self.pytorch_settings.head_symbols, self.pytorch_settings.eye_symbols)
|
||||
config_path = random.choice(self.data_settings.config_path) if isinstance(self.data_settings.config_path, (list, tuple)) else self.data_settings.config_path
|
||||
fiber_length = int(float(str(config_path).split('-')[4])/1000)
|
||||
config_path = (
|
||||
random.choice(self.data_settings.config_path)
|
||||
if isinstance(self.data_settings.config_path, (list, tuple))
|
||||
else self.data_settings.config_path
|
||||
)
|
||||
# fiber_length = int(float(str(config_path).split("-")[4]) / 1000)
|
||||
if not hasattr(self, "_plot_loader"):
|
||||
self._plot_loader, _ = self.get_sliced_data(
|
||||
override={
|
||||
"num_symbols": self.pytorch_settings.batchsize,
|
||||
"config_path": config_path,
|
||||
"shuffle": False,
|
||||
"polarisations": (np.random.rand(1)*np.pi*2,),
|
||||
"randomise_polarisation": False,
|
||||
# "polarisations": (np.random.rand(1) * np.pi * 2,),
|
||||
"polarisations": self.data_settings.polarisations,
|
||||
"randomise_polarisation": self.data_settings.randomise_polarisations,
|
||||
}
|
||||
)
|
||||
self._sps = self._plot_loader.dataset.samples_per_symbol
|
||||
fiber_length = float(self._plot_loader.dataset.config["fiber"]["length"])/1000
|
||||
self.data_settings = data_settings_backup
|
||||
self.pytorch_settings = pytorch_settings_backup
|
||||
|
||||
fiber_in, fiber_out, regen, timestamps = self.run_model(model, self._plot_loader)
|
||||
fiber_in, fiber_out, fiber_out_rot, angles, regen, timestamps = self.run_model(model, self._plot_loader)
|
||||
fiber_in = fiber_in.view(-1, 2)
|
||||
fiber_out = fiber_out.view(-1, 2)
|
||||
fiber_out_rot = fiber_out_rot.view(-1, 2)
|
||||
angles = angles.view(-1, 1)
|
||||
angles = angles.real
|
||||
angles = torch.fmod(angles, 2*torch.pi)
|
||||
angles = torch.div(angles, 2*torch.pi)
|
||||
angles = torch.repeat_interleave(angles, 2, dim=1)
|
||||
|
||||
regen = regen.view(-1, 2)
|
||||
|
||||
fiber_in = fiber_in.numpy()
|
||||
fiber_out = fiber_out.numpy()
|
||||
fiber_out_rot = fiber_out_rot.numpy()
|
||||
angles = angles.numpy()
|
||||
regen = regen.numpy()
|
||||
timestamps = timestamps.numpy()
|
||||
|
||||
@@ -1327,31 +1611,34 @@ class RegenerationTrainer:
|
||||
import gc
|
||||
|
||||
head_fig = self._plot_model_response_head(
|
||||
fiber_in[:self.pytorch_settings.head_symbols*self._sps],
|
||||
fiber_out[:self.pytorch_settings.head_symbols*self._sps],
|
||||
regen[:self.pytorch_settings.head_symbols*self._sps],
|
||||
timestamps=timestamps[:self.pytorch_settings.head_symbols*self._sps],
|
||||
fiber_out[: self.pytorch_settings.head_symbols * self._sps],
|
||||
fiber_in[: self.pytorch_settings.head_symbols * self._sps],
|
||||
regen[: self.pytorch_settings.head_symbols * self._sps],
|
||||
angles[: self.pytorch_settings.head_symbols * self._sps],
|
||||
timestamps=timestamps[: self.pytorch_settings.head_symbols * self._sps],
|
||||
labels=("fiber out", "fiber in", "regen", "normed angle"),
|
||||
sps=self._sps,
|
||||
title_append=title_append + f" ({fiber_length} km)",
|
||||
subtitle=subtitle,
|
||||
show=show,
|
||||
)
|
||||
# raise NotImplementedError("Eye diagram not implemented")
|
||||
eye_fig = self._plot_model_response_eye(
|
||||
fiber_in[: self.pytorch_settings.eye_symbols * self._sps],
|
||||
fiber_out[: self.pytorch_settings.eye_symbols * self._sps],
|
||||
regen[: self.pytorch_settings.eye_symbols * self._sps],
|
||||
timestamps=timestamps[: self.pytorch_settings.eye_symbols * self._sps],
|
||||
labels=("fiber in", "fiber out", "regen"),
|
||||
sps=self._sps,
|
||||
title_append=title_append + f" ({fiber_length} km)",
|
||||
subtitle=subtitle,
|
||||
show=show,
|
||||
)
|
||||
# raise NotImplementedError("Eye diagram not implemented")
|
||||
eye_fig = self._plot_model_response_eye(
|
||||
fiber_in[:self.pytorch_settings.eye_symbols*self._sps],
|
||||
fiber_out[:self.pytorch_settings.eye_symbols*self._sps],
|
||||
regen[:self.pytorch_settings.eye_symbols*self._sps],
|
||||
timestamps=timestamps[:self.pytorch_settings.eye_symbols*self._sps],
|
||||
labels=("fiber in", "fiber out", "regen"),
|
||||
sps=self._sps,
|
||||
title_append=title_append + f" ({fiber_length} km)",
|
||||
subtitle=subtitle,
|
||||
show=show,
|
||||
)
|
||||
|
||||
weight_fig = self._plot_model_weights(model, title_append=title_append, subtitle=subtitle, show=show)
|
||||
gc.collect()
|
||||
|
||||
return head_fig, eye_fig, power_fig
|
||||
return head_fig, eye_fig, weight_fig, power_fig
|
||||
|
||||
def build_title(self, number: int):
|
||||
title_append = f"epoch {number}"
|
||||
@@ -1361,7 +1648,7 @@ class RegenerationTrainer:
|
||||
self.model_settings.overrides.get(f"n_hidden_nodes_{i}", -1) for i in range(model_n_hidden_layers)
|
||||
]
|
||||
model_dims.insert(0, input_dim)
|
||||
model_dims.append(2)
|
||||
model_dims.append(self.model_settings.output_dim)
|
||||
model_dims = [str(dim) for dim in model_dims]
|
||||
model_activation_func = self.model_settings.model_activation_func
|
||||
model_dtype = self.data_settings.dtype
|
||||
|
||||
217
src/single-core-regen/plot_model.py
Normal file
217
src/single-core-regen/plot_model.py
Normal file
@@ -0,0 +1,217 @@
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
from matplotlib import pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
import util
|
||||
from hypertraining.settings import GlobalSettings, DataSettings, ModelSettings, OptimizerSettings, PytorchSettings
|
||||
from hypertraining import models
|
||||
|
||||
# def move_to_location_in_size(array, location, size):
|
||||
# array_x, array_y = array.shape
|
||||
# location_x, location_y = location
|
||||
# size_x, size_y = size
|
||||
|
||||
# left = location_x
|
||||
# right = size_x - array_x - location_x
|
||||
|
||||
# top = location_y
|
||||
# bottom = size_y - array_y - location_y
|
||||
|
||||
# return np.pad(
|
||||
# array,
|
||||
# (
|
||||
# (left, right),
|
||||
# (top, bottom),
|
||||
# ),
|
||||
# constant_values=(-np.inf, -np.inf),
|
||||
# )
|
||||
|
||||
def register_puccs_cmap(puccs_path=None):
|
||||
puccs_path = Path(__file__).resolve().parent / 'puccs.csv' if puccs_path is None else puccs_path
|
||||
|
||||
colors = []
|
||||
# keys = None
|
||||
with open(puccs_path, "r") as f:
|
||||
for i, line in enumerate(f.readlines()):
|
||||
elements = tuple(line.split(","))
|
||||
# if i == 0:
|
||||
# # keys = elements
|
||||
# continue
|
||||
# else:
|
||||
try:
|
||||
colors.append(tuple(map(float, elements[4:])))
|
||||
except ValueError:
|
||||
continue
|
||||
# colors = []
|
||||
# for current in puccs_csv_data:
|
||||
# colors.append(tuple(current[4:]))
|
||||
from matplotlib.colors import LinearSegmentedColormap
|
||||
import matplotlib as mpl
|
||||
mpl.colormaps.register(LinearSegmentedColormap.from_list('puccs', colors))
|
||||
|
||||
def pad_to_size(array, size):
|
||||
if not hasattr(size, "__len__"):
|
||||
size = (size, size)
|
||||
|
||||
left = (
|
||||
(size[0] - array.shape[0] + 1) // 2 if size[0] is not None else 0
|
||||
)
|
||||
right = (
|
||||
(size[0] - array.shape[0]) // 2 if size[0] is not None else 0
|
||||
)
|
||||
top = (
|
||||
(size[1] - array.shape[1] + 1) // 2 if size[1] is not None else 0
|
||||
)
|
||||
bottom = (
|
||||
(size[1] - array.shape[1]) // 2 if size[1] is not None else 0
|
||||
)
|
||||
|
||||
array: np.ndarray = array
|
||||
if array.ndim == 2:
|
||||
return np.pad(
|
||||
array,
|
||||
(
|
||||
(left, right),
|
||||
(top, bottom),
|
||||
),
|
||||
constant_values=(np.nan, np.nan),
|
||||
)
|
||||
elif array.ndim == 3:
|
||||
return np.pad(
|
||||
array,
|
||||
(
|
||||
(left, right),
|
||||
(top, bottom),
|
||||
(0,0)
|
||||
),
|
||||
constant_values=(np.nan, np.nan),
|
||||
)
|
||||
|
||||
def model_plot(model_path, show=True):
|
||||
torch.serialization.add_safe_globals([
|
||||
*util.complexNN.__all__,
|
||||
GlobalSettings,
|
||||
DataSettings,
|
||||
ModelSettings,
|
||||
OptimizerSettings,
|
||||
PytorchSettings,
|
||||
models.regenerator,
|
||||
torch.nn.utils.parametrizations.orthogonal,
|
||||
])
|
||||
checkpoint_dict = torch.load(model_path, weights_only=True)
|
||||
|
||||
dims = checkpoint_dict["model_kwargs"].pop("dims")
|
||||
|
||||
model = models.regenerator(*dims, **checkpoint_dict["model_kwargs"])
|
||||
model.load_state_dict(checkpoint_dict["model_state_dict"], strict=False)
|
||||
|
||||
model_params = []
|
||||
plots = []
|
||||
max_size = np.max(dims)
|
||||
# max_act_size = np.max(dims[1:])
|
||||
|
||||
# angles = [None, None]
|
||||
# weights = [None, None]
|
||||
|
||||
for num, (layer_name, layer) in enumerate(model.named_children()):
|
||||
# each layer contains an "ONN" layer and an "activation" layer
|
||||
# activation layer is approximately the same for all layers and nodes -> rotation by 90 degrees
|
||||
onn_weights = layer.ONN.weight
|
||||
onn_weights = onn_weights.detach().cpu().numpy()
|
||||
onn_values = np.abs(onn_weights).real
|
||||
onn_angles = np.mod(np.angle(onn_weights), 2*np.pi).real
|
||||
|
||||
model_params.append({layer_name: onn_weights})
|
||||
plots.append({layer_name: (num, onn_values, onn_angles)})#, act_values, act_angles)})
|
||||
|
||||
# fig, axs = plt.subplots(3, len(model_params)*2-1, figsize=(20, 5))
|
||||
|
||||
for plot in plots:
|
||||
layer_name, (num, onn_values, onn_angles) = plot.popitem()
|
||||
|
||||
if num == 0:
|
||||
value_img = onn_values
|
||||
angle_img = onn_angles
|
||||
onn_angles = pad_to_size(onn_angles, (max_size, None))
|
||||
onn_values = pad_to_size(onn_values, (max_size, None))
|
||||
else:
|
||||
onn_values = pad_to_size(onn_values, (max_size, onn_values.shape[1]+1))
|
||||
onn_angles = pad_to_size(onn_angles, (max_size, onn_angles.shape[1]+1))
|
||||
value_img = np.concatenate((value_img, onn_values), axis=1)
|
||||
angle_img = np.concatenate((angle_img, onn_angles), axis=1)
|
||||
|
||||
value_img = np.ma.array(value_img, mask=np.isnan(value_img))
|
||||
angle_img = np.ma.array(angle_img, mask=np.isnan(angle_img))
|
||||
|
||||
# from cmcrameri import cm
|
||||
from cmap import Colormap as cm
|
||||
import scicomap as sc
|
||||
# from matplotlib import colors as mcolors
|
||||
# alpha_map = mcolors.LinearSegmentedColormap(
|
||||
# 'alphamap',
|
||||
# {
|
||||
# 'red': [(0, 0, 0), (1, 0, 0)],
|
||||
# 'green': [(0, 0, 0), (1, 0, 0)],
|
||||
# 'blue': [(0, 0, 0), (1, 0, 0)],
|
||||
# 'alpha': [
|
||||
# (0, 1, 1),
|
||||
# # (0.2, 0.2, 0.1),
|
||||
# (1, 0, 0)
|
||||
# ]
|
||||
# }
|
||||
# )
|
||||
# alpha_map.set_bad(color="#AAAAAA")
|
||||
|
||||
|
||||
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
||||
|
||||
fig, axs = plt.subplots(1, 2, sharex=True, sharey=True, figsize=(16, 5))
|
||||
# fig.tight_layout()
|
||||
dividers = map(make_axes_locatable, axs)
|
||||
caxs = list(map(lambda d: d.append_axes("right", size="5%", pad=0.05), dividers))
|
||||
# masked_value_img = np.ma.masked_where(np.isnan(value_img), value_img)
|
||||
masked_value_img = value_img
|
||||
cmap = cm('google:turbo').to_matplotlib()
|
||||
# cmap = sc.ScicoSequential("rainbow").get_mpl_color_map()
|
||||
cmap.set_bad(color="#AAAAAA")
|
||||
im_val = axs[0].imshow(masked_value_img, cmap=cmap, vmin=0, vmax=1)
|
||||
fig.colorbar(im_val, cax=caxs[0], orientation="vertical")
|
||||
|
||||
|
||||
masked_angle_img = np.ma.masked_where(np.isnan(angle_img), angle_img)
|
||||
# cmap = cm('crameri:romao').to_matplotlib()
|
||||
# cmap = plt.get_cmap('puccs')
|
||||
# cmap = sc.ScicoCircular("colorwheel").get_mpl_color_map()
|
||||
cmap = cm('colorcet:CET_C8').to_matplotlib()
|
||||
cmap.set_bad(color="#AAAAAA")
|
||||
im_ang = axs[1].imshow(masked_angle_img, cmap=cmap, vmin=0, vmax=2*np.pi)
|
||||
cbar = fig.colorbar(im_ang, cax=caxs[1], orientation="vertical", ticks=[i/8 * 2*np.pi for i in range(9)])
|
||||
cbar.ax.set_yticklabels(["0", "π/4", "π/2", "3π/4", "π", "5π/4", "3π/2", "7π/4", "2π"])
|
||||
# im_ang_w = axs[2].imshow(masked_angle_img, cmap=cmap)
|
||||
# im_ang_w = axs[2].imshow(masked_value_img, cmap=alpha_map)
|
||||
|
||||
axs[0].axis("off")
|
||||
axs[1].axis("off")
|
||||
# axs[2].axis("off")
|
||||
|
||||
axs[0].set_title("Values")
|
||||
axs[1].set_title("Angles")
|
||||
# axs[2].set_title("Values and Angles")
|
||||
|
||||
|
||||
...
|
||||
if show:
|
||||
plt.show()
|
||||
return fig
|
||||
|
||||
# model = models.regenerator(*dims, **model_kwargs)
|
||||
|
||||
if __name__ == "__main__":
|
||||
register_puccs_cmap()
|
||||
if len(sys.argv) > 1:
|
||||
model_plot(sys.argv[1])
|
||||
else:
|
||||
print("Please provide a model path as an argument")
|
||||
# model_plot(".models/best_20250114_224234.tar")
|
||||
102
src/single-core-regen/puccs.csv
Normal file
102
src/single-core-regen/puccs.csv
Normal file
@@ -0,0 +1,102 @@
|
||||
"x","L","a","b","R","G","B"
|
||||
0.,0.5187848173343539,0.6399990176455989,0.67,0.8889427469969852,0.22673227640012172,0.
|
||||
0.01,0.5374499525557803,0.604014067614707,0.6777967519386492,0.8956274406155226,0.27553288030331824,0.
|
||||
0.02,0.5560867887452998,0.5680836759482211,0.6855816828789898,0.9019507507843885,0.318608215541461,0.
|
||||
0.03,0.5746877595125583,0.5322224300667823,0.6933516322080414,0.907905487190649,0.3580633000693721,0.
|
||||
0.04,0.5932314662487472,0.49647158484797804,0.7010976613543587,0.9134808162089558,0.3949845524063657,0.
|
||||
0.05,0.6117000836392819,0.46086550613202343,0.7088123243737041,0.918668356138916,0.43002019316005363,0.
|
||||
0.06,0.6300828534995973,0.4254249348741487,0.7164911273850869,0.923462736751354,0.4635961938811463,0.
|
||||
0.07,0.6483763163456417,0.3901565406944371,0.7241326253017896,0.9278609626724071,0.49601354353255284,0.
|
||||
0.08,0.6665840140182806,0.3550534951951814,0.7317382976124045,0.9318616057744784,0.5274983630587982,0.
|
||||
0.09,0.6847162776119433,0.3200958808181962,0.7393124597949372,0.9354640163365924,0.5582303922647159,0.
|
||||
0.1,0.7027902128942014,0.2852507189547545,0.7468622572263107,0.9386675557407496,0.5883604892249517,0.004034952213848706
|
||||
0.11,0.7208298719332069,0.25047163906104203,0.7543977368741345,0.9414708123927996,0.6180221032545026,0.016031521294251994
|
||||
0.12,0.7388665670611175,0.2156982733607376,0.7619319784446927,0.943870754968487,0.6473392272576862,0.029857267582036696
|
||||
0.13,0.7569392765472108,0.18085547473834482,0.7694812638396673,0.9458617774020323,0.676432172396153,0.045365670193636125
|
||||
0.14,0.7750950944867471,0.14585244938794778,0.7770652650825484,0.9474345911958609,0.7054219201084561,0.06017985923530026
|
||||
0.15,0.793389684293558,0.11058188251425949,0.7847072337503834,0.9485749196617762,0.7344334940032564,0.07418869502646075
|
||||
0.16,0.8117919447684838,0.07510373484536464,0.792394178330817,0.9492596163836376,0.7634480277996188,0.08767517868137237
|
||||
0.17,0.8293050962981561,0.03629277424762101,0.799038155466063,0.9462308253550155,0.7922009241807345,0.10066327128139077
|
||||
0.18,0.8213303100752708,-0.0062517290795987,0.7879999288492758,0.9088702681901394,0.7940579017644396,0.10139639009534024
|
||||
0.19,0.8134831311534617,-0.048115463155645855,0.7771383286984362,0.8716809050191757,0.7954897210083888,0.10232311621802098
|
||||
0.2,0.80558613530069,-0.0902449644291895,0.7662077749032042,0.8337524177888596,0.7965471523787845,0.10344968926026826
|
||||
0.21,0.7975860185564765,-0.13292460297117392,0.7551344872795225,0.7947193410849823,0.7972381033243311,0.10477682283894393
|
||||
0.22,0.7894147026971006,-0.17651756772919341,0.7438242359834689,0.7540941866826836,0.7975605026647324,0.10631182441371936
|
||||
0.23,0.7809997374598548,-0.2214103719409295,0.7321767396537806,0.7112894518675287,0.7974995317311054,0.1080672415170634
|
||||
0.24,0.7722646970273015,-0.2680107379394189,0.7200862142018722,0.6655745739336695,0.7970267795229349,0.11006041388465265
|
||||
0.25,0.7631307298557146,-0.3167393290089981,0.7074435179925446,0.6160047476007512,0.7960993904970947,0.11231257117602686
|
||||
0.26,0.7535192192483822,-0.36801555555407994,0.6941398344519211,0.5612859274945571,0.794659599537827,0.11484733363789801
|
||||
0.27,0.7433557597838075,-0.42223636134393283,0.6800721760037781,0.4994862901720824,0.7926351396848288,0.11768844813479104
|
||||
0.28,0.732575139048096,-0.479749646583324,0.6651502794883674,0.42731393423789277,0.7899410218414098,0.12085678487511567
|
||||
0.29,0.7211269294461059,-0.5408244362880141,0.6493043460161184,0.3378265607222193,0.786483110019224,0.124366774034814
|
||||
0.3,0.7090756028785993,-0.6051167807996883,0.6326236137723747,0.2098475715121697,0.7821998608677176,0.12819222127525928
|
||||
0.31,0.7094510768540225,-0.6165036055456403,0.5630307498747129,0.15061488620640032,0.7845112116922692,0.21943537230975235
|
||||
0.32,0.7174669421288304,-0.5917687864932311,0.4797229624661701,0.18766933782916642,0.7905828987725732,0.31091344246312086
|
||||
0.33,0.7249009746435938,-0.5688293479200438,0.40246208306061504,0.21160609617940718,0.7962175427587832,0.38519766326885596
|
||||
0.34,0.7317072855135611,-0.5478268906666535,0.3317250285377912,0.22717569971119178,0.8013847719431052,0.4490960048955565
|
||||
0.35,0.7379328517830899,-0.5286164561226088,0.26702357292455026,0.23690087622812972,0.8061220291668977,0.5056371468159843
|
||||
0.36,0.7436229063122554,-0.5110584677642499,0.20788761731555405,0.24226377668817778,0.8104638164122776,0.5563570758573497
|
||||
0.37,0.7488251728809415,-0.4950056627547577,0.15382117501783654,0.24424372086048424,0.8144455902164638,0.6022301663745243
|
||||
0.38,0.7535943992285348,-0.48028910419451787,0.10425526029155024,0.24352232677523483,0.818107753931944,0.6440238320299774
|
||||
0.39,0.757994865186593,-0.4667104416936734,0.05852182167144754,0.240562414747303,0.8214980148949816,0.6824536572462205
|
||||
0.4,0.7620994844391137,-0.4540446830999986,0.015863077249098356,0.2356325204239052,0.8246710357361025,0.7182393675419642
|
||||
0.41,0.7659871096124125,-0.4420485102716773,-0.024540477496154123,0.22880568593963535,0.8276865975886148,0.7521146815529202
|
||||
0.42,0.7697410958994951,-0.4304647113488041,-0.06355514164248566,0.21993360985514526,0.8306086550266585,0.7848331944479765
|
||||
0.43,0.773446484628189,-0.4190308715098135,-0.10206473803580057,0.20858849290850018,0.833503273690861,0.8171544357676854
|
||||
0.44,0.7771893686864673,-0.4074813310994203,-0.14096401824224686,0.1939295692427068,0.8364382500400466,0.8498448067259188
|
||||
0.45,0.7810574093604746,-0.3955455908045306,-0.18116403397486242,0.17438366103820427,0.839483669055626,0.8836865023336339
|
||||
0.46,0.7851360804917298,-0.3829599011818591,-0.2235531031349741,0.14679145002531463,0.8427091517444469,0.9194481212717681
|
||||
0.47,0.789525027020907,-0.369416784561489,-0.26916682191206776,0.10278921007810798,0.8461971304126237,0.9580316568065935
|
||||
0.48,0.7942371698732826,-0.35487637041943493,-0.3181394757087982,0.0013920913109500188,0.8499626968466341,0.9995866371771526
|
||||
0.49,0.7773897680996302,-0.31852357140025195,-0.34537976514700053,0.10740420703601522,0.8254781216972907,1.
|
||||
0.5,0.7604011244310231,-0.28211213216592784,-0.3722846952738428,0.1581725581872408,0.8008522647497104,1.
|
||||
0.51,0.7433440454962605,-0.2455540169176899,-0.3992980063927199,0.19300141807932156,0.7761561224913385,1.
|
||||
0.52,0.7262590833969331,-0.20893614020926626,-0.42635547610418184,0.2194621842292243,0.751443124097109,1.
|
||||
0.53,0.709058602701224,-0.17207067467417486,-0.453595892719742,0.2405673704012788,0.7265803324554873,1.
|
||||
0.54,0.6915768892539101,-0.1346024482921609,-0.48128169789479536,0.25788347992973676,0.701321051230534,1.
|
||||
0.55,0.6736331627810209,-0.09614399811510127,-0.5096991935104321,0.2722888922216317,0.6753950894563805,1.
|
||||
0.56,0.6551463184003872,-0.05652149358027936,-0.5389768254408652,0.28422807900785235,0.6486730893521468,1.
|
||||
0.57,0.6361671326276888,-0.01584376303510615,-0.5690341788729347,0.293907374075009,0.6212117649042732,1.
|
||||
0.58,0.6168396823565967,0.025580396234342995,-0.5996430791016598,0.301442767979156,0.5931976878638505,1.
|
||||
0.59,0.5973210287815495,0.06741435793529688,-0.6305547881733555,0.30694603901024253,0.5648312189065924,1.
|
||||
0.6,0.5777303704171711,0.10940264614179468,-0.661580531294122,0.3105418468883679,0.5362525958007331,1.
|
||||
0.61,0.5581475370499237,0.15137416317967575,-0.6925938819599547,0.3123531986526998,0.5075386530652202,1.
|
||||
0.62,0.5386227795100639,0.19322120739317136,-0.7235152578861672,0.31248922600720636,0.4787151440558522,1.
|
||||
0.63,0.5191666876024412,0.23492108185347996,-0.754327887989376,0.31103663081260624,0.44973844514160927,1.
|
||||
0.64,0.4996990584326256,0.2766456839100268,-0.7851587896650079,0.30803814950244496,0.4204116611935119,1.
|
||||
0.65,0.479957679121191,0.3189570094767831,-0.8164232296840259,0.30343473603466015,0.390226489453496,1.
|
||||
0.66,0.4600072725872886,0.3617163391430824,-0.8480187063016573,0.29717122075330515,0.3591178757512998,1.
|
||||
0.67,0.44600100870220305,0.4113853615984094,-0.8697728377551008,0.3178994129506999,0.3295740682997879,1.
|
||||
0.68,0.4574651571354146,0.44026390446569547,-0.8504539292487465,0.3842479358768364,0.3280946443367561,1.
|
||||
0.69,0.4691809168948424,0.46977626401045774,-0.830711015748157,0.44293649140770447,0.3260767554252525,1.
|
||||
0.7,0.4811696900083858,0.49997635259991063,-0.8105080314416201,0.49708450874457527,0.3234487047238236,1.
|
||||
0.71,0.49350094811609174,0.5310391714342613,-0.7897279055963483,0.5485591109413528,0.3201099534066949,1.
|
||||
0.72,0.5062548753068121,0.5631667067020758,-0.7682355153041539,0.5985798481027601,0.3159263917472715,1.
|
||||
0.73,0.5195243020949684,0.5965928013272943,-0.7458744264238399,0.6480500606439057,0.31071717884730565,1.
|
||||
0.74,0.5334043922713477,0.6315571758288618,-0.7224842728734379,0.6976685401842261,0.3042411890803418,1.
|
||||
0.75,0.5479805812358602,0.6682750446095802,-0.697921082452685,0.7479712773579563,0.29618040787504757,1.
|
||||
0.76,0.5633244502526606,0.7069267230777347,-0.6720642293775535,0.7993701361353484,0.28611136999256687,1.
|
||||
0.77,0.5794956601139,0.7476624986056212,-0.6448131757501174,0.8521918014427678,0.2734527325942473,1.
|
||||
0.78,0.5965429098573916,0.7906050455688622,-0.6160858559672187,0.9067003897516911,0.2573693489198746,1.
|
||||
0.79,0.6145761476424179,0.8360313267658297,-0.5856969899409387,0.963334644317004,0.23648492980159264,1.
|
||||
0.8,0.6232910688128902,0.859291371252556,-0.5300995185388214,1.,0.21867949406239662,0.9712088595948508
|
||||
0.81,0.6159984336377875,0.8439887543380684,-0.44635440435952856,1.,0.21606849746358275,0.9041480210597966
|
||||
0.82,0.6091642745073532,0.8296481879180277,-0.36787420852419694,1.,0.21421830096504035,0.8419706002336461
|
||||
0.83,0.6025478038652375,0.8157644115969636,-0.2918938425681935,1.,0.21295365915197917,0.7823908751330636
|
||||
0.84,0.5961857222953111,0.8024144366282877,-0.21883475834162458,0.9971140114799418,0.21220068235083267,0.7256713129328118
|
||||
0.85,0.5900921771070883,0.7896279492437488,-0.1488594167412921,0.993273906363258,0.2118788857127918,0.671860243327784
|
||||
0.86,0.5842771639541229,0.7774259239818333,-0.08208260304413262,0.9887084084529413,0.21191070453347688,0.6209624706933893
|
||||
0.87,0.578741582584259,0.7658102488427286,-0.018514649521559012,0.9835846378805114,0.2122246941077346,0.5728987835613306
|
||||
0.88,0.5734741590353537,0.7547572669288056,0.04197390858426542,0.9780378159372328,0.21275878699579343,0.5274829957183049
|
||||
0.89,0.5684517008574971,0.7442183119942206,0.09964940221121898,0.9721670725313721,0.21346242315895625,0.4844270603851604
|
||||
0.9,0.5636419856510335,0.7341257696545772,0.15488185789614228,0.9660363209686843,0.21429691147008262,0.4433660148378527
|
||||
0.91,0.5590069340453534,0.7243997354573974,0.20810856081277884,0.9596781387247791,0.2152344151262528,0.4038812338146013
|
||||
0.92,0.5545051525321143,0.7149533506766244,0.25980485409830323,0.9530986696850675,0.21625626438013962,0.3655130449917989
|
||||
0.93,0.5500961975299247,0.705701749880514,0.3104351723857584,0.9462863346513658,0.21735046958786286,0.327780364198278
|
||||
0.94,0.545740378056064,0.6965616468647046,0.36045530782708896,0.93921469089265,0.21851014470332586,0.29014917175372823
|
||||
0.95,0.5414004092067859,0.6874548042588865,0.41029342232076466,0.9318478255642132,0.21973168075163751,0.2519897371806688
|
||||
0.96,0.5370416605957644,0.6783085548415655,0.46034719456417006,0.9241434776436454,0.22101341980094052,0.2124579038400577
|
||||
0.97,0.5326309593934517,0.6690532898786764,0.5109975653738162,0.9160532016485884,0.22235495330179011,0.17018252385769012
|
||||
0.98,0.5281374148557197,0.6596241892863608,0.5625992691950712,0.90752576202319,0.22375597459867458,0.1223073280126531
|
||||
0.99,0.5235317096396147,0.6499597345521199,0.615488972291106,0.8985077346125597,0.22521565729028564,0.05933950582860665
|
||||
1.,0.5187848173343539,0.6399990176455989,0.67,0.8889427469969852,0.22673227640012172,0.
|
||||
|
@@ -1,6 +1,8 @@
|
||||
from datetime import datetime
|
||||
|
||||
import optuna
|
||||
import torch
|
||||
import util
|
||||
from hypertraining.hypertraining import HyperTraining
|
||||
from hypertraining.settings import (
|
||||
GlobalSettings,
|
||||
@@ -16,24 +18,29 @@ global_settings = GlobalSettings(
|
||||
)
|
||||
|
||||
data_settings = DataSettings(
|
||||
config_path="data/*-128-16384-100000-0-0-17-0-PAM4-0.ini",
|
||||
# config_path="data/*-128-16384-100000-0-0-17-0-PAM4-0.ini",
|
||||
config_path="data/20241204-131003-128-16384-100000-0-0-17-0-PAM4-0.ini",
|
||||
dtype="complex64",
|
||||
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
|
||||
symbols=13, # study: single_core_regen_20241123_011232
|
||||
# symbols=13, # study: single_core_regen_20241123_011232
|
||||
# symbols = (3, 13),
|
||||
symbols=4,
|
||||
# output_size = (11, 32), # ballpark 26 taps -> 2 taps per input symbol -> 1 tap every 0.01m (model has 52 inputs)
|
||||
output_size=26, # study: single_core_regen_20241123_011232 (model_input_dim/2)
|
||||
# output_size=26, # study: single_core_regen_20241123_011232 (model_input_dim/2)
|
||||
output_size=(8, 30),
|
||||
shuffle=True,
|
||||
in_out_delay=0,
|
||||
xy_delay=0,
|
||||
drop_first=128 * 100,
|
||||
drop_first=256,
|
||||
train_split=0.8,
|
||||
randomise_polarisations=False,
|
||||
)
|
||||
|
||||
pytorch_settings = PytorchSettings(
|
||||
epochs=10000,
|
||||
epochs=10,
|
||||
batchsize=2**10,
|
||||
device="cuda",
|
||||
dataloader_workers=12,
|
||||
dataloader_workers=4,
|
||||
dataloader_prefetch=4,
|
||||
summary_dir=".runs",
|
||||
write_every=2**5,
|
||||
@@ -43,28 +50,70 @@ pytorch_settings = PytorchSettings(
|
||||
|
||||
model_settings = ModelSettings(
|
||||
output_dim=2,
|
||||
# n_hidden_layers = (3, 8),
|
||||
n_hidden_layers=4,
|
||||
overrides={
|
||||
"n_hidden_nodes_0": 8,
|
||||
"n_hidden_nodes_1": 6,
|
||||
"n_hidden_nodes_2": 4,
|
||||
"n_hidden_nodes_3": 8,
|
||||
},
|
||||
model_activation_func="Mag",
|
||||
# satabsT0=(1e-6, 1),
|
||||
n_hidden_layers = (2, 5),
|
||||
n_hidden_nodes=(2, 16),
|
||||
model_activation_func="EOActivation",
|
||||
dropout_prob=0,
|
||||
model_layer_function="ONNRect",
|
||||
model_layer_kwargs={"square": True},
|
||||
# scale=(False, True),
|
||||
scale=False,
|
||||
model_layer_parametrizations=[
|
||||
{
|
||||
"tensor_name": "weight",
|
||||
"parametrization": util.complexNN.energy_conserving,
|
||||
},
|
||||
{
|
||||
"tensor_name": "alpha",
|
||||
"parametrization": util.complexNN.clamp,
|
||||
},
|
||||
{
|
||||
"tensor_name": "gain",
|
||||
"parametrization": util.complexNN.clamp,
|
||||
"kwargs": {
|
||||
"min": 0,
|
||||
"max": float("inf"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"tensor_name": "phase_bias",
|
||||
"parametrization": util.complexNN.clamp,
|
||||
"kwargs": {
|
||||
"min": 0,
|
||||
"max": 2 * torch.pi,
|
||||
},
|
||||
},
|
||||
{
|
||||
"tensor_name": "scales",
|
||||
"parametrization": util.complexNN.clamp,
|
||||
},
|
||||
{
|
||||
"tensor_name": "angle",
|
||||
"parametrization": util.complexNN.clamp,
|
||||
"kwargs": {
|
||||
"min": -torch.pi,
|
||||
"max": torch.pi,
|
||||
},
|
||||
},
|
||||
{
|
||||
"tensor_name": "loss",
|
||||
"parametrization": util.complexNN.clamp,
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
optimizer_settings = OptimizerSettings(
|
||||
optimizer="Adam",
|
||||
# learning_rate = (1e-5, 1e-1),
|
||||
learning_rate=5e-3
|
||||
# learning_rate=5e-4,
|
||||
optimizer="AdamW",
|
||||
optimizer_kwargs={
|
||||
"lr": 5e-3,
|
||||
"amsgrad": True,
|
||||
# "weight_decay": 1e-7,
|
||||
},
|
||||
)
|
||||
|
||||
optuna_settings = OptunaSettings(
|
||||
n_trials=1,
|
||||
n_workers=1,
|
||||
n_trials=1024,
|
||||
n_workers=8,
|
||||
timeout=3600,
|
||||
directions=("minimize",),
|
||||
metrics_names=("mse",),
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from datetime import datetime
|
||||
# from datetime import datetime
|
||||
from pathlib import Path
|
||||
import matplotlib
|
||||
import numpy as np
|
||||
@@ -13,7 +13,7 @@ from hypertraining.settings import (
|
||||
OptimizerSettings,
|
||||
)
|
||||
|
||||
from hypertraining.training import RegenerationTrainer, PolarizationTrainer
|
||||
from hypertraining.training import RegenerationTrainer#, PolarizationTrainer
|
||||
|
||||
# import torch
|
||||
import json
|
||||
@@ -26,25 +26,39 @@ global_settings = GlobalSettings(
|
||||
)
|
||||
|
||||
data_settings = DataSettings(
|
||||
config_path="data/*-128-16384-10000-0-0-0-0-PAM4-0-0.4.ini",
|
||||
# config_path=[f"data/20241202-*-128-16384-{length}-0-0-17-0-PAM4-0.ini" for length in range(48000, 53000, 1000)],
|
||||
# config_path="data/20250115-114319-128-16384-inf-100000-0-0-0-0-PAM4-1-0-10.ini", # no effects enabled - baseline
|
||||
# config_path = "data/20250115-233553-128-16384-1060.0-100000-0-0.2-17.0-0.058-PAM4-1.0-0.0-10.ini", # dispersion + slope only
|
||||
# config_path="data/20250115-115836-128-16384-60.0-100000-0-0.2-17-0.058-PAM4-1000-0.2-10.ini", # all linear effects enabled with realistic values + noise + pmd (delta_beta=0.2) + ortho_error = 0.1
|
||||
# config_path="data/20250118-225840-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini", # a)
|
||||
# config_path="data/20250116-214606-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini", # b)
|
||||
# config_path="data/20250116-214547-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini", # c)
|
||||
# config_path="data/20250117-143926-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini", # d) 10ps dgd
|
||||
config_path="data/20250120-105720-128-16384-inf-100000-0-0.2-17-0.058-PAM4-0-0-10.ini", # d) 10ns
|
||||
|
||||
# config_path="data/20250114-215547-128-16384-60.0-100000-1.15-0.2-17-0.058-PAM4-1-0-10.ini", # with gamma=1.15, 2.5dBm launch power, no pmd
|
||||
|
||||
|
||||
dtype="complex64",
|
||||
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
|
||||
symbols=13, # study: single_core_regen_20241123_011232
|
||||
symbols=4, # study: single_core_regen_20241123_011232 -> taps spread over 4 symbols @ 10GBd
|
||||
# output_size = (11, 32), # 26 taps -> 2 taps per input symbol -> 1 tap every 1cm (model has 52 inputs (x/y))
|
||||
output_size=26, # study: single_core_regen_20241123_011232 (model_input_dim/2)
|
||||
shuffle=True,
|
||||
drop_first=64,
|
||||
output_size=20, # = model_input_dim/2 study: single_core_regen_20241123_011232
|
||||
shuffle=False,
|
||||
drop_first=256,
|
||||
drop_last=256,
|
||||
train_split=0.8,
|
||||
randomise_polarisations=True,
|
||||
randomise_polarisations=False,
|
||||
polarisations=False,
|
||||
# cross_pol_interference=0.01,
|
||||
osnr=16, #16dB due to amplification with NF 5
|
||||
)
|
||||
|
||||
pytorch_settings = PytorchSettings(
|
||||
epochs=10000,
|
||||
batchsize=2**14,
|
||||
epochs=1000,
|
||||
batchsize=2**13,
|
||||
device="cuda",
|
||||
dataloader_workers=16,
|
||||
dataloader_prefetch=8,
|
||||
dataloader_workers=32,
|
||||
dataloader_prefetch=4,
|
||||
summary_dir=".runs",
|
||||
write_every=2**5,
|
||||
save_models=True,
|
||||
@@ -53,80 +67,51 @@ pytorch_settings = PytorchSettings(
|
||||
|
||||
model_settings = ModelSettings(
|
||||
output_dim=2,
|
||||
n_hidden_layers=5,
|
||||
n_hidden_layers=3,
|
||||
overrides={
|
||||
# "hidden_layer_dims": (8, 8, 4, 4),
|
||||
"n_hidden_nodes_0": 8,
|
||||
"n_hidden_nodes_0": 16,
|
||||
"n_hidden_nodes_1": 8,
|
||||
"n_hidden_nodes_2": 4,
|
||||
"n_hidden_nodes_3": 4,
|
||||
"n_hidden_nodes_4": 2,
|
||||
"n_hidden_nodes_2": 8,
|
||||
# "n_hidden_nodes_3": 4,
|
||||
# "n_hidden_nodes_4": 2,
|
||||
},
|
||||
model_activation_func="EOActivation",
|
||||
dropout_prob=0.01,
|
||||
dropout_prob=0,
|
||||
model_layer_function="ONNRect",
|
||||
model_layer_kwargs={"square": True},
|
||||
scale=False,
|
||||
scale=2.0,
|
||||
model_layer_parametrizations=[
|
||||
{
|
||||
"tensor_name": "weight",
|
||||
"parametrization": util.complexNN.energy_conserving,
|
||||
},
|
||||
# EOactivation
|
||||
{
|
||||
"tensor_name": "alpha",
|
||||
"parametrization": util.complexNN.clamp,
|
||||
"kwargs": {
|
||||
"min": 0,
|
||||
"max": 1,
|
||||
},
|
||||
},
|
||||
# ONNRect
|
||||
{
|
||||
"tensor_name": "gain",
|
||||
"tensor_name": "weight",
|
||||
"parametrization": torch.nn.utils.parametrizations.orthogonal,
|
||||
},
|
||||
# Scale
|
||||
{
|
||||
"tensor_name": "scale",
|
||||
"parametrization": util.complexNN.clamp,
|
||||
"kwargs": {
|
||||
"min": 0,
|
||||
"max": float("inf"),
|
||||
"max": 10,
|
||||
},
|
||||
},
|
||||
{
|
||||
"tensor_name": "phase_bias",
|
||||
"parametrization": util.complexNN.clamp,
|
||||
"kwargs": {
|
||||
"min": 0,
|
||||
"max": 2 * torch.pi,
|
||||
},
|
||||
},
|
||||
{
|
||||
"tensor_name": "scales",
|
||||
"parametrization": util.complexNN.clamp,
|
||||
},
|
||||
{
|
||||
"tensor_name": "angle",
|
||||
"parametrization": util.complexNN.clamp,
|
||||
"kwargs": {
|
||||
"min": -torch.pi,
|
||||
"max": torch.pi,
|
||||
},
|
||||
},
|
||||
# {
|
||||
# "tensor_name": "scale",
|
||||
# "parametrization": util.complexNN.clamp,
|
||||
# },
|
||||
# {
|
||||
# "tensor_name": "bias",
|
||||
# "parametrization": util.complexNN.clamp,
|
||||
# },
|
||||
# {
|
||||
# "tensor_name": "V",
|
||||
# "parametrization": torch.nn.utils.parametrizations.orthogonal,
|
||||
# },
|
||||
{
|
||||
"tensor_name": "loss",
|
||||
"parametrization": util.complexNN.clamp,
|
||||
},
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
optimizer_settings = OptimizerSettings(
|
||||
optimizer="AdamW",
|
||||
optimizer_kwargs={
|
||||
"lr": 0.01,
|
||||
"lr": 0.005,
|
||||
"amsgrad": True,
|
||||
# "weight_decay": 1e-7,
|
||||
},
|
||||
@@ -134,107 +119,19 @@ optimizer_settings = OptimizerSettings(
|
||||
scheduler="ReduceLROnPlateau",
|
||||
scheduler_kwargs={
|
||||
"patience": 2**6,
|
||||
"factor": 0.75,
|
||||
"factor": 0.5,
|
||||
# "threshold": 1e-3,
|
||||
"min_lr": 1e-6,
|
||||
"cooldown": 10,
|
||||
},
|
||||
early_stopping=True,
|
||||
early_stop_kwargs={
|
||||
"threshold": 1e-06,
|
||||
"plateau": 2**7,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def save_dict_to_file(dictionary, filename):
|
||||
"""
|
||||
Save the best dictionary to a JSON file.
|
||||
|
||||
:param best: Dictionary containing the best training results.
|
||||
:type best: dict
|
||||
:param filename: Path to the JSON file where the dictionary will be saved.
|
||||
:type filename: str
|
||||
"""
|
||||
with open(filename, "w") as f:
|
||||
json.dump(dictionary, f, indent=4)
|
||||
|
||||
|
||||
def sweep_lengths(*lengths, model=None, data_glob: str = None, strategy="newest"):
|
||||
assert model is not None, "Model must be provided."
|
||||
assert data_glob is not None, "Data glob must be provided."
|
||||
model = model
|
||||
|
||||
fiber_ins = {}
|
||||
fiber_outs = {}
|
||||
regens = {}
|
||||
timestampss = {}
|
||||
|
||||
trainer = RegenerationTrainer(
|
||||
checkpoint_path=model,
|
||||
)
|
||||
trainer.define_model()
|
||||
|
||||
for length in lengths:
|
||||
data_glob_length = data_glob.replace("{length}", str(length))
|
||||
files = list(Path.cwd().glob(data_glob_length))
|
||||
if len(files) == 0:
|
||||
continue
|
||||
if strategy == "newest":
|
||||
sorted_kwargs = {
|
||||
"key": lambda x: x.stat().st_mtime,
|
||||
"reverse": True,
|
||||
}
|
||||
elif strategy == "oldest":
|
||||
sorted_kwargs = {
|
||||
"key": lambda x: x.stat().st_mtime,
|
||||
"reverse": False,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unknown strategy {strategy}.")
|
||||
file = sorted(files, **sorted_kwargs)[0]
|
||||
|
||||
loader, _ = trainer.get_sliced_data(override={"config_path": file})
|
||||
fiber_in, fiber_out, regen, timestamps = trainer.run_model(trainer.model, loader=loader)
|
||||
|
||||
fiber_ins[length] = fiber_in
|
||||
fiber_outs[length] = fiber_out
|
||||
regens[length] = regen
|
||||
timestampss[length] = timestamps
|
||||
|
||||
data = torch.zeros(2 * len(timestampss.keys()) + 2, 2, tuple(fiber_outs.values())[-1].shape[0])
|
||||
channel_names = ["" for _ in range(2 * len(timestampss.keys()) + 2)]
|
||||
|
||||
data[1, 0, :] = timestampss[tuple(timestampss.keys())[-1]] / 128
|
||||
data[1, 1, :] = fiber_ins[tuple(timestampss.keys())[-1]][:, 0].abs().square()
|
||||
|
||||
channel_names[1] = "fiber in x"
|
||||
|
||||
for li, length in enumerate(timestampss.keys()):
|
||||
data[2 + 2 * li, 0, :] = timestampss[length] / 128
|
||||
data[2 + 2 * li, 1, :] = fiber_outs[length][:, 0].abs().square()
|
||||
data[2 + 2 * li + 1, 0, :] = timestampss[length] / 128
|
||||
data[2 + 2 * li + 1, 1, :] = regens[length][:, 0].abs().square()
|
||||
|
||||
channel_names[2 + 2 * li + 1] = f"regen x {length}"
|
||||
channel_names[2 + 2 * li] = f"fiber out x {length}"
|
||||
|
||||
# get current backend
|
||||
backend = matplotlib.get_backend()
|
||||
|
||||
matplotlib.use("TkCairo")
|
||||
eye = util.eye_diagram.eye_diagram(data.to(dtype=torch.float32).detach().cpu().numpy(), channel_names=channel_names)
|
||||
|
||||
print_attrs = ("channel_name", "success", "min_area")
|
||||
with np.printoptions(precision=3, suppress=True, formatter={"float": "{:0.3e}".format}):
|
||||
for result in eye.eye_stats:
|
||||
print_dict = {attr: result[attr] for attr in print_attrs}
|
||||
rprint(print_dict)
|
||||
rprint()
|
||||
|
||||
eye.plot(all_stats=False)
|
||||
matplotlib.use(backend)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# lengths = range(90000, 100000+10000, 10000)
|
||||
# lengths = [100000]
|
||||
# sweep_lengths(*lengths, model=".models/best_20241204_132605.tar", data_glob="data/202412*-{length}-*-0.ini", strategy="newest")
|
||||
|
||||
trainer = RegenerationTrainer(
|
||||
global_settings=global_settings,
|
||||
@@ -242,67 +139,15 @@ if __name__ == "__main__":
|
||||
pytorch_settings=pytorch_settings,
|
||||
model_settings=model_settings,
|
||||
optimizer_settings=optimizer_settings,
|
||||
# checkpoint_path=".models/best_20241205_235929.tar",
|
||||
# 20241202_143149
|
||||
checkpoint_path=".models/best_20250117_144001.tar",
|
||||
new_model=True,
|
||||
settings_override={
|
||||
"data_settings": data_settings.__dict__,
|
||||
# "optimizer_settings": {
|
||||
# "early_stop_kwargs":{
|
||||
# "plateau": 2**8,
|
||||
# }
|
||||
# }
|
||||
}
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
# from hypertraining.lighning_models import regenerator, regeneratorData
|
||||
# import lightning as L
|
||||
|
||||
# model = regenerator(
|
||||
# 2 * data_settings.output_size,
|
||||
# *model_settings.overrides["hidden_layer_dims"],
|
||||
# model_settings.output_dim,
|
||||
# layer_function=getattr(util.complexNN, model_settings.model_layer_function),
|
||||
# layer_func_kwargs=model_settings.model_layer_kwargs,
|
||||
# act_function=getattr(util.complexNN, model_settings.model_activation_func),
|
||||
# act_func_kwargs=None,
|
||||
# parametrizations=model_settings.model_layer_parametrizations,
|
||||
# dtype=getattr(torch, data_settings.dtype),
|
||||
# dropout_prob=model_settings.dropout_prob,
|
||||
# scale_layers=model_settings.scale,
|
||||
# optimizer=getattr(torch.optim, optimizer_settings.optimizer),
|
||||
# optimizer_kwargs=optimizer_settings.optimizer_kwargs,
|
||||
# lr_scheduler=getattr(torch.optim.lr_scheduler, optimizer_settings.scheduler),
|
||||
# lr_scheduler_kwargs=optimizer_settings.scheduler_kwargs,
|
||||
# )
|
||||
|
||||
# dm = regeneratorData(
|
||||
# config_globs=data_settings.config_path,
|
||||
# output_symbols=data_settings.symbols,
|
||||
# output_dim=data_settings.output_size,
|
||||
# dtype=getattr(torch, data_settings.dtype),
|
||||
# drop_first=data_settings.drop_first,
|
||||
# shuffle=data_settings.shuffle,
|
||||
# train_split=data_settings.train_split,
|
||||
# batch_size=pytorch_settings.batchsize,
|
||||
# loader_settings={
|
||||
# "num_workers": pytorch_settings.dataloader_workers,
|
||||
# "prefetch_factor": pytorch_settings.dataloader_prefetch,
|
||||
# "pin_memory": True,
|
||||
# "drop_last": True,
|
||||
# },
|
||||
# seed=global_settings.seed,
|
||||
# )
|
||||
|
||||
# # writer = L.SummaryWriter(pytorch_settings.summary_dir + f"/{datetime.now().strftime('%Y%m%d_%H%M%S')}")
|
||||
|
||||
# # from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
# subdir = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||||
|
||||
# # writer = SummaryWriter(pytorch_settings.summary_dir + f"/{subdir}")
|
||||
|
||||
# logger = L.pytorch.loggers.TensorBoardLogger(pytorch_settings.summary_dir, name=subdir, log_graph=True)
|
||||
|
||||
# trainer = L.Trainer(
|
||||
# fast_dev_run=False,
|
||||
# # max_epochs=pytorch_settings.epochs,
|
||||
# max_epochs=2,
|
||||
# enable_checkpointing=True,
|
||||
# default_root_dir=f".models/{subdir}/",
|
||||
# logger=logger,
|
||||
# )
|
||||
|
||||
# trainer.fit(model, dm)
|
||||
|
||||
@@ -12,20 +12,22 @@ Full license text in LICENSE file
|
||||
"""
|
||||
|
||||
import configparser
|
||||
# import copy
|
||||
from datetime import datetime
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
import time
|
||||
import h5py
|
||||
from matplotlib import pyplot as plt # noqa: F401
|
||||
import numpy as np
|
||||
|
||||
import add_pypho # noqa: F401
|
||||
from . import add_pypho # noqa: F401
|
||||
import pypho
|
||||
|
||||
default_config = f"""
|
||||
[glova]
|
||||
nos = 256
|
||||
sps = 256
|
||||
sps = 128
|
||||
nos = 16384
|
||||
f0 = 193414489032258.06
|
||||
symbolrate = 10e9
|
||||
wisdom_dir = "{str((Path.home() / ".pypho"))}"
|
||||
@@ -37,9 +39,9 @@ length = 10000
|
||||
gamma = 1.14
|
||||
alpha = 0.2
|
||||
D = 17
|
||||
S = 0
|
||||
birefsteps = 0
|
||||
max_delta_beta = 0.4
|
||||
S = 0.058
|
||||
bireflength = 10
|
||||
pmd_q = 0.2
|
||||
; birefseed = 0xC0FFEE
|
||||
|
||||
[signal]
|
||||
@@ -47,17 +49,15 @@ max_delta_beta = 0.4
|
||||
|
||||
modulation = "pam"
|
||||
mod_order = 4
|
||||
mod_depth = 0.8
|
||||
|
||||
mod_depth = 1
|
||||
max_jitter = 0.02
|
||||
; jitter_seed = 0xC0FFEE
|
||||
|
||||
laser_power = 0
|
||||
edfa_power = 3
|
||||
edfa_power = 0
|
||||
edfa_nf = 5
|
||||
|
||||
pulse_shape = "gauss"
|
||||
fwhm = 0.33
|
||||
osnr = "inf"
|
||||
|
||||
[data]
|
||||
dir = "data"
|
||||
@@ -71,6 +71,7 @@ def get_config(config_file=None):
|
||||
"""
|
||||
if config_file is None:
|
||||
config_file = Path(__file__).parent / "signal_generation.ini"
|
||||
config_file = Path(config_file)
|
||||
if not config_file.exists():
|
||||
with open(config_file, "w") as f:
|
||||
f.write(default_config)
|
||||
@@ -83,7 +84,10 @@ def get_config(config_file=None):
|
||||
conf[section] = {}
|
||||
for key in config[section]:
|
||||
# print(f"{key} = {config[section][key]}")
|
||||
conf[section][key] = eval(config[section][key])
|
||||
try:
|
||||
conf[section][key] = eval(config[section][key])
|
||||
except NameError:
|
||||
conf[section][key] = float(config[section][key])
|
||||
# if isinstance(conf[section][key], str):
|
||||
# conf[section][key] = config[section][key].strip('"')
|
||||
return conf
|
||||
@@ -96,7 +100,9 @@ class PDM_IM_IPM:
|
||||
mod_order=8,
|
||||
seed=None,
|
||||
):
|
||||
assert np.cbrt(mod_order) == int(np.cbrt(mod_order)) and mod_order > 1, "mod_order must be a cube of an integer greater than 1"
|
||||
assert np.cbrt(mod_order) == int(np.cbrt(mod_order)) and mod_order > 1, (
|
||||
"mod_order must be a cube of an integer greater than 1"
|
||||
)
|
||||
self.glova = glova
|
||||
self.mod_order = mod_order
|
||||
self.symbols_per_dim = int(np.cbrt(mod_order))
|
||||
@@ -110,14 +116,7 @@ class PDM_IM_IPM:
|
||||
|
||||
class pam_generator:
|
||||
def __init__(
|
||||
self,
|
||||
glova,
|
||||
mod_order=None,
|
||||
mod_depth=0.5,
|
||||
pulse_shape="gauss",
|
||||
fwhm=0.33,
|
||||
seed=None,
|
||||
single_channel=False
|
||||
self, glova, mod_order=None, mod_depth=0.5, pulse_shape="gauss", fwhm=0.33, seed=None, single_channel=False
|
||||
) -> None:
|
||||
self.glova = glova
|
||||
self.pulse_shape = pulse_shape
|
||||
@@ -138,15 +137,13 @@ class pam_generator:
|
||||
symbols_x = symbols[0] / (self.mod_order)
|
||||
diffs_x = np.diff(symbols_x, prepend=symbols_x[0])
|
||||
digital_x = self.generate_digital_signal(diffs_x, max_jitter)
|
||||
digital_x = np.pad(
|
||||
digital_x, (0, self.glova.sps // 2), "constant", constant_values=(0, 0)
|
||||
)
|
||||
digital_x = np.pad(digital_x, (0, self.glova.sps // 2), "constant", constant_values=(0, 0))
|
||||
|
||||
# create analog signal of diff of symbols
|
||||
E_x = np.convolve(digital_x, wavelet)
|
||||
|
||||
# convert to pam and set modulation depth (scale and move up such that 1 stays at 1)
|
||||
E_x = np.cumsum(E_x) * self.modulation_depth + 2*(1 - self.modulation_depth)
|
||||
E_x = np.cumsum(E_x) * self.modulation_depth + (1 - self.modulation_depth)
|
||||
|
||||
# cut off the wavelet tails
|
||||
E_x = E_x[self.glova.sps // 2 + len(wavelet) // 2 - 1 : -len(wavelet) // 2]
|
||||
@@ -158,25 +155,21 @@ class pam_generator:
|
||||
symbols_y = symbols[1] / (self.mod_order)
|
||||
diffs_y = np.diff(symbols_y, prepend=symbols_y[0])
|
||||
digital_y = self.generate_digital_signal(diffs_y, max_jitter)
|
||||
digital_y = np.pad(
|
||||
digital_y, (0, self.glova.sps // 2), "constant", constant_values=(0, 0)
|
||||
)
|
||||
digital_y = np.pad(digital_y, (0, self.glova.sps // 2), "constant", constant_values=(0, 0))
|
||||
E_y = np.convolve(digital_y, wavelet)
|
||||
|
||||
E_y = np.cumsum(E_y) * self.modulation_depth + (1 - self.modulation_depth)
|
||||
|
||||
E_y = E_y[self.glova.sps // 2 + len(wavelet) // 2 - 1 : -len(wavelet) // 2]
|
||||
|
||||
|
||||
E[0]["E"][1] = np.sqrt(np.multiply(np.square(E[0]["E"][1]), E_y))
|
||||
|
||||
# rotate the signal on the y-polarisation by 90°
|
||||
E[0]["E"][1] *= 1j
|
||||
# E[0]["E"][1] *= 1j
|
||||
else:
|
||||
E[0]["E"][1] = np.zeros_like(E[0]["E"][0], dtype=E[0]["E"][0].dtype)
|
||||
return E
|
||||
|
||||
|
||||
def generate_digital_signal(self, symbols, max_jitter=0):
|
||||
rs = np.random.RandomState(self.seed)
|
||||
signal = np.zeros(self.glova.nos * self.glova.sps)
|
||||
@@ -198,19 +191,19 @@ class pam_generator:
|
||||
endpoint=True,
|
||||
)
|
||||
sigma = self.fwhm / (1 * np.sqrt(2 * np.log(2))) * self.glova.sps
|
||||
pulse = (
|
||||
1
|
||||
/ (sigma * np.sqrt(2 * np.pi))
|
||||
* np.exp(-np.square(sample_points) / (2 * np.square(sigma)))
|
||||
)
|
||||
pulse = 1 / (sigma * np.sqrt(2 * np.pi)) * np.exp(-np.square(sample_points) / (2 * np.square(sigma)))
|
||||
return pulse
|
||||
|
||||
|
||||
def initialize_fiber_and_data(config, input_data_override=None):
|
||||
def initialize_fiber_and_data(config):
|
||||
f0 = config["glova"].get("f0", None)
|
||||
if f0 is None:
|
||||
f0 = 299792458/(config["glova"].get("lambda0", 1550)*1e-9)
|
||||
config["glova"]["f0"] = f0
|
||||
py_glova = pypho.setup(
|
||||
nos=config["glova"]["nos"],
|
||||
sps=config["glova"]["sps"],
|
||||
f0=config["glova"]["f0"],
|
||||
f0=f0,
|
||||
symbolrate=config["glova"]["symbolrate"],
|
||||
wisdom_dir=config["glova"]["wisdom_dir"],
|
||||
flags=config["glova"]["flags"],
|
||||
@@ -221,49 +214,89 @@ def initialize_fiber_and_data(config, input_data_override=None):
|
||||
c_data = pypho.cfiber.DataWrapper(py_glova.sps * py_glova.nos)
|
||||
py_edfa = pypho.oamp(py_glova, Pmean=config["signal"]["edfa_power"], NF=config["signal"]["edfa_nf"])
|
||||
|
||||
if input_data_override is not None:
|
||||
c_data.E_in = input_data_override[0]
|
||||
noise = input_data_override[1]
|
||||
else:
|
||||
config["signal"]["seed"] = config["signal"].get(
|
||||
"seed", (int(time.time() * 1000)) % 2**32
|
||||
)
|
||||
config["signal"]["jitter_seed"] = config["signal"].get(
|
||||
"jitter_seed", (int(time.time() * 1000)) % 2**32
|
||||
)
|
||||
symbolsrc = pypho.symbols(
|
||||
py_glova, py_glova.nos, pattern="ones", p1=config["signal"]["mod_order"], seed=config["signal"]["seed"]
|
||||
)
|
||||
laser = pypho.lasmod(
|
||||
py_glova, power=config["signal"]["laser_power"]+1.5, Df=0, theta=np.pi / 4
|
||||
)
|
||||
modulator = pam_generator(
|
||||
py_glova,
|
||||
mod_depth=config["signal"]["mod_depth"],
|
||||
pulse_shape=config["signal"]["pulse_shape"],
|
||||
fwhm=config["signal"]["fwhm"],
|
||||
seed=config["signal"]["jitter_seed"],
|
||||
single_channel=False,
|
||||
mod_order=config["signal"]["mod_order"],
|
||||
osnr = config["signal"]["osnr"] = float(config["signal"].get("osnr", "inf"))
|
||||
|
||||
config["signal"]["seed"] = config["signal"].get("seed", (int(time.time() * 1000)) % 2**32)
|
||||
config["signal"]["jitter_seed"] = config["signal"].get("jitter_seed", (int(time.time() * 1000)) % 2**32)
|
||||
symbolsrc = pypho.symbols(
|
||||
py_glova, py_glova.nos, pattern="ones", p1=config["signal"]["mod_order"], seed=config["signal"]["seed"]
|
||||
)
|
||||
laserx = pypho.lasmod(py_glova, power=0, Df=0, theta=np.pi/4)
|
||||
# lasery = pypho.lasmod(py_glova, power=0, Df=25, theta=0)
|
||||
|
||||
modulator = pam_generator(
|
||||
py_glova,
|
||||
mod_depth=config["signal"]["mod_depth"],
|
||||
pulse_shape=config["signal"]["pulse_shape"],
|
||||
fwhm=config["signal"]["fwhm"],
|
||||
seed=config["signal"]["jitter_seed"],
|
||||
mod_order=config["signal"]["mod_order"],
|
||||
)
|
||||
|
||||
symbols_x = symbolsrc(pattern="random")
|
||||
symbols_y = symbolsrc(pattern="random")
|
||||
symbols_x[:3] = 0
|
||||
symbols_y[:3] = 0
|
||||
# symbols_x += 1
|
||||
|
||||
|
||||
cw = laserx()
|
||||
# cwy = lasery()
|
||||
# cw[0]['E'][0] = cw[0]['E'][0]
|
||||
# cw[0]['E'][1] = cwy[0]['E'][0]
|
||||
|
||||
|
||||
source_signal = modulator(E=cw, symbols=(symbols_x, symbols_y))
|
||||
|
||||
if osnr != float("inf"):
|
||||
osnr_lin = 10 ** (osnr / 10)
|
||||
signal_power = np.sum(pypho.functions.getpower_W(source_signal[0]["E"]))
|
||||
noise_power = signal_power / osnr_lin
|
||||
noise = np.random.normal(0, 1, source_signal[0]["E"].shape) + 1j * np.random.normal(
|
||||
0, 1, source_signal[0]["E"].shape
|
||||
)
|
||||
noise_power_is = np.sum(pypho.functions.getpower_W(noise))
|
||||
noise = noise * np.sqrt(noise_power / noise_power_is)
|
||||
noise_power_is = np.sum(pypho.functions.getpower_W(noise))
|
||||
source_signal[0]["E"] += noise
|
||||
source_signal[0]["noise"] = noise_power_is
|
||||
|
||||
symbols_x = symbolsrc(pattern="random")
|
||||
symbols_y = symbolsrc(pattern="random")
|
||||
symbols_x[:3] = 0
|
||||
symbols_y[:3] = 0
|
||||
# symbols_x += 1
|
||||
# source_signal[0]['E'][1] += source_signal[0]['E'][1][np.argmin(np.abs(source_signal[0]['E'][1]))]
|
||||
|
||||
## side channels
|
||||
# df = 100
|
||||
# signal_power = pypho.functions.W_to_dBm(np.sum(pypho.functions.getpower_W(source_signal[0]["E"])))
|
||||
|
||||
|
||||
cw = laser()
|
||||
# symbols_x_side = symbolsrc(pattern="random")
|
||||
# symbols_y_side = symbolsrc(pattern="random")
|
||||
# symbols_x_side[:3] = 0
|
||||
# symbols_y_side[:3] = 0
|
||||
|
||||
source_signal = modulator(E=cw, symbols=(symbols_x, symbols_y))
|
||||
# cw_left = laser(Df=-df)
|
||||
# source_signal_left = modulator(E=cw_left, symbols=(symbols_x_side, symbols_y_side))
|
||||
|
||||
# source_signal[0]['E'][1] += source_signal[0]['E'][1][np.argmin(np.abs(source_signal[0]['E'][1]))]
|
||||
# cw_right = laser(Df=df)
|
||||
# source_signal_right = modulator(E=cw_right, symbols=(symbols_y_side, symbols_x_side))
|
||||
|
||||
source_signal = py_edfa(E=source_signal)
|
||||
E_in_pure = source_signal[0]["E"]
|
||||
|
||||
c_data.E_in = source_signal[0]["E"]
|
||||
noise = source_signal[0]["noise"]
|
||||
nf = py_edfa.NF
|
||||
pmean = py_edfa.Pmean
|
||||
|
||||
# ideal amplification to launch power into fiber
|
||||
source_signal = py_edfa(E=source_signal, NF=0, Pmean=config["signal"]["laser_power"])
|
||||
# source_signal_left = py_edfa(E=source_signal_left, NF=0, Pmean=config["signal"]["laser_power"])
|
||||
# source_signal_right = py_edfa(E=source_signal_right, NF=0, Pmean=config["signal"]["laser_power"])
|
||||
|
||||
# source_signal[0]["E"][0] += source_signal_left[0]["E"][0] + source_signal_right[0]["E"][0]
|
||||
# source_signal[0]["E"][1] += source_signal_left[0]["E"][1] + source_signal_right[0]["E"][1]
|
||||
|
||||
c_data.E_in = source_signal[0]["E"]
|
||||
noise = source_signal[0]["noise"]
|
||||
|
||||
py_edfa.NF = nf
|
||||
py_edfa.Pmean = pmean
|
||||
|
||||
py_fiber = pypho.fiber(
|
||||
glova=py_glova,
|
||||
@@ -272,27 +305,32 @@ def initialize_fiber_and_data(config, input_data_override=None):
|
||||
gamma=config["fiber"]["gamma"],
|
||||
D=config["fiber"]["d"],
|
||||
S=config["fiber"]["s"],
|
||||
phi_max=0.02,
|
||||
)
|
||||
if config["fiber"].get("birefsteps", 0) > 0:
|
||||
seed = config["fiber"].get(
|
||||
"birefseed", (int(time.time() * 1000)) % 2**32
|
||||
)
|
||||
|
||||
config["fiber"]["birefsteps"] = config["fiber"].get(
|
||||
"birefsteps", config["fiber"]["length"] // config["fiber"].get("bireflength", config["fiber"]["length"])
|
||||
)
|
||||
if config["fiber"]["birefsteps"] > 0:
|
||||
config["fiber"]["bireflength"] = config["fiber"].get("bireflength", config["fiber"]["length"] / config["fiber"]["birefsteps"])
|
||||
seed = config["fiber"].get("birefseed", (int(time.time() * 1000)) % 2**32)
|
||||
py_fiber.birefarray = pypho.birefringence_segment.create_pmd_fibre(
|
||||
py_fiber.l,
|
||||
py_fiber.l / config["fiber"]["birefsteps"],
|
||||
# maxDeltaD=config["fiber"]["d"]/5,
|
||||
maxDeltaBeta = config["fiber"].get("max_delta_beta", 0),
|
||||
config["fiber"]["length"],
|
||||
config["fiber"]["bireflength"],
|
||||
maxDeltaBeta=config["fiber"].get("pmd_q", 0)/np.sqrt(2*config["fiber"]["bireflength"]),
|
||||
seed=seed,
|
||||
)
|
||||
c_params = pypho.cfiber.ParamsWrapper.from_fiber(
|
||||
py_fiber, max_step=1e3 if py_fiber.gamma == 0 else 200
|
||||
)
|
||||
elif (dgd := config['fiber'].get('dgd', 0)) > 0:
|
||||
py_fiber.birefarray = [
|
||||
pypho.birefringence_segment(z_point=0, angle=np.pi/2, delta_beta=1000*dgd/config["fiber"]["length"])
|
||||
]
|
||||
c_params = pypho.cfiber.ParamsWrapper.from_fiber(py_fiber, max_step=config["fiber"]["length"] if py_fiber.gamma == 0 else 200)
|
||||
c_fiber = pypho.cfiber.FiberWrapper(c_data, c_params, c_glova)
|
||||
|
||||
return c_fiber, c_data, noise, py_edfa
|
||||
return c_fiber, c_data, noise, py_edfa, (symbols_x, symbols_y), py_glova, E_in_pure
|
||||
|
||||
|
||||
def save_data(data, config):
|
||||
def save_data(data, config, **metadata):
|
||||
data_dir = Path(config["data"]["dir"])
|
||||
npy_dir = config["data"].get("npy_dir", "")
|
||||
save_dir = data_dir / npy_dir if len(npy_dir) else data_dir
|
||||
@@ -307,6 +345,7 @@ def save_data(data, config):
|
||||
seed = config["signal"].get("seed", False)
|
||||
jitter_seed = config["signal"].get("jitter_seed", False)
|
||||
birefseed = config["fiber"].get("birefseed", False)
|
||||
osnr = float(config["signal"].get("osnr", "inf"))
|
||||
|
||||
config_content = "\n".join((
|
||||
f"; Generated by {str(Path(__file__).name)} @ {timestamp.strftime('%Y-%m-%d %H:%M:%S')}",
|
||||
@@ -318,16 +357,19 @@ def save_data(data, config):
|
||||
f'wisdom_dir = "{config["glova"]["wisdom_dir"]}"',
|
||||
f'flags = "{config["glova"]["flags"]}"',
|
||||
f"nthreads = {config['glova']['nthreads']}",
|
||||
" ",
|
||||
"",
|
||||
"[fiber]",
|
||||
f"length = {config['fiber']['length']}",
|
||||
f"gamma = {config['fiber']['gamma']}",
|
||||
f"alpha = {config['fiber']['alpha']}",
|
||||
f"D = {config['fiber']['d']}",
|
||||
f"S = {config['fiber']['s']}",
|
||||
f"birefsteps = {config['fiber'].get('birefsteps',0)}",
|
||||
f"max_delta_beta = {config['fiber'].get('max_delta_beta', 0)}",
|
||||
f"birefsteps = {config['fiber'].get('birefsteps', 0)}",
|
||||
f"pmd_q = {config['fiber'].get('pmd_q', 0)}",
|
||||
f"birefseed = {hex(birefseed)}" if birefseed else "; birefseed = not set",
|
||||
f"dgd = {config['fiber'].get('dgd', 0)}",
|
||||
f"ortho_error = {config['fiber'].get('ortho_error', 0)}",
|
||||
f"pol_error = {config['fiber'].get('pol_error', 0)}",
|
||||
"",
|
||||
"[signal]",
|
||||
f"seed = {hex(seed)}" if seed else "; seed = not set",
|
||||
@@ -335,100 +377,93 @@ def save_data(data, config):
|
||||
f'modulation = "{config["signal"]["modulation"]}"',
|
||||
f"mod_order = {config['signal']['mod_order']}",
|
||||
f"mod_depth = {config['signal']['mod_depth']}",
|
||||
""
|
||||
"",
|
||||
f"max_jitter = {config['signal']['max_jitter']}",
|
||||
f"jitter_seed = {hex(jitter_seed)}" if jitter_seed else "; jitter_seed = not set",
|
||||
""
|
||||
"",
|
||||
f"laser_power = {config['signal']['laser_power']}",
|
||||
f"edfa_power = {config['signal']['edfa_power']}",
|
||||
f"edfa_nf = {config['signal']['edfa_nf']}",
|
||||
""
|
||||
f"osnr = {osnr}",
|
||||
"",
|
||||
f'pulse_shape = "{config["signal"]["pulse_shape"]}"',
|
||||
f"fwhm = {config['signal']['fwhm']}",
|
||||
"",
|
||||
"[data]",
|
||||
f'dir = "{str(data_dir)}"',
|
||||
f'npy_dir = "{npy_dir}"',
|
||||
"file = "
|
||||
"file = ",
|
||||
))
|
||||
config_hash = hashlib.md5(config_content.encode()).hexdigest()
|
||||
save_file = f"{config_hash}.npy"
|
||||
save_file = f"{config_hash}.h5"
|
||||
config_content += f'"{str(save_file)}"\n'
|
||||
|
||||
config_filename:Path = create_config_filename(config, data_dir, timestamp)
|
||||
while config_filename.exists():
|
||||
time.sleep(1)
|
||||
config_filename = create_config_filename(config, data_dir=data_dir)
|
||||
|
||||
|
||||
with open(config_filename, "w") as f:
|
||||
f.write(config_content)
|
||||
|
||||
with h5py.File(save_dir / save_file, "w") as outfile:
|
||||
outfile.create_dataset("data", data=save_data)
|
||||
outfile.create_dataset("symbols", data=metadata.pop("symbols"))
|
||||
for key, value in metadata.items():
|
||||
# if isinstance(value, dict):
|
||||
# value = json.dumps(model_runner.convert_arrays(value))
|
||||
outfile.attrs[key] = value
|
||||
# np.save(save_dir / save_file, save_data)
|
||||
|
||||
# print("Saved config to", config_filename)
|
||||
# print("Saved data to", save_dir / save_file)
|
||||
|
||||
return config_filename
|
||||
|
||||
def create_config_filename(config, data_dir:Path, timestamp=None):
|
||||
if timestamp is None:
|
||||
timestamp = datetime.now()
|
||||
filename_components = (
|
||||
timestamp.strftime("%Y%m%d-%H%M%S"),
|
||||
config["glova"]["sps"],
|
||||
config["glova"]["nos"],
|
||||
config["signal"]["osnr"],
|
||||
config["fiber"]["length"],
|
||||
config["fiber"]["gamma"],
|
||||
config["fiber"]["alpha"],
|
||||
config["fiber"]["d"],
|
||||
config["fiber"]["s"],
|
||||
f"{config['signal']['modulation'].upper()}{config['signal']['mod_order']}",
|
||||
config['fiber'].get('birefsteps',0),
|
||||
config["fiber"].get("max_delta_beta", 0),
|
||||
config["fiber"].get("birefsteps", 0),
|
||||
config["fiber"].get("pmd_q", 0),
|
||||
int(config["glova"]["symbolrate"] / 1e9),
|
||||
)
|
||||
|
||||
lookup_file = "-".join(map(str, filename_components)) + ".ini"
|
||||
with open(data_dir / lookup_file, "w") as f:
|
||||
f.write(config_content)
|
||||
return data_dir / lookup_file
|
||||
|
||||
np.save(save_dir / save_file, save_data)
|
||||
|
||||
print("Saved config to", data_dir / lookup_file)
|
||||
print("Saved data to", save_dir / save_file)
|
||||
|
||||
|
||||
def length_loop(config, lengths, incremental=False, bireflength=None, save=True):
|
||||
def length_loop(config, lengths, save=True):
|
||||
lengths = sorted(lengths)
|
||||
input_override = None
|
||||
birefsteps_running = 0
|
||||
for lind, length in enumerate(lengths):
|
||||
# print(f"\nGenerating data for fiber length {length}")
|
||||
if lind > 0 and incremental:
|
||||
# set the length to the difference between the current and previous length -> incremental
|
||||
length = lengths[lind] - lengths[lind - 1]
|
||||
if incremental:
|
||||
print(
|
||||
f"\nGenerating data for fiber length {lengths[lind]}m [using {length}m increment]"
|
||||
)
|
||||
else:
|
||||
print(f"\nGenerating data for fiber length {length}m")
|
||||
for length in lengths:
|
||||
print(f"\nGenerating data for fiber length {length}m")
|
||||
config["fiber"]["length"] = length
|
||||
if bireflength is not None and bireflength > 0:
|
||||
config["fiber"]["birefsteps"] = length // bireflength
|
||||
birefsteps_running += config["fiber"]["birefsteps"]
|
||||
# set the input data to the output data of the previous run
|
||||
cfiber, cdata, noise, edfa = initialize_fiber_and_data(
|
||||
config, input_data_override=input_override
|
||||
)
|
||||
|
||||
if lind == 0:
|
||||
cdata_orig = cdata
|
||||
cfiber, cdata, noise, edfa, symbols, py_glova = initialize_fiber_and_data(config)
|
||||
|
||||
mean_power_in = np.sum(pypho.functions.getpower_W(cdata.E_in))
|
||||
print(
|
||||
f"Mean power in: {mean_power_in:.3e} W ({pypho.functions.W_to_dBm(mean_power_in):.3e} dBm)"
|
||||
)
|
||||
|
||||
cfiber()
|
||||
|
||||
|
||||
mean_power_out = np.sum(pypho.functions.getpower_W(cdata.E_out))
|
||||
print(
|
||||
f"Mean power out: {mean_power_out:.3e} W ({pypho.functions.W_to_dBm(mean_power_out):.3e} dBm)"
|
||||
)
|
||||
|
||||
if incremental:
|
||||
input_override = (cdata.E_out, noise)
|
||||
cdata.E_in = cdata_orig.E_in
|
||||
config["fiber"]["length"] = lengths[lind]
|
||||
if bireflength is not None:
|
||||
config["fiber"]["birefsteps"] = birefsteps_running
|
||||
print(f"Mean power in: {mean_power_in:.3e} W ({pypho.functions.W_to_dBm(mean_power_in):.3e} dBm)")
|
||||
print(f"Mean power out: {mean_power_out:.3e} W ({pypho.functions.W_to_dBm(mean_power_out):.3e} dBm)")
|
||||
|
||||
E_tmp = [{'E': cdata.E_out, 'noise': noise*(-cfiber.params.l*cfiber.params.alpha)}]
|
||||
E_tmp = [{"E": cdata.E_out, "noise": noise * (-cfiber.params.l * cfiber.params.alpha)}]
|
||||
E_tmp = edfa(E=E_tmp)
|
||||
cdata.E_out = E_tmp[0]['E']
|
||||
cdata.E_out = E_tmp[0]["E"]
|
||||
|
||||
mean_power_amp = np.sum(pypho.functions.getpower_W(cdata.E_out))
|
||||
print(f"Mean power after EDFA: {mean_power_amp:.3e} W ({pypho.functions.W_to_dBm(mean_power_amp):.3e} dBm)")
|
||||
|
||||
if save:
|
||||
save_data(cdata, config)
|
||||
|
||||
@@ -436,27 +471,55 @@ def length_loop(config, lengths, incremental=False, bireflength=None, save=True)
|
||||
|
||||
|
||||
def single_run_with_plot(config, save=True):
|
||||
cfiber, cdata, noise, edfa = initialize_fiber_and_data(config)
|
||||
|
||||
mean_power_in = np.sum(pypho.functions.getpower_W(cdata.E_in))
|
||||
print(
|
||||
f"Mean power in: {mean_power_in:.3e} W ({pypho.functions.W_to_dBm(mean_power_in):.3e} dBm)"
|
||||
)
|
||||
|
||||
cfiber()
|
||||
|
||||
mean_power_out = np.sum(pypho.functions.getpower_W(cdata.E_out))
|
||||
print(
|
||||
f"Mean power out: {mean_power_out:.3e} W ({pypho.functions.W_to_dBm(mean_power_out):.3e} dBm)"
|
||||
)
|
||||
|
||||
E_tmp = [{'E': cdata.E_out, 'noise': noise*(-cfiber.params.l*cfiber.params.alpha)}]
|
||||
E_tmp = edfa(E=E_tmp)
|
||||
cdata.E_out = E_tmp[0]['E']
|
||||
if save:
|
||||
save_data(cdata, config)
|
||||
cfiber, cdata, config_filename = single_run(config, save)
|
||||
|
||||
in_out_eyes(cfiber, cdata, show_pols=False)
|
||||
return config_filename
|
||||
|
||||
|
||||
def single_run(config, save=True, silent=True):
|
||||
cfiber, cdata, noise, edfa, symbols, glova, E_in = initialize_fiber_and_data(config)
|
||||
|
||||
# transmit
|
||||
cfiber()
|
||||
|
||||
# amplify
|
||||
E_tmp = [{"E": cdata.E_out, "noise": noise}]
|
||||
|
||||
E_tmp = edfa(E=E_tmp)
|
||||
|
||||
|
||||
# rotate
|
||||
# ortho error
|
||||
ortho_error = config["fiber"].get("ortho_error", 0)
|
||||
|
||||
E_tmp[0]["E"] = np.stack((
|
||||
E_tmp[0]["E"][0] * np.cos(ortho_error/2) + E_tmp[0]["E"][1] * np.sin(ortho_error/2),
|
||||
E_tmp[0]["E"][0] * np.sin(ortho_error/2) + E_tmp[0]["E"][1] * np.cos(ortho_error/2)
|
||||
), axis=0)
|
||||
|
||||
|
||||
pol_error = config['fiber'].get('pol_error', 0)
|
||||
|
||||
E_tmp[0]["E"] = np.stack((
|
||||
E_tmp[0]["E"][0] * np.cos(pol_error) - E_tmp[0]["E"][1] * np.sin(pol_error),
|
||||
E_tmp[0]["E"][0] * np.sin(pol_error) + E_tmp[0]["E"][1] * np.cos(pol_error)
|
||||
), axis=0)
|
||||
|
||||
|
||||
|
||||
|
||||
# output
|
||||
cdata.E_out = E_tmp[0]["E"]
|
||||
|
||||
config_filename = None
|
||||
symbols = np.array(symbols)
|
||||
if save:
|
||||
config_filename = save_data(cdata, config, **{"symbols": symbols})
|
||||
if not silent:
|
||||
print(f"Saved config to {config_filename}")
|
||||
return cfiber, cdata, config_filename
|
||||
|
||||
|
||||
def in_out_eyes(cfiber, cdata, show_pols=False):
|
||||
fig, axs = plt.subplots(2, 2, sharex=True, sharey=True)
|
||||
@@ -620,9 +683,7 @@ def plot_eye_diagram(
|
||||
signal = signal[: head * eye_width]
|
||||
if normalize:
|
||||
signal = signal / np.max(signal)
|
||||
slices = np.lib.stride_tricks.sliding_window_view(signal, eye_width + 1)[
|
||||
offset % (eye_width + 1) :: eye_width
|
||||
]
|
||||
slices = np.lib.stride_tricks.sliding_window_view(signal, eye_width + 1)[offset % (eye_width + 1) :: eye_width]
|
||||
plt_ax = np.arange(-eye_width // 2, eye_width // 2 + 1) / samplerate
|
||||
for slice in slices:
|
||||
ax.plot(plt_ax, slice, color=color, alpha=0.1)
|
||||
@@ -643,13 +704,26 @@ if __name__ == "__main__":
|
||||
# lengths = [*lengths, *lengths]
|
||||
lengths = (
|
||||
# 8000, 9000,
|
||||
10000, 20000, 30000, 40000, 50000, 60000, 70000, 80000, 90000,
|
||||
95000, 100000, 105000, 110000, 115000, 120000
|
||||
10000,
|
||||
20000,
|
||||
30000,
|
||||
40000,
|
||||
50000,
|
||||
60000,
|
||||
70000,
|
||||
80000,
|
||||
90000,
|
||||
95000,
|
||||
100000,
|
||||
105000,
|
||||
110000,
|
||||
115000,
|
||||
120000,
|
||||
)
|
||||
lengths = sorted(lengths)
|
||||
|
||||
length_loop(config, lengths, incremental=False, bireflength=1000, save=True)
|
||||
# lengths = (10000,100000)
|
||||
|
||||
# length_loop(config, lengths, save=True)
|
||||
# birefringence is constant over coupling length -> several 100m -> bireflength=1000 (m)
|
||||
|
||||
# single_run_with_plot(config, save=True)
|
||||
|
||||
single_run_with_plot(config, save=False)
|
||||
138
src/single-core-regen/testing/prob_dens.ipynb
Normal file
138
src/single-core-regen/testing/prob_dens.ipynb
Normal file
File diff suppressed because one or more lines are too long
1351
src/single-core-regen/tolerance_testing.py
Normal file
1351
src/single-core-regen/tolerance_testing.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -26,7 +26,7 @@ global_settings = GlobalSettings(
|
||||
)
|
||||
|
||||
data_settings = DataSettings(
|
||||
config_path="data/*-128-16384-10000-0-0-0-0-PAM4-0-0.4.ini",
|
||||
config_path="data/20241211-105524-128-16384-1-0-0-0-0-PAM4-0-0.ini",
|
||||
# config_path=[f"data/20241202-*-128-16384-{length}-0-0-17-0-PAM4-0.ini" for length in range(48000, 53000, 1000)],
|
||||
dtype="complex64",
|
||||
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
|
||||
@@ -53,14 +53,14 @@ pytorch_settings = PytorchSettings(
|
||||
)
|
||||
|
||||
model_settings = ModelSettings(
|
||||
output_dim=3,
|
||||
output_dim=1,
|
||||
n_hidden_layers=3,
|
||||
overrides={
|
||||
"n_hidden_nodes_0": 2,
|
||||
"n_hidden_nodes_1": 2,
|
||||
"n_hidden_nodes_2": 2,
|
||||
"n_hidden_nodes_0": 4,
|
||||
"n_hidden_nodes_1": 4,
|
||||
"n_hidden_nodes_2": 4,
|
||||
},
|
||||
dropout_prob=0.01,
|
||||
dropout_prob=0,
|
||||
model_layer_function="ONNRect",
|
||||
model_activation_func="EOActivation",
|
||||
model_layer_kwargs={"square": True},
|
||||
@@ -110,20 +110,24 @@ model_settings = ModelSettings(
|
||||
)
|
||||
|
||||
optimizer_settings = OptimizerSettings(
|
||||
optimizer="AdamW",
|
||||
optimizer="RMSprop",
|
||||
# optimizer="AdamW",
|
||||
optimizer_kwargs={
|
||||
"lr": 0.005,
|
||||
"amsgrad": True,
|
||||
"lr": 0.01,
|
||||
"alpha": 0.9,
|
||||
"momentum": 0.1,
|
||||
"eps": 1e-8,
|
||||
"centered": True,
|
||||
# "amsgrad": True,
|
||||
# "weight_decay": 1e-7,
|
||||
},
|
||||
# learning_rate=0.05,
|
||||
scheduler="ReduceLROnPlateau",
|
||||
scheduler_kwargs={
|
||||
"patience": 2**6,
|
||||
"patience": 2**5,
|
||||
"factor": 0.75,
|
||||
# "threshold": 1e-3,
|
||||
"min_lr": 1e-6,
|
||||
"cooldown": 10,
|
||||
# "cooldown": 10,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -320,6 +320,29 @@ class normalize_by_first(nn.Module):
|
||||
def forward(self, data):
|
||||
return data / data[:, 0].unsqueeze(1)
|
||||
|
||||
class rotate(nn.Module):
|
||||
def __init__(self):
|
||||
super(rotate, self).__init__()
|
||||
|
||||
def forward(self, data, angle):
|
||||
# data -> (batch, n*2)
|
||||
# angle -> (batch, n)
|
||||
data_ = data
|
||||
if angle.ndim == 1:
|
||||
angle_ = angle.unsqueeze(1)
|
||||
else:
|
||||
angle_ = angle
|
||||
angle_ = angle_.expand(-1, data_.shape[1]//2)
|
||||
c = torch.cos(angle_)
|
||||
s = torch.sin(angle_)
|
||||
rot = torch.stack([torch.stack([c, -s], dim=2),
|
||||
torch.stack([s, c], dim=2)], dim=3)
|
||||
d = torch.bmm(data_.reshape(-1, 1, 2), rot.view(-1, 2, 2).to(dtype=data_.dtype)).reshape(*data.shape)
|
||||
# d = torch.bmm(data.unsqueeze(-1).mT, rot.to(dtype=data.dtype).mT).mT.squeeze(-1)
|
||||
|
||||
return d
|
||||
|
||||
|
||||
class photodiode(nn.Module):
|
||||
def __init__(self, size, bias=True):
|
||||
super(photodiode, self).__init__()
|
||||
@@ -418,8 +441,7 @@ class input_rotator(nn.Module):
|
||||
# return out
|
||||
|
||||
|
||||
#### as defined by zhang et al
|
||||
|
||||
#### as defined by zhang et alas
|
||||
|
||||
class DropoutComplex(nn.Module):
|
||||
def __init__(self, p=0.5):
|
||||
@@ -441,7 +463,7 @@ class Scale(nn.Module):
|
||||
self.scale = nn.Parameter(torch.ones(size, dtype=torch.float32))
|
||||
|
||||
def forward(self, x):
|
||||
return x * self.scale
|
||||
return x * torch.sqrt(self.scale)
|
||||
|
||||
def __repr__(self):
|
||||
return f"Scale({self.size})"
|
||||
@@ -459,6 +481,15 @@ class Identity(nn.Module):
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
class phase_shift(nn.Module):
|
||||
def __init__(self, size):
|
||||
super(phase_shift, self).__init__()
|
||||
self.size = size
|
||||
self.phase = nn.Parameter(torch.rand(size))
|
||||
|
||||
def forward(self, x):
|
||||
return x * torch.exp(1j*self.phase)
|
||||
|
||||
|
||||
class PowRot(nn.Module):
|
||||
def __init__(self, bias=False):
|
||||
@@ -487,7 +518,7 @@ class MZISingle(nn.Module):
|
||||
return x * torch.exp(1j * self.phi) * torch.sin(self.omega + self.func(x))
|
||||
|
||||
def naive_angle_loss(x: torch.Tensor, target: torch.Tensor, mod=2*torch.pi):
|
||||
return torch.fmod((x - target), mod).square().mean()
|
||||
return torch.fmod((x.abs().real - target.abs().real), mod).abs().mean()
|
||||
|
||||
def cosine_loss(x: torch.Tensor, target: torch.Tensor):
|
||||
return (2*(1 - torch.cos(x - target))).mean()
|
||||
@@ -508,51 +539,46 @@ def angle_mse_loss(x: torch.Tensor, target: torch.Tensor):
|
||||
|
||||
class EOActivation(nn.Module):
|
||||
def __init__(self, size=None):
|
||||
# 10.1109/SiPhotonics60897.2024.10543376
|
||||
# 10.1109/JSTQE.2019.2930455
|
||||
super(EOActivation, self).__init__()
|
||||
if size is None:
|
||||
raise ValueError("Size must be specified")
|
||||
self.size = size
|
||||
self.alpha = nn.Parameter(torch.ones(size))
|
||||
self.V_bias = nn.Parameter(torch.ones(size))
|
||||
self.gain = nn.Parameter(torch.ones(size))
|
||||
# if bias:
|
||||
# self.phase_bias = nn.Parameter(torch.zeros(size))
|
||||
# else:
|
||||
# self.register_buffer("phase_bias", torch.zeros(size))
|
||||
self.register_buffer("phase_bias", torch.clamp(torch.ones(size) + torch.randn(size)*0.1, 0, 1)*torch.pi)
|
||||
self.register_buffer("responsivity", torch.ones(size)*0.9)
|
||||
self.register_buffer("V_pi", torch.ones(size)*3)
|
||||
self.alpha = nn.Parameter(torch.rand(size))
|
||||
self.gain = nn.Parameter(torch.rand(size))
|
||||
self.V_bias = nn.Parameter(torch.rand(size))
|
||||
# self.register_buffer("gain", torch.ones(size))
|
||||
# self.register_buffer("responsivity", torch.ones(size))
|
||||
# self.register_buffer("V_pi", torch.ones(size))
|
||||
|
||||
self.reset_weights()
|
||||
|
||||
def reset_weights(self):
|
||||
if "alpha" in self._parameters:
|
||||
self.alpha.data = torch.ones(self.size)*0.5
|
||||
if "V_pi" in self._parameters:
|
||||
self.V_pi.data = torch.ones(self.size)*3
|
||||
self.alpha.data = torch.rand(self.size)
|
||||
# if "V_pi" in self._parameters:
|
||||
# self.V_pi.data = torch.rand(self.size)*3
|
||||
if "V_bias" in self._parameters:
|
||||
self.V_bias.data = torch.zeros(self.size)
|
||||
self.V_bias.data = torch.randn(self.size)
|
||||
if "gain" in self._parameters:
|
||||
self.gain.data = torch.ones(self.size)
|
||||
if "responsivity" in self._parameters:
|
||||
self.responsivity.data = torch.ones(self.size)*0.9
|
||||
if "bias" in self._parameters:
|
||||
self.phase_bias.data = torch.zeros(self.size)
|
||||
self.gain.data = torch.rand(self.size)
|
||||
# if "responsivity" in self._parameters:
|
||||
# self.responsivity.data = torch.ones(self.size)*0.9
|
||||
# if "bias" in self._parameters:
|
||||
# self.phase_bias.data = torch.zeros(self.size)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
phi_b = torch.pi * self.V_bias / (self.V_pi + 1e-8)
|
||||
g_phi = torch.pi * (self.alpha * self.gain * self.responsivity) / (self.V_pi + 1e-8)
|
||||
phi_b = torch.pi * self.V_bias# / (self.V_pi)
|
||||
g_phi = torch.pi * (self.alpha * self.gain)# * self.responsivity)# / (self.V_pi)
|
||||
intermediate = g_phi * x.abs().square() + phi_b
|
||||
return (
|
||||
1j
|
||||
* torch.sqrt(1 - self.alpha)
|
||||
* torch.exp(-0.5j * (intermediate + self.phase_bias))
|
||||
* torch.exp(-0.5j * intermediate)
|
||||
* torch.cos(0.5 * intermediate)
|
||||
* x
|
||||
)
|
||||
|
||||
|
||||
class Pow(nn.Module):
|
||||
"""
|
||||
implements the activation function
|
||||
@@ -693,6 +719,7 @@ __all__ = [
|
||||
MZISingle,
|
||||
EOActivation,
|
||||
photodiode,
|
||||
phase_shift,
|
||||
# SaturableAbsorberLambertW,
|
||||
# SaturableAbsorber,
|
||||
# SpreadLayer,
|
||||
|
||||
105
src/single-core-regen/util/core.py
Normal file
105
src/single-core-regen/util/core.py
Normal file
@@ -0,0 +1,105 @@
|
||||
# Copyright (c) 2015, Warren Weckesser. All rights reserved.
|
||||
# This software is licensed according to the "BSD 2-clause" license.
|
||||
|
||||
import hashlib
|
||||
import h5py
|
||||
import numpy as _np
|
||||
from scipy.interpolate import interp1d as _interp1d
|
||||
from scipy.ndimage import gaussian_filter as _gaussian_filter
|
||||
from ._brescount import bres_curve_count as _bres_curve_count
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
__all__ = ['grid_count']
|
||||
|
||||
|
||||
def grid_count(y, window_size, offset=0, size=None, fuzz=True, blur=0, bounds=None):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
`y` is the 1-d array of signal samples.
|
||||
|
||||
`window_size` is the number of samples to show horizontally in the
|
||||
eye diagram. Typically this is twice the number of samples in a
|
||||
"symbol" (i.e. in a data bit).
|
||||
|
||||
`offset` is the number of initial samples to skip before computing
|
||||
the eye diagram. This allows the overall phase of the diagram to
|
||||
be adjusted.
|
||||
|
||||
`size` must be a tuple of two integers. It sets the size of the
|
||||
array of counts, (height, width). The default is (800, 640).
|
||||
|
||||
`fuzz`: If True, the values in `y` are reinterpolated with a
|
||||
random "fuzz factor" before plotting in the eye diagram. This
|
||||
reduces an aliasing-like effect that arises with the use of
|
||||
Bresenham's algorithm.
|
||||
|
||||
`bounds` must be a tuple of two floating point values, (ymin, ymax).
|
||||
These set the y range of the returned array. If not given, the
|
||||
bounds are `(y.min() - 0.05*A, y.max() + 0.05*A)`, where `A` is
|
||||
`y.max() - y.min()`.
|
||||
|
||||
Return Value
|
||||
------------
|
||||
Returns a numpy array of integers.
|
||||
|
||||
"""
|
||||
# hash input params
|
||||
param_ob = (y, window_size, offset, size, fuzz, blur, bounds)
|
||||
param_hash = hashlib.md5(str(param_ob).encode()).hexdigest()
|
||||
cache_dir = Path.home()/".eyediagram"/".cache"
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
if (cache_dir/param_hash).is_file():
|
||||
try:
|
||||
with h5py.File(cache_dir/param_hash, "r") as infile:
|
||||
counts = infile["counts"][:]
|
||||
if counts.len() != 0:
|
||||
return counts
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
|
||||
if size is None:
|
||||
size = (800, 640)
|
||||
height, width = size
|
||||
dt = width / window_size
|
||||
counts = _np.zeros((width, height), dtype=_np.int32)
|
||||
|
||||
if bounds is None:
|
||||
ymin = y.min()
|
||||
ymax = y.max()
|
||||
yamp = ymax - ymin
|
||||
ymin = ymin - 0.05*yamp
|
||||
ymax = ymax + 0.05*yamp
|
||||
ymax = _np.ceil(ymax*10)/10
|
||||
ymin = _np.floor(ymin*10)/10
|
||||
else:
|
||||
ymin, ymax = bounds
|
||||
|
||||
start = offset
|
||||
while start + window_size < len(y):
|
||||
end = start + window_size
|
||||
yy = y[start:end+1]
|
||||
k = _np.arange(len(yy))
|
||||
xx = dt*k
|
||||
if fuzz:
|
||||
f = _interp1d(xx, yy, kind='cubic')
|
||||
jiggle = dt*(_np.random.beta(a=3, b=3, size=len(xx)-2) - 0.5)
|
||||
xx[1:-1] += jiggle
|
||||
yd = f(xx)
|
||||
else:
|
||||
yd = yy
|
||||
iyd = (height * (yd - ymin)/(ymax - ymin)).astype(_np.int32)
|
||||
_bres_curve_count(xx.astype(_np.int32), iyd, counts)
|
||||
|
||||
start = end
|
||||
|
||||
if blur != 0:
|
||||
counts = _gaussian_filter(counts, sigma=blur)
|
||||
|
||||
with h5py.File(cache_dir/param_hash, "w") as outfile:
|
||||
outfile.create_dataset("data", data=counts)
|
||||
|
||||
return counts
|
||||
@@ -1,10 +1,12 @@
|
||||
from pathlib import Path
|
||||
import h5py
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
# from torch.utils.data import Sampler
|
||||
import numpy as np
|
||||
import configparser
|
||||
import multiprocessing as mp
|
||||
|
||||
# class SubsetSampler(Sampler[int]):
|
||||
# """
|
||||
@@ -24,7 +26,22 @@ import configparser
|
||||
# return len(self.indices)
|
||||
|
||||
|
||||
def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=False, device=None, dtype=None):
|
||||
def load_from_file(datapath):
|
||||
if str(datapath).endswith(".h5"):
|
||||
symbols = None
|
||||
with h5py.File(datapath, "r") as infile:
|
||||
data = infile["data"][:]
|
||||
try:
|
||||
symbols = np.swapaxes(infile["symbols"][:], 0, 1)
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
symbols = None
|
||||
data = np.load(datapath)
|
||||
return data, symbols
|
||||
|
||||
|
||||
def load_data(config_path, skipfirst=0, skiplast=0, symbols=None, real=False, normalize=1, device=None, dtype=None):
|
||||
filepath = Path(config_path)
|
||||
filepath = filepath.parent.glob(filepath.name)
|
||||
config = configparser.ConfigParser()
|
||||
@@ -40,14 +57,28 @@ def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=Fals
|
||||
if symbols is None:
|
||||
symbols = int(config["glova"]["nos"]) - skipfirst
|
||||
|
||||
data = np.load(datapath)[int(skipfirst * sps) : int(symbols * sps + skipfirst * sps)]
|
||||
timestamps = np.arange(int(skipfirst * sps), int(symbols * sps + skipfirst * sps))
|
||||
data, orig_symbols = load_from_file(datapath)
|
||||
|
||||
if normalize:
|
||||
# square gets normalized to 1, as the power is (proportional to) the square of the amplitude
|
||||
a, b, c, d = np.square(data.T)
|
||||
a, b, c, d = a / np.max(np.abs(a)), b / np.max(np.abs(b)), c / np.max(np.abs(c)), d / np.max(np.abs(d))
|
||||
data = np.sqrt(np.array([a, b, c, d]).T)
|
||||
data = data[int(skipfirst * sps) : int(symbols * sps + skipfirst * sps - skiplast * sps)]
|
||||
orig_symbols = orig_symbols[skipfirst : symbols + skipfirst - skiplast]
|
||||
timestamps = np.arange(int(skipfirst * sps), int(symbols * sps + skipfirst * sps - skiplast * sps))
|
||||
|
||||
data *= np.sqrt(normalize)
|
||||
|
||||
launch_power = float(config["signal"]["laser_power"])
|
||||
output_power = float(config["signal"]["edfa_power"])
|
||||
|
||||
target_normalization = 10 ** (output_power / 10) / 10 ** (launch_power / 10)
|
||||
# target_normalization *= 0.5 # allow 50% power loss, so the network can ignore parts of the signal
|
||||
|
||||
data[:, 0:2] *= np.sqrt(target_normalization)
|
||||
|
||||
# if normalize:
|
||||
# # square gets normalized to 1, as the power is (proportional to) the square of the amplitude
|
||||
# a, b, c, d = data.T
|
||||
# a, b, c, d = a - np.min(np.abs(a)), b - np.min(np.abs(b)), c - np.min(np.abs(c)), d - np.min(np.abs(d))
|
||||
# a, b, c, d = a / np.max(np.abs(a)), b / np.max(np.abs(b)), c / np.max(np.abs(c)), d / np.max(np.abs(d))
|
||||
# data = np.array([a, b, c, d]).T
|
||||
|
||||
if real:
|
||||
data = np.abs(data)
|
||||
@@ -58,7 +89,7 @@ def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=Fals
|
||||
|
||||
data = torch.tensor(data, device=device, dtype=dtype)
|
||||
|
||||
return data, config
|
||||
return data, config, orig_symbols
|
||||
|
||||
|
||||
def roll_along(arr, shifts, dim):
|
||||
@@ -110,11 +141,15 @@ class FiberRegenerationDataset(Dataset):
|
||||
target_delay: float | int = 0,
|
||||
xy_delay: float | int = 0,
|
||||
drop_first: float | int = 0,
|
||||
drop_last=0,
|
||||
dtype: torch.dtype = None,
|
||||
real: bool = False,
|
||||
device=None,
|
||||
polarisations: tuple | list = (0,),
|
||||
# osnr: float|None = None,
|
||||
polarisations=None,
|
||||
randomise_polarisations: bool = False,
|
||||
repeat_randoms: int = 1,
|
||||
# cross_pol_interference: float = 0,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -148,64 +183,53 @@ class FiberRegenerationDataset(Dataset):
|
||||
assert drop_first >= 0, "drop_first must be non-negative"
|
||||
|
||||
self.randomise_polarisations = randomise_polarisations
|
||||
# self.cross_pol_interference = cross_pol_interference
|
||||
|
||||
faux = kwargs.pop("faux", False)
|
||||
|
||||
if faux:
|
||||
data_raw = np.array(
|
||||
[[i + 0.1j, i + 0.2j, i + 1.1j, i + 1.2j] for i in range(12800)],
|
||||
dtype=np.complex128,
|
||||
data_raw = None
|
||||
self.config = None
|
||||
files = []
|
||||
self.orig_symbols = None
|
||||
for file_path in file_path if isinstance(file_path, (tuple, list)) else [file_path]:
|
||||
data, config, orig_syms = load_data(
|
||||
file_path,
|
||||
skipfirst=drop_first,
|
||||
skiplast=drop_last,
|
||||
symbols=kwargs.get("num_symbols", None),
|
||||
real=real,
|
||||
normalize=1000,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
data_raw = torch.tensor(data_raw, device=device, dtype=dtype)
|
||||
timestamps = torch.arange(12800)
|
||||
|
||||
data_raw = torch.concatenate([data_raw, timestamps.reshape(-1, 1)], axis=-1)
|
||||
|
||||
self.config = {
|
||||
"data": {"dir": '"."', "npy_dir": '"."', "file": "faux"},
|
||||
"glova": {"sps": 128},
|
||||
}
|
||||
else:
|
||||
data_raw = None
|
||||
self.config = None
|
||||
files = []
|
||||
for file_path in file_path if isinstance(file_path, (tuple, list)) else [file_path]:
|
||||
data, config = load_data(
|
||||
file_path,
|
||||
skipfirst=drop_first,
|
||||
symbols=kwargs.get("num_symbols", None),
|
||||
real=real,
|
||||
normalize=True,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
if data_raw is None:
|
||||
data_raw = data
|
||||
if orig_syms is not None:
|
||||
if self.orig_symbols is None:
|
||||
self.orig_symbols = orig_syms
|
||||
else:
|
||||
data_raw = torch.cat([data_raw, data], dim=0)
|
||||
if self.config is None:
|
||||
self.config = config
|
||||
else:
|
||||
assert self.config["glova"]["sps"] == config["glova"]["sps"], "samples per symbol must be the same"
|
||||
files.append(config["data"]["file"].strip('"'))
|
||||
self.config["data"]["file"] = str(files)
|
||||
self.orig_symbols = np.concat((self.orig_symbols, orig_syms), axis=-1)
|
||||
|
||||
for i, angle in enumerate(torch.tensor(np.array(polarisations))):
|
||||
data_raw_copy = data_raw.clone()
|
||||
if angle == 0:
|
||||
continue
|
||||
sine = torch.sin(angle)
|
||||
cosine = torch.cos(angle)
|
||||
data_raw_copy[:, 2] = data_raw[:, 2] * cosine - data_raw[:, 3] * sine
|
||||
data_raw_copy[:, 3] = data_raw[:, 2] * sine + data_raw[:, 3] * cosine
|
||||
if i == 0:
|
||||
data_raw = data_raw_copy
|
||||
else:
|
||||
data_raw = torch.cat([data_raw, data_raw_copy], dim=0)
|
||||
if data_raw is None:
|
||||
data_raw = data
|
||||
else:
|
||||
data_raw = torch.cat([data_raw, data], dim=0)
|
||||
if self.config is None:
|
||||
self.config = config
|
||||
else:
|
||||
assert self.config["glova"]["sps"] == config["glova"]["sps"], "samples per symbol must be the same"
|
||||
files.append(config["data"]["file"].strip('"'))
|
||||
self.config["data"]["file"] = str(files)
|
||||
|
||||
# if polarisations is not None:
|
||||
# data_raw_clone = data_raw.clone()
|
||||
# # rotate the polarisation by 180 degrees
|
||||
# data_raw_clone[2, :] *= -1
|
||||
# data_raw_clone[3, :] *= -1
|
||||
# data_raw = torch.cat([data_raw, data_raw_clone], dim=0)
|
||||
|
||||
self.polarisations = bool(polarisations)
|
||||
|
||||
self.device = data_raw.device
|
||||
|
||||
self.samples_per_symbol = int(self.config["glova"]["sps"])
|
||||
# self.num_symbols = int(self.config["glova"]["nos"])
|
||||
self.samples_per_slice = int(symbols * self.samples_per_symbol)
|
||||
self.symbols_per_slice = self.samples_per_slice / self.samples_per_symbol
|
||||
|
||||
@@ -278,23 +302,94 @@ class FiberRegenerationDataset(Dataset):
|
||||
timestamps = data_raw[4, :]
|
||||
data_raw = data_raw[:4, :]
|
||||
data_raw = data_raw.view(2, 2, -1)
|
||||
timestamps_doubled = torch.cat([timestamps.unsqueeze(dim=0), timestamps.unsqueeze(dim=0)], dim=0).unsqueeze(
|
||||
dim=1
|
||||
)
|
||||
data_raw = torch.cat([data_raw, timestamps_doubled], dim=1)
|
||||
fiber_in = data_raw[0, :, :]
|
||||
fiber_out = data_raw[1, :, :]
|
||||
# timestamps_doubled = torch.cat([timestamps.unsqueeze(dim=0), timestamps.unsqueeze(dim=0)], dim=0).unsqueeze(
|
||||
# dim=1
|
||||
# )
|
||||
fiber_in = torch.cat([fiber_in, timestamps.unsqueeze(0)], dim=0)
|
||||
fiber_out = torch.cat([fiber_out, timestamps.unsqueeze(0)], dim=0)
|
||||
|
||||
# fiber_out: [E_out_x, E_out_y, timestamps]
|
||||
|
||||
# add noise related to amplification necessary due to splitting of the signal
|
||||
# gain_lin = output_dim*2
|
||||
# gain_lin = 1
|
||||
# edfa_nf = float(self.config["signal"]["edfa_nf"])
|
||||
# nf_lin = 10**(edfa_nf/10)
|
||||
# f0 = float(self.config["glova"]["f0"])
|
||||
|
||||
# noise_add = (gain_lin-1)*nf_lin*6.626e-34*f0*12.5e9
|
||||
|
||||
# noise = torch.randn_like(fiber_out[:2, :])
|
||||
# noise_power = torch.sum(torch.mean(noise.abs().square().squeeze(), dim=-1))
|
||||
# noise = noise * torch.sqrt(noise_add / noise_power)
|
||||
# fiber_out[:2, :] += noise
|
||||
|
||||
# if osnr is None:
|
||||
# noisy = fiber_out[:2, :]
|
||||
# else:
|
||||
# noisy = self.add_noise(fiber_out[:2, :], osnr)
|
||||
|
||||
# fiber_out = torch.cat([fiber_out, noisy], dim=0)
|
||||
|
||||
# fiber_out: [E_out_x, E_out_y, timestamps, E_out_x_noisy, E_out_y_noisy]
|
||||
|
||||
if repeat_randoms > 1:
|
||||
fiber_in = fiber_in.repeat(1, 1, repeat_randoms)
|
||||
fiber_out = fiber_out.repeat(1, 1, repeat_randoms)
|
||||
# review: potential problems with repeated timestamps when plotting
|
||||
else:
|
||||
repeat_randoms = 1
|
||||
|
||||
if self.randomise_polarisations:
|
||||
angles = torch.fmod(torch.randperm(data_raw.shape[-1]*repeat_randoms, device=fiber_out.device), 2) * torch.pi
|
||||
start_angle = torch.rand(1) * 2 * torch.pi
|
||||
angles = start_angle + torch.cumsum(torch.randn(data_raw.shape[-1])/333, dim=0) * torch.pi * 2 # random walk
|
||||
angles = torch.randn(data_raw.shape[-1], device=fiber_out.device) * 2*torch.pi / 36 # sigma = 10 degrees
|
||||
# self.angles = torch.rand(self.data.shape[0]) * 2 * torch.pi
|
||||
else:
|
||||
angles = torch.zeros(data_raw.shape[-1], device=fiber_out.device)
|
||||
|
||||
sin = torch.sin(angles)
|
||||
cos = torch.cos(angles)
|
||||
rot = torch.stack([torch.stack([cos, -sin], dim=1), torch.stack([sin, cos], dim=1)], dim=2)
|
||||
data_rot = torch.bmm(fiber_out[:2, :].T.unsqueeze(1), rot.to(dtype=data_raw.dtype)).squeeze(1).T
|
||||
# data_rot_noisy = torch.bmm(fiber_out[3:5, :].T.unsqueeze(1), rot.to(dtype=data_raw.dtype)).squeeze(1).T
|
||||
fiber_out = torch.cat((fiber_out, data_rot), dim=0)
|
||||
fiber_out = torch.cat([fiber_out, angles.unsqueeze(0)], dim=0)
|
||||
|
||||
# fiber_in:
|
||||
# 0 E_in_x,
|
||||
# 1 E_in_y,
|
||||
# 2 timestamps
|
||||
|
||||
# fiber_out:
|
||||
# 0 E_out_x,
|
||||
# 1 E_out_y,
|
||||
# 2 timestamps,
|
||||
# 3 E_out_x_rot,
|
||||
# 4 E_out_y_rot,
|
||||
# 5 angle
|
||||
|
||||
# data_raw = torch.cat([data_raw, timestamps_doubled], dim=1)
|
||||
# data layout
|
||||
# [ [E_in_x, E_in_y, timestamps],
|
||||
# [E_out_x, E_out_y, timestamps] ]
|
||||
|
||||
self.data = data_raw.unfold(dimension=-1, size=self.samples_per_slice, step=1)
|
||||
self.data = self.data.movedim(-2, 0)
|
||||
self.fiber_in = fiber_in.unfold(dimension=-1, size=self.samples_per_slice, step=1)
|
||||
self.fiber_in = self.fiber_in.movedim(-2, 0)
|
||||
|
||||
if randomise_polarisations:
|
||||
self.angles = torch.rand(self.data.shape[0]) * np.pi * 2
|
||||
# self.data[:, 1, :2, :] = self.rotate(self.data[:, 1, :2, :], self.angles)
|
||||
else:
|
||||
self.angles = torch.zeros(self.data.shape[0])
|
||||
self.fiber_out = fiber_out.unfold(dimension=-1, size=self.samples_per_slice, step=1)
|
||||
self.fiber_out = self.fiber_out.movedim(-2, 0)
|
||||
|
||||
# if self.randomise_polarisations:
|
||||
# self.angles = torch.cumsum((torch.rand(self.fiber_out.shape[0]) - 0.5) * 2 * torch.pi * 2 / 5000, dim=0)
|
||||
|
||||
# self.data = data_raw.unfold(dimension=-1, size=self.samples_per_slice, step=1)
|
||||
# self.data = self.data.movedim(-2, 0)
|
||||
# self.angles = torch.zeros(self.data.shape[0])
|
||||
...
|
||||
# ...
|
||||
# -> [no_slices, 2, 3, samples_per_slice]
|
||||
|
||||
@@ -305,69 +400,109 @@ class FiberRegenerationDataset(Dataset):
|
||||
# ...
|
||||
# ] -> [no_slices, 2, 3, samples_per_slice]
|
||||
|
||||
...
|
||||
|
||||
def __len__(self):
|
||||
return self.data.shape[0]
|
||||
return self.fiber_in.shape[0]
|
||||
|
||||
def add_noise(self, data, osnr):
|
||||
osnr_lin = 10 ** (osnr / 10)
|
||||
popt = torch.mean(data.abs().square().squeeze(), dim=-1)
|
||||
noise = torch.randn_like(data)
|
||||
pn = torch.mean(noise.abs().square().squeeze(), dim=-1)
|
||||
|
||||
mult = torch.sqrt(popt / (pn * osnr_lin))
|
||||
mult = mult * torch.eye(popt.shape[0], device=mult.device)
|
||||
mult = mult.to(dtype=noise.dtype)
|
||||
|
||||
noise = mult @ noise
|
||||
pn = torch.mean(noise.abs().square().squeeze(), dim=-1)
|
||||
noisy = data + noise
|
||||
return noisy
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if isinstance(idx, slice):
|
||||
return [self.__getitem__(i) for i in range(*idx.indices(len(self)))]
|
||||
else:
|
||||
data_slice = self.data[idx].squeeze()
|
||||
# fiber in: [E_in_x, E_in_y, timestamps]
|
||||
# fiber out: [E_out_x, E_out_y, timestamps, E_out_x_rot, E_out_y_rot, angle]
|
||||
|
||||
data_slice = data_slice[:, :, : data_slice.shape[2] // self.output_dim * self.output_dim]
|
||||
# if self.polarisations:
|
||||
output_dim = self.output_dim // 2
|
||||
self.output_dim = output_dim * 2
|
||||
|
||||
data_slice = data_slice.view(data_slice.shape[0], data_slice.shape[1], self.output_dim, -1)
|
||||
if not self.polarisations:
|
||||
output_dim = 2 * output_dim
|
||||
|
||||
# if self.randomise_polarisations:
|
||||
# angle = torch.rand(1) * torch.pi * 2
|
||||
# sine = torch.sin(angle)
|
||||
# cosine = torch.cos(angle)
|
||||
# data_slice_ = data_slice[1]
|
||||
# data_slice[1, 0] = data_slice_[0] * cosine - data_slice_[1] * sine
|
||||
# data_slice[1,1] = data_slice_[0] * sine + data_slice_[1] * cosine
|
||||
# else:
|
||||
# angle = torch.zeros(1)
|
||||
|
||||
# data = data_slice[1, :2, :, 0]
|
||||
fiber_in = self.fiber_in[idx].squeeze()
|
||||
fiber_out = self.fiber_out[idx].squeeze()
|
||||
|
||||
angle = self.angles[idx]
|
||||
fiber_in = fiber_in[..., : fiber_in.shape[-1] // output_dim * output_dim]
|
||||
fiber_out = fiber_out[..., : fiber_out.shape[-1] // output_dim * output_dim]
|
||||
|
||||
data_index = 1
|
||||
fiber_in = fiber_in.view(fiber_in.shape[0], output_dim, -1)
|
||||
fiber_out = fiber_out.view(fiber_out.shape[0], output_dim, -1)
|
||||
|
||||
data_slice[1, :2, :, :] = self.rotate(data_slice[data_index, :2, :, :], angle)
|
||||
center_angle = fiber_out[5, output_dim // 2, 0]
|
||||
angles = fiber_out[5, :, 0]
|
||||
plot_data_rot = fiber_out[3:5, output_dim // 2, 0].detach().clone()
|
||||
data = fiber_out[0:2, :, 0]
|
||||
plot_data = fiber_out[0:2, output_dim // 2, 0].detach().clone()
|
||||
|
||||
data = data_slice[1, :2, :, 0]
|
||||
# data = self.rotate(data, angle)
|
||||
|
||||
# for both polarisations (x/y), calculate the mean of the signal around the current symbol (-> corresponding to a lowpass filter)
|
||||
angle_data = data_slice[1, :2, :, :].reshape(2, -1).mean(dim=1)
|
||||
angle_data2 = self.complex_max(data_slice[1, :2, :, :].reshape(2, -1))
|
||||
plot_data = data_slice[1, :2, self.output_dim // 2, 0]
|
||||
sop = self.polarimeter(plot_data)
|
||||
# angle_data = data_slice[1, :2, :self.output_dim//2, :].squeeze().reshape(2, -1).mean(dim=-1)
|
||||
# angle = data_slice[1, 3, self.output_dim // 2, 0].real
|
||||
target = data_slice[0, :2, self.output_dim // 2, 0]
|
||||
target_timestamp = data_slice[0, 2, self.output_dim // 2, 0].real
|
||||
target = fiber_in[:2, output_dim // 2, 0]
|
||||
plot_target = fiber_in[:2, output_dim // 2, 0].detach().clone()
|
||||
target_timestamp = fiber_in[2, output_dim // 2, 0].real
|
||||
...
|
||||
|
||||
# data_timestamps = data[-1,:].real
|
||||
# data = data[:-1, :]
|
||||
# target_timestamp = target[-1].real
|
||||
# target = target[:-1]
|
||||
# plot_data = plot_data[:-1]
|
||||
if self.polarisations:
|
||||
rot = int(np.random.randint(2) * 2 - 1)
|
||||
data = rot * data
|
||||
target = rot * target
|
||||
plot_data_rot = rot * plot_data_rot
|
||||
center_angle = center_angle + (rot - 1) * torch.pi / 2 # rot: -1 or 1 -> -2 or 0 -> -pi or 0
|
||||
angles = angles + (rot - 1) * torch.pi / 2
|
||||
|
||||
pol_flipped_data = -data
|
||||
pol_flipped_target = -target
|
||||
|
||||
# transpose to interleave the x and y data in the output tensor
|
||||
data = data.transpose(0, 1).flatten().squeeze()
|
||||
angle_data = angle_data.flatten().squeeze()
|
||||
angle_data2 = angle_data.flatten().squeeze()
|
||||
angle = angle.flatten().squeeze()
|
||||
data = data / torch.sqrt(torch.ones(1) * len(data)) # power loss due to splitting
|
||||
pol_flipped_data = pol_flipped_data.transpose(0, 1).flatten().squeeze()
|
||||
pol_flipped_data = pol_flipped_data / torch.sqrt(
|
||||
torch.ones(1) * len(pol_flipped_data)
|
||||
) # power loss due to splitting
|
||||
# angle_data = angle_data.transpose(0, 1).flatten().squeeze()
|
||||
# angle_data2 = angle_data2.transpose(0,1).flatten().squeeze()
|
||||
center_angle = center_angle.flatten().squeeze()
|
||||
angles = angles.flatten().squeeze()
|
||||
# data_timestamps = data_timestamps.flatten().squeeze()
|
||||
# target = target.transpose(0,1).flatten().squeeze()
|
||||
target = target.flatten().squeeze()
|
||||
pol_flipped_target = pol_flipped_target.flatten().squeeze()
|
||||
target_timestamp = target_timestamp.flatten().squeeze()
|
||||
plot_target = plot_target.flatten().squeeze()
|
||||
plot_data = plot_data.flatten().squeeze()
|
||||
plot_data_rot = plot_data_rot.flatten().squeeze()
|
||||
|
||||
return {"x": data, "y": target, "angle": angle, "sop": sop, "angle_data": angle_data, "angle_data2": angle_data2, "timestamp": target_timestamp, "plot_data": plot_data}
|
||||
return {
|
||||
"x": data,
|
||||
"x_flipped": pol_flipped_data,
|
||||
"x_stacked": torch.cat([data, pol_flipped_data], dim=-1),
|
||||
"y": target,
|
||||
"y_flipped": pol_flipped_target,
|
||||
"y_stacked": torch.cat([target, pol_flipped_target], dim=-1),
|
||||
"center_angle": center_angle,
|
||||
"angles": angles,
|
||||
"mean_angle": angles.mean(),
|
||||
# "sop": sop,
|
||||
# "angle_data": angle_data,
|
||||
# "angle_data2": angle_data2,
|
||||
"timestamp": target_timestamp,
|
||||
"plot_target": plot_target,
|
||||
"plot_data": plot_data,
|
||||
"plot_data_rot": plot_data_rot,
|
||||
# "plot_clean": fiber_out_plot_clean,
|
||||
}
|
||||
|
||||
def complex_max(self, data, dim=-1):
|
||||
# returns element(s) with the maximum absolute value along a given dimension
|
||||
@@ -376,7 +511,6 @@ class FiberRegenerationDataset(Dataset):
|
||||
# return max_values
|
||||
return torch.gather(data, dim, torch.argmax(data.abs(), dim=dim, keepdim=True)).squeeze(dim=dim)
|
||||
|
||||
|
||||
def rotate(self, data, angle):
|
||||
# rotates a 2d tensor by a given angle
|
||||
# data: [2, ...]
|
||||
@@ -389,6 +523,24 @@ class FiberRegenerationDataset(Dataset):
|
||||
|
||||
return torch.stack([data[0] * cosine - data[1] * sine, data[0] * sine + data[1] * cosine], dim=0)
|
||||
|
||||
def rotate_all(self):
|
||||
def do_rotation(j, num_processes):
|
||||
for i in range(len(self) // num_processes):
|
||||
index = i * num_processes + j
|
||||
self.data[index, 1, :2, :] = self.rotate(self.data[index, 1, :2, :], self.angles[index])
|
||||
|
||||
self.processes = []
|
||||
|
||||
for j in range(mp.cpu_count()):
|
||||
self.processes.append(mp.Process(target=do_rotation, args=(j, mp.cpu_count())))
|
||||
self.processes[-1].start()
|
||||
|
||||
for p in self.processes:
|
||||
p.join()
|
||||
|
||||
for i in range(len(self) // mp.cpu_count() * mp.cpu_count(), len(self)):
|
||||
self.data[i, 1, :2, :] = self.rotate(self.data[i, 1, :2, :], self.angles[i])
|
||||
|
||||
def polarimeter(self, data):
|
||||
# data: [2, ...] -> x, y
|
||||
# returns [4] -> S0, S1, S2, S3
|
||||
@@ -396,12 +548,12 @@ class FiberRegenerationDataset(Dataset):
|
||||
y = data[1].mean()
|
||||
I_X = x.abs().square()
|
||||
I_Y = y.abs().square()
|
||||
I_45 = (x+y).abs().square()
|
||||
I_RHC = (x + 1j*y).abs().square()
|
||||
I_45 = (x + y).abs().square()
|
||||
I_RHC = (x + 1j * y).abs().square()
|
||||
|
||||
S0 = I_X + I_Y
|
||||
S1 = (2*I_X - S0) / S0
|
||||
S2 = (2*I_45 - S0) / S0
|
||||
S3 = (2*I_RHC - S0) / S0
|
||||
S1 = (2 * I_X - S0) / S0
|
||||
S2 = (2 * I_45 - S0) / S0
|
||||
S3 = (2 * I_RHC - S0) / S0
|
||||
|
||||
return torch.stack([S1, S2, S3], dim=0)
|
||||
return torch.stack([S0, S1, S2, S3], dim=0)
|
||||
|
||||
@@ -1,16 +1,23 @@
|
||||
from datetime import datetime
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
import h5py
|
||||
from matplotlib import pyplot as plt
|
||||
from matplotlib.colors import LinearSegmentedColormap
|
||||
# from cmap import Colormap as cm
|
||||
import numpy as np
|
||||
from scipy.cluster.vq import kmeans2
|
||||
import warnings
|
||||
import multiprocessing
|
||||
|
||||
from rich.traceback import install
|
||||
from rich import pretty
|
||||
from rich import print
|
||||
|
||||
install()
|
||||
pretty.install()
|
||||
# from rich import pretty
|
||||
# from rich import print
|
||||
|
||||
# pretty.install()
|
||||
|
||||
|
||||
def generate_sample_data(n_symbols=1000, sps=128, noise=0.01, skew=1):
|
||||
@@ -21,6 +28,7 @@ def generate_sample_data(n_symbols=1000, sps=128, noise=0.01, skew=1):
|
||||
xaxis = np.arange(0, len(signal)) / sps
|
||||
return np.vstack([xaxis, signal])
|
||||
|
||||
|
||||
def create_symbol_sequence(n_symbols, skew=1):
|
||||
np.random.seed(42)
|
||||
data = np.random.randint(0, 4, n_symbols) / 4
|
||||
@@ -39,6 +47,14 @@ def generate_signal(data, sps):
|
||||
signal = np.convolve(data_padded, wavelet)
|
||||
signal = np.cumsum(signal)
|
||||
signal = signal[sps // 2 + len(wavelet) // 2 - 1 : -len(wavelet) // 2]
|
||||
mi, ma = np.min(signal), np.max(signal)
|
||||
|
||||
signal = (signal - mi) / (ma - mi)
|
||||
|
||||
mod = 0.8
|
||||
|
||||
signal *= mod
|
||||
signal += 1 - mod
|
||||
|
||||
return signal
|
||||
|
||||
@@ -49,8 +65,8 @@ def normalization_with_noise(signal, noise=0):
|
||||
signal += awgn
|
||||
|
||||
# min-max normalization
|
||||
signal = signal - np.min(signal)
|
||||
signal = signal / np.max(signal)
|
||||
# signal = signal - np.min(signal)
|
||||
# signal = signal / np.max(signal)
|
||||
return signal
|
||||
|
||||
|
||||
@@ -68,98 +84,264 @@ def generate_wavelet(sps, oversample=3):
|
||||
|
||||
|
||||
class eye_diagram:
|
||||
def __init__(self, data, *, channel_names=None, horizontal_bins=256, vertical_bins=1000, n_levels=4, multithreaded=True):
|
||||
def __init__(
|
||||
self,
|
||||
data,
|
||||
*,
|
||||
channel_names=None,
|
||||
horizontal_bins=256,
|
||||
vertical_bins=1000,
|
||||
n_levels=4,
|
||||
multithreaded=True,
|
||||
save_file_or_dir=None,
|
||||
):
|
||||
# data has shape [channels, 2, samples]
|
||||
# each sample has a timestamp and a value
|
||||
if data.ndim == 2:
|
||||
data = data[np.newaxis, :, :]
|
||||
self.channel_names = channel_names
|
||||
self.raw_data = data
|
||||
self.channels = data.shape[0]
|
||||
|
||||
self.y_bins = np.zeros(1)
|
||||
self.x_bins = np.zeros(1)
|
||||
self.eye_data = np.zeros(1)
|
||||
self.channel_names = channel_names
|
||||
self.n_channels = data.shape[0]
|
||||
self.n_levels = n_levels
|
||||
self.eye_stats = [{"success": False} for _ in range(self.channels)]
|
||||
self.eye_stats = [{"success": False} for _ in range(self.n_channels)]
|
||||
self.horizontal_bins = horizontal_bins
|
||||
self.vertical_bins = vertical_bins
|
||||
self.multi_threaded = multithreaded
|
||||
self.analysed = False
|
||||
self.eye_built = False
|
||||
self.analyse()
|
||||
|
||||
def generate_eye_data(self):
|
||||
self.x_bins = np.linspace(0, 2, self.horizontal_bins, endpoint=False)
|
||||
self.y_bins = np.zeros((self.channels, self.vertical_bins))
|
||||
self.eye_data = np.zeros((self.channels, self.vertical_bins, self.horizontal_bins))
|
||||
datas = [self.raw_data[i] for i in range(self.channels)]
|
||||
if self.multi_threaded:
|
||||
with multiprocessing.Pool() as pool:
|
||||
results = pool.map(self.generate_eye_data_single, datas)
|
||||
for i, result in enumerate(results):
|
||||
self.eye_data[i], self.y_bins[i] = result
|
||||
self.save_file = save_file_or_dir
|
||||
|
||||
def load_data(self, file=None):
|
||||
file = self.save_file if file is None else file
|
||||
|
||||
if file is None:
|
||||
raise FileNotFoundError("No file specified.")
|
||||
|
||||
self.save_file = str(file)
|
||||
# self.file_or_dir = self.save_file
|
||||
with h5py.File(file, "r") as infile:
|
||||
self.y_bins = infile["y_bins"][:]
|
||||
self.x_bins = infile["x_bins"][:]
|
||||
self.eye_data = infile["eye_data"][:]
|
||||
self.channel_names = infile.attrs["channel_names"]
|
||||
self.n_channels = infile.attrs["n_channels"]
|
||||
self.n_levels = infile.attrs["n_levels"]
|
||||
self.eye_stats = infile.attrs["eye_stats"]
|
||||
self.eye_stats = [json.loads(stat) for stat in self.eye_stats]
|
||||
self.horizontal_bins = infile.attrs["horizontal_bins"]
|
||||
self.vertical_bins = infile.attrs["vertical_bins"]
|
||||
self.multi_threaded = infile.attrs["multithreaded"]
|
||||
self.analysed = infile.attrs["analysed"]
|
||||
self.eye_built = infile.attrs["eye_built"]
|
||||
|
||||
def save_data(self, file_or_dir=None):
|
||||
file_or_dir = self.save_file if file_or_dir is None else file_or_dir
|
||||
if file_or_dir is None:
|
||||
file = Path(f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_eye_data.eye.h5")
|
||||
elif Path(file_or_dir).is_dir():
|
||||
file = Path(file_or_dir) / f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_eye_data.eye.h5"
|
||||
else:
|
||||
for i, data in enumerate(datas):
|
||||
self.eye_data[i], self.y_bins[i] = self.generate_eye_data_single(data)
|
||||
self.eye_built = True
|
||||
file = Path(file_or_dir)
|
||||
|
||||
# file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.save_file = str(file)
|
||||
|
||||
with h5py.File(file, "w") as outfile:
|
||||
outfile.create_dataset("eye_data", data=self.eye_data)
|
||||
outfile.create_dataset("y_bins", data=self.y_bins)
|
||||
outfile.create_dataset("x_bins", data=self.x_bins)
|
||||
outfile.attrs["channel_names"] = self.channel_names
|
||||
outfile.attrs["n_channels"] = self.n_channels
|
||||
outfile.attrs["n_levels"] = self.n_levels
|
||||
self.eye_stats = eye_diagram.convert_arrays(self.eye_stats)
|
||||
outfile.attrs["eye_stats"] = [json.dumps(stat) for stat in self.eye_stats]
|
||||
outfile.attrs["horizontal_bins"] = self.horizontal_bins
|
||||
outfile.attrs["vertical_bins"] = self.vertical_bins
|
||||
outfile.attrs["multithreaded"] = self.multi_threaded
|
||||
outfile.attrs["analysed"] = self.analysed
|
||||
outfile.attrs["eye_built"] = self.eye_built
|
||||
|
||||
@staticmethod
|
||||
def convert_arrays(input_object):
|
||||
"""
|
||||
convert ndarrays in (nested) dict to lists
|
||||
"""
|
||||
|
||||
if isinstance(input_object, np.ndarray):
|
||||
return input_object.tolist()
|
||||
elif isinstance(input_object, list):
|
||||
return [eye_diagram.convert_arrays(old) for old in input_object]
|
||||
elif isinstance(input_object, tuple):
|
||||
return tuple(eye_diagram.convert_arrays(old) for old in input_object)
|
||||
elif isinstance(input_object, dict):
|
||||
dict_out = {}
|
||||
for key, value in input_object.items():
|
||||
dict_out[key] = eye_diagram.convert_arrays(value)
|
||||
return dict_out
|
||||
return input_object
|
||||
|
||||
def generate_eye_data(
|
||||
self, mode: Literal["default", "load", "save", "nosave"] = "default", file_or_dir: Path | None | str = None
|
||||
):
|
||||
# modes:
|
||||
# default: try to load eye data from file, if not found, generate and save
|
||||
# load: try to load eye data from file, if not found, generate but don't save
|
||||
# save: generate eye data and save
|
||||
update_save = True
|
||||
if mode == "load":
|
||||
self.load_data(file_or_dir)
|
||||
update_save = False
|
||||
elif mode == "default":
|
||||
try:
|
||||
self.load_data(file_or_dir)
|
||||
update_save = False
|
||||
except (FileNotFoundError, IsADirectoryError):
|
||||
pass
|
||||
|
||||
if not self.eye_built:
|
||||
update_save = True
|
||||
self.x_bins = np.linspace(0, 2, self.horizontal_bins, endpoint=False)
|
||||
self.y_bins = np.zeros((self.n_channels, self.vertical_bins))
|
||||
self.eye_data = np.zeros((self.n_channels, self.vertical_bins, self.horizontal_bins))
|
||||
datas = [self.raw_data[i] for i in range(self.n_channels)]
|
||||
if self.multi_threaded:
|
||||
with multiprocessing.Pool() as pool:
|
||||
results = pool.map(self.generate_eye_data_single, datas)
|
||||
for i, result in enumerate(results):
|
||||
self.eye_data[i], self.y_bins[i] = result
|
||||
else:
|
||||
for i, data in enumerate(datas):
|
||||
self.eye_data[i], self.y_bins[i] = self.generate_eye_data_single(data)
|
||||
self.eye_built = True
|
||||
|
||||
if mode == "save" or (mode == "default" and update_save):
|
||||
self.save_data(file_or_dir)
|
||||
|
||||
def generate_eye_data_single(self, data):
|
||||
eye_data = np.zeros((self.vertical_bins, self.horizontal_bins))
|
||||
data_min = np.min(data[1, :])
|
||||
data_max = np.max(data[1, :])
|
||||
# round down/up to 1 decimal
|
||||
data_min = np.floor(data_min*10)/10
|
||||
data_max = np.ceil(data_max*10)/10
|
||||
# data_range = data_max - data_min
|
||||
# data_min -= 0.1 * data_range
|
||||
# data_max += 0.1 * data_range
|
||||
# data_min = -0.05
|
||||
# data_max += 0.05
|
||||
# data[1,:] -= np.min(data[1, :])
|
||||
# data[1,:] /= np.max(data[1, :])
|
||||
# data_min = 0
|
||||
# data_max = 1
|
||||
y_bins = np.linspace(data_min, data_max, self.vertical_bins, endpoint=False)
|
||||
t_vals = data[0, :] % 2
|
||||
val_vals = data[1, :]
|
||||
t_vals = data[0, :] % 2 # + np.random.randn(*data[0, :].shape) * 1 / (512)
|
||||
val_vals = data[1, :] # + np.random.randn(*data[1, :].shape) * 1 / (320)
|
||||
x_indices = np.digitize(t_vals, self.x_bins) - 1
|
||||
y_indices = np.digitize(val_vals, y_bins) - 1
|
||||
np.add.at(eye_data, (y_indices, x_indices), 1)
|
||||
return eye_data, y_bins
|
||||
|
||||
def plot(self, title="Eye Diagram", stats=True, all_stats=True, show=True):
|
||||
def plot(
|
||||
self,
|
||||
title="Eye Diagram",
|
||||
stats=True,
|
||||
all_stats=True,
|
||||
show=True,
|
||||
mode: Literal["default", "load", "save", "nosave"] = "default",
|
||||
# save_images = False,
|
||||
# image_dir = None,
|
||||
# cmap=None,
|
||||
):
|
||||
if stats and not self.analysed:
|
||||
self.analyse(mode=mode)
|
||||
if not self.eye_built:
|
||||
self.generate_eye_data()
|
||||
self.generate_eye_data(mode=mode)
|
||||
cmap = LinearSegmentedColormap.from_list(
|
||||
"eyemap",
|
||||
[(0, "white"), (0.1, "blue"), (0.2, "cyan"), (0.5, "green"), (0.8, "yellow"), (0.9, "red"), (1, "magenta")],
|
||||
[
|
||||
(0, "#FFFFFF00"),
|
||||
(0.1, "blue"),
|
||||
(0.2, "cyan"),
|
||||
(0.5, "green"),
|
||||
(0.8, "yellow"),
|
||||
(0.9, "red"),
|
||||
(1, "magenta"),
|
||||
],
|
||||
)
|
||||
if self.channels % 2 == 0:
|
||||
# cmap = cm('google:turbo_r' if cmap is None else cmap)
|
||||
# first = cmap(-1)
|
||||
# cmap = cmap.to_mpl()
|
||||
# cmap.set_under(first, alpha=0)
|
||||
if self.n_channels % 2 == 0:
|
||||
rows = 2
|
||||
cols = self.channels // 2
|
||||
cols = self.n_channels // 2
|
||||
else:
|
||||
cols = int(np.ceil(np.sqrt(self.channels)))
|
||||
rows = int(np.ceil(self.channels / cols))
|
||||
cols = int(np.ceil(np.sqrt(self.n_channels)))
|
||||
rows = int(np.ceil(self.n_channels / cols))
|
||||
fig, ax = plt.subplots(rows, cols, sharex=True, sharey=False)
|
||||
fig.suptitle(title)
|
||||
fig.tight_layout()
|
||||
ax = np.atleast_1d(ax).transpose().flatten()
|
||||
for i in range(self.channels):
|
||||
ax[i].set_title(self.channel_names[i] if self.channel_names is not None else f"Channel {i+1}")
|
||||
if (i+1) % rows == 0:
|
||||
for i in range(self.n_channels):
|
||||
ax[i].set_title(self.channel_names[i] if self.channel_names is not None else f"Channel {i + 1}")
|
||||
if (i + 1) % rows == 0:
|
||||
ax[i].set_xlabel("Symbol")
|
||||
if i < rows:
|
||||
ax[i].set_ylabel("Amplitude")
|
||||
ax[i].grid()
|
||||
ax[i].set_axisbelow(True)
|
||||
ax[i].imshow(
|
||||
self.eye_data[i],
|
||||
self.eye_data[i] - 0.1,
|
||||
origin="lower",
|
||||
aspect="auto",
|
||||
cmap=cmap,
|
||||
extent=[self.x_bins[0], self.x_bins[-1], self.y_bins[i][0], self.y_bins[i][-1]],
|
||||
interpolation="gaussian",
|
||||
vmin=0,
|
||||
zorder=3,
|
||||
)
|
||||
ax[i].set_xlim((self.x_bins[0], self.x_bins[-1]))
|
||||
ymin = np.min(self.y_bins[:, 0])
|
||||
ymax = np.max(self.y_bins[:, -1])
|
||||
yspan = ymax - ymin
|
||||
ax[i].set_ylim((ymin - 0.1 * yspan, ymax + 0.1 * yspan))
|
||||
# if save_images:
|
||||
# image_dir = "images_out" if image_dir is None else image_dir
|
||||
# image_path = Path(image_dir) / (slugify(f"{datetime.now().strftime("%Y%m%d_%H%M%S")}_{title.replace(" ","_")}_{self.channel_names[i].replace(" ", "_") if self.channel_names is not None else f"{i + 1}"}_{ymin:.1f}_{ymax:.1f}") + ".png")
|
||||
# image_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
# # plt.imsave(
|
||||
# # image_path,
|
||||
# # self.eye_data[i] - 0.1,
|
||||
# # origin="lower",
|
||||
# # # aspect="auto",
|
||||
# # cmap=cmap,
|
||||
# # # extent=[self.x_bins[0], self.x_bins[-1], self.y_bins[i][0], self.y_bins[i][-1]],
|
||||
# # # interpolation="gaussian",
|
||||
# # vmin=0,
|
||||
# # # zorder=3,
|
||||
# # )
|
||||
if stats and self.eye_stats[i]["success"]:
|
||||
# add min_area above the plot
|
||||
ax[i].annotate(
|
||||
f"Min Area: {self.eye_stats[i]['min_area']:.2e}",
|
||||
xy=(0.05, ymax + 0.05 * yspan),
|
||||
# xycoords="axes fraction",
|
||||
ha="left",
|
||||
va="center",
|
||||
bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"),
|
||||
)
|
||||
# # add min_area above the plot
|
||||
# ax[i].annotate(
|
||||
# f"Min Area: {self.eye_stats[i]['min_area']:.2e}",
|
||||
# xy=(0.05, ymax + 0.05 * yspan),
|
||||
# # xycoords="axes fraction",
|
||||
# ha="left",
|
||||
# va="center",
|
||||
# bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"),
|
||||
# )
|
||||
|
||||
if all_stats:
|
||||
ax[i].plot([0, 2], [self.eye_stats[i]["levels"], self.eye_stats[i]["levels"]], "k--")
|
||||
ax[i].set_yticks(self.eye_stats[i]["levels"])
|
||||
y_ticks = (*self.eye_stats[i]["levels"], *self.eye_stats[i]["thresholds"])
|
||||
# y_ticks = np.sort(y_ticks)
|
||||
ax[i].set_yticks(y_ticks)
|
||||
# add arrows for amplitudes
|
||||
for j in range(len(self.eye_stats[i]["amplitudes"])):
|
||||
ax[i].annotate(
|
||||
@@ -193,35 +375,35 @@ class eye_diagram:
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
# add arrows for eye widths
|
||||
for j in range(len(self.eye_stats[i]["widths"])):
|
||||
try:
|
||||
left = np.max(self.eye_stats[i]["time_clusters"][j][0])
|
||||
right = np.min(self.eye_stats[i]["time_clusters"][j][1])
|
||||
vertical = (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2
|
||||
# for j in range(len(self.eye_stats[i]["widths"])):
|
||||
# try:
|
||||
# left = np.max(self.eye_stats[i]["time_clusters"][j][0])
|
||||
# right = np.min(self.eye_stats[i]["time_clusters"][j][1])
|
||||
# vertical = (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2
|
||||
|
||||
ax[i].annotate(
|
||||
"",
|
||||
xy=(left, vertical),
|
||||
xytext=(right, vertical),
|
||||
arrowprops=dict(arrowstyle="<->", facecolor="black"),
|
||||
)
|
||||
ax[i].annotate(
|
||||
f"{self.eye_stats[i]['widths'][j]:.2e}",
|
||||
xy=((left + right) / 2 - 0.15, vertical + 0.01),
|
||||
bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"),
|
||||
)
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
# ax[i].annotate(
|
||||
# "",
|
||||
# xy=(left, vertical),
|
||||
# xytext=(right, vertical),
|
||||
# arrowprops=dict(arrowstyle="<->", facecolor="black"),
|
||||
# )
|
||||
# ax[i].annotate(
|
||||
# f"{self.eye_stats[i]['widths'][j]:.2e}",
|
||||
# xy=((left + right) / 2 - 0.15, vertical + 0.01),
|
||||
# bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"),
|
||||
# )
|
||||
# except (ValueError, IndexError):
|
||||
# pass
|
||||
|
||||
# add area
|
||||
for j in range(len(self.eye_stats[i]["areas"])):
|
||||
horizontal = self.eye_stats[i]["time_midpoint"]
|
||||
vertical = (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2
|
||||
ax[i].annotate(
|
||||
f"{self.eye_stats[i]['areas'][j]:.2e}",
|
||||
xy=(horizontal + 0.035, vertical - 0.07),
|
||||
bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"),
|
||||
)
|
||||
# # add area
|
||||
# for j in range(len(self.eye_stats[i]["areas"])):
|
||||
# horizontal = self.eye_stats[i]["time_midpoint"]
|
||||
# vertical = (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2
|
||||
# ax[i].annotate(
|
||||
# f"{self.eye_stats[i]['areas'][j]:.2e}",
|
||||
# xy=(horizontal + 0.035, vertical - 0.07),
|
||||
# bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"),
|
||||
# )
|
||||
|
||||
fig.tight_layout()
|
||||
|
||||
@@ -229,26 +411,33 @@ class eye_diagram:
|
||||
plt.show()
|
||||
return fig
|
||||
|
||||
@staticmethod
|
||||
def calculate_thresholds(levels):
|
||||
ret = np.cumsum(levels, dtype=float)
|
||||
ret[2:] = ret[2:] - ret[:-2]
|
||||
return ret[1:] / 2
|
||||
|
||||
def analyse_single(self, data, index):
|
||||
warnings.filterwarnings("error")
|
||||
eye_stats = {}
|
||||
eye_stats["channel_name"] = str(index+1) if self.channel_names is None else self.channel_names[index]
|
||||
eye_stats["channel_name"] = str(index + 1) if self.channel_names is None else self.channel_names[index]
|
||||
try:
|
||||
approx_levels = eye_diagram.approximate_levels(data, self.n_levels)
|
||||
|
||||
time_bounds = eye_diagram.calculate_time_bounds(data, approx_levels)
|
||||
|
||||
eye_stats["time_midpoint"] = (time_bounds[0] + time_bounds[1]) / 2
|
||||
eye_stats["time_midpoint"] = float((time_bounds[0] + time_bounds[1]) / 2)
|
||||
# eye_stats["time_midpoint"] = 1.0
|
||||
|
||||
eye_stats["levels"], eye_stats["amplitude_clusters"] = eye_diagram.calculate_levels(
|
||||
data, approx_levels, time_bounds
|
||||
)
|
||||
|
||||
eye_stats["thresholds"] = self.calculate_thresholds(eye_stats["levels"])
|
||||
|
||||
eye_stats["amplitudes"] = np.diff(eye_stats["levels"])
|
||||
|
||||
eye_stats["heights"] = eye_diagram.calculate_eye_heights(
|
||||
eye_stats["amplitude_clusters"]
|
||||
)
|
||||
eye_stats["heights"] = eye_diagram.calculate_eye_heights(eye_stats["amplitude_clusters"])
|
||||
|
||||
eye_stats["widths"], eye_stats["time_clusters"] = eye_diagram.calculate_eye_widths(
|
||||
data, eye_stats["levels"]
|
||||
@@ -260,36 +449,59 @@ class eye_diagram:
|
||||
# if not (np.max(eye_stats['time_clusters'][j][0]) < eye_stats["time_midpoint"] < np.min(eye_stats['time_clusters'][j][1])):
|
||||
# raise ValueError
|
||||
|
||||
eye_stats["areas"] = eye_stats["heights"] * eye_stats["widths"]
|
||||
eye_stats["mean_area"] = np.mean(eye_stats["areas"])
|
||||
eye_stats["min_area"] = np.min(eye_stats["areas"])
|
||||
# eye_stats["areas"] = eye_stats["heights"] * eye_stats["widths"]
|
||||
# eye_stats["mean_area"] = np.mean(eye_stats["areas"])
|
||||
# eye_stats["min_area"] = np.min(eye_stats["areas"])
|
||||
|
||||
eye_stats["success"] = True
|
||||
except (RuntimeWarning, UserWarning, ValueError):
|
||||
eye_stats["success"] = False
|
||||
eye_stats["time_midpoint"] = 0
|
||||
eye_stats["levels"] = np.zeros(self.n_levels)
|
||||
eye_stats["amplitude_clusters"] = []
|
||||
eye_stats["amplitudes"] = np.zeros(self.n_levels - 1)
|
||||
eye_stats["heights"] = np.zeros(self.n_levels - 1)
|
||||
eye_stats["widths"] = np.zeros(self.n_levels - 1)
|
||||
eye_stats["areas"] = np.zeros(self.n_levels - 1)
|
||||
eye_stats["mean_area"] = 0
|
||||
eye_stats["min_area"] = 0
|
||||
eye_stats["time_midpoint"] = None
|
||||
eye_stats["levels"] = None
|
||||
eye_stats["thresholds"] = None
|
||||
eye_stats["amplitude_clusters"] = None
|
||||
eye_stats["amplitudes"] = None
|
||||
eye_stats["heights"] = None
|
||||
eye_stats["widths"] = None
|
||||
# eye_stats["areas"] = np.zeros(self.n_levels - 1)
|
||||
# eye_stats["mean_area"] = 0
|
||||
# eye_stats["min_area"] = 0
|
||||
warnings.resetwarnings()
|
||||
return eye_stats
|
||||
|
||||
def analyse(
|
||||
self, mode: Literal["default", "load", "save", "nosave"] = "default", file_or_dir: Path | None | str = None
|
||||
):
|
||||
# modes:
|
||||
# default: try to load eye data from file, if not found, generate and save
|
||||
# load: try to load eye data from file, if not found, generate but don't save
|
||||
# save: generate eye data and save
|
||||
update_save = True
|
||||
if mode == "load":
|
||||
self.load_data(file_or_dir)
|
||||
update_save = False
|
||||
elif mode == "default":
|
||||
try:
|
||||
self.load_data(file_or_dir)
|
||||
update_save = False
|
||||
except (FileNotFoundError, IsADirectoryError):
|
||||
pass
|
||||
|
||||
def analyse(self):
|
||||
self.eye_stats = []
|
||||
if self.multi_threaded:
|
||||
with multiprocessing.Pool() as pool:
|
||||
results = pool.starmap(self.analyse_single, [(self.raw_data[i], i) for i in range(self.channels)])
|
||||
for i, result in enumerate(results):
|
||||
self.eye_stats.append(result)
|
||||
else:
|
||||
for i in range(self.channels):
|
||||
self.eye_stats.append(self.analyse_single(self.raw_data[i], i))
|
||||
if not self.analysed:
|
||||
update_save = True
|
||||
self.eye_stats = []
|
||||
if self.multi_threaded:
|
||||
with multiprocessing.Pool() as pool:
|
||||
results = pool.starmap(self.analyse_single, [(self.raw_data[i], i) for i in range(self.n_channels)])
|
||||
for i, result in enumerate(results):
|
||||
self.eye_stats.append(result)
|
||||
else:
|
||||
for i in range(self.n_channels):
|
||||
self.eye_stats.append(self.analyse_single(self.raw_data[i], i))
|
||||
self.analysed = True
|
||||
|
||||
if mode == "save" or (mode == "default" and update_save):
|
||||
self.save_data(file_or_dir)
|
||||
|
||||
@staticmethod
|
||||
def approximate_levels(data, levels):
|
||||
@@ -431,7 +643,7 @@ class eye_diagram:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
length = int(2**14)
|
||||
length = int(2**16)
|
||||
# data = generate_sample_data(length, noise=1)
|
||||
# data1 = generate_sample_data(length, noise=0.01)
|
||||
# data2 = generate_sample_data(length, noise=0.01, skew=1.2)
|
||||
@@ -439,12 +651,13 @@ if __name__ == "__main__":
|
||||
|
||||
# data = np.stack([data, data1, data2, data3])
|
||||
|
||||
data = generate_sample_data(length, noise=0.005)
|
||||
eye = eye_diagram(data, horizontal_bins=256, vertical_bins=256)
|
||||
attrs = ("success", "amplitudes", "time_midpoint", "levels", "heights", "widths", "area", "mean_area", "min_area")
|
||||
for i, channel in enumerate(eye.eye_stats):
|
||||
print(f"Channel {i}")
|
||||
print_data = {attr: channel[attr] for attr in attrs}
|
||||
print(print_data)
|
||||
data = generate_sample_data(length, noise=0.0000)
|
||||
eye = eye_diagram(data, horizontal_bins=1056, vertical_bins=1200)
|
||||
eye.plot(mode="nosave", stats=False)
|
||||
# attrs = ("success", "amplitudes", "time_midpoint", "levels", "heights", "widths")#, "area", "mean_area", "min_area")
|
||||
# for i, channel in enumerate(eye.eye_stats):
|
||||
# print(f"Channel {i}")
|
||||
# print_data = {attr: channel[attr] for attr in attrs}
|
||||
# print(print_data)
|
||||
|
||||
eye.plot()
|
||||
# eye.plot()
|
||||
|
||||
122
src/single-core-regen/util/mpl.py
Normal file
122
src/single-core-regen/util/mpl.py
Normal file
@@ -0,0 +1,122 @@
|
||||
# Copyright (c) 2015, Warren Weckesser. All rights reserved.
|
||||
# This software is licensed according to the "BSD 2-clause" license.
|
||||
|
||||
# modified by Joseph Hopfmüller in 2025,
|
||||
# for integration into optical regeneration analysis scripts
|
||||
|
||||
from pathlib import Path
|
||||
from matplotlib.colors import LinearSegmentedColormap
|
||||
import matplotlib.colors as colors
|
||||
import numpy as _np
|
||||
from .core import grid_count as _grid_count
|
||||
import matplotlib.pyplot as _plt
|
||||
import numpy as np
|
||||
from scipy.ndimage import gaussian_filter
|
||||
|
||||
|
||||
# from ._common import _common_doc
|
||||
|
||||
|
||||
__all__ = ["eyediagram"] # , 'eyediagram_lines']
|
||||
|
||||
|
||||
# def eyediagram_lines(y, window_size, offset=0, **plotkwargs):
|
||||
# """
|
||||
# Plot an eye diagram using matplotlib by repeatedly calling the `plot`
|
||||
# function.
|
||||
# <common>
|
||||
|
||||
# """
|
||||
# start = offset
|
||||
# while start < len(y):
|
||||
# end = start + window_size
|
||||
# if end > len(y):
|
||||
# end = len(y)
|
||||
# yy = y[start:end+1]
|
||||
# _plt.plot(_np.arange(len(yy)), yy, 'k', **plotkwargs)
|
||||
# start = end
|
||||
|
||||
# eyediagram_lines.__doc__ = eyediagram_lines.__doc__.replace("<common>",
|
||||
# _common_doc)
|
||||
|
||||
|
||||
eyemap = LinearSegmentedColormap.from_list(
|
||||
"eyemap",
|
||||
[
|
||||
(0, "#0000FF00"),
|
||||
(0.1, "blue"),
|
||||
(0.2, "cyan"),
|
||||
(0.5, "green"),
|
||||
(0.8, "yellow"),
|
||||
(0.9, "red"),
|
||||
(1, "magenta"),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def eyediagram(
|
||||
y,
|
||||
window_size,
|
||||
offset=0,
|
||||
colorbar=False,
|
||||
show=False,
|
||||
save_im=False,
|
||||
overwrite=False,
|
||||
blur: int | bool = True,
|
||||
save_path="out.png",
|
||||
bounds=None,
|
||||
**imshowkwargs,
|
||||
):
|
||||
"""
|
||||
Plot an eye diagram using matplotlib by creating an image and calling
|
||||
the `imshow` function.
|
||||
<common>
|
||||
"""
|
||||
if bounds is None:
|
||||
ymax = y.max()
|
||||
ymin = y.min()
|
||||
yamp = ymax - ymin
|
||||
ymin = ymin - 0.05 * yamp
|
||||
ymax = ymax + 0.05 * yamp
|
||||
ymin = np.floor(ymin * 10) / 10
|
||||
ymax = np.ceil(ymax * 10) / 10
|
||||
bounds = (ymin, ymax)
|
||||
counts = _grid_count(y, window_size, offset, bounds=bounds, size=(1000, 1200), blur=int(blur))
|
||||
counts = counts.astype(_np.float32)
|
||||
origin = imshowkwargs.pop("origin", "lower")
|
||||
cmap: colors.Colormap = imshowkwargs.pop("cmap", eyemap)
|
||||
vmin = imshowkwargs.pop("vmin", 1)
|
||||
vmax = imshowkwargs.pop("vmax", None)
|
||||
cmap.set_under("white", alpha=0)
|
||||
|
||||
if show:
|
||||
_plt.imshow(
|
||||
counts.T[::-1, :],
|
||||
extent=[0, 2, *bounds],
|
||||
origin=origin,
|
||||
cmap=cmap,
|
||||
vmin=vmin,
|
||||
vmax=vmax,
|
||||
**imshowkwargs,
|
||||
)
|
||||
_plt.grid()
|
||||
if colorbar:
|
||||
_plt.colorbar()
|
||||
|
||||
if Path(save_path).is_file() and not overwrite:
|
||||
save_im = False
|
||||
if save_im:
|
||||
from PIL import Image
|
||||
arr = counts.T[::-1, :]
|
||||
if origin == "lower":
|
||||
arr = arr[::-1]
|
||||
arr = (arr-arr.min())/(arr.max()-arr.min())
|
||||
image = Image.fromarray((cmap(arr)[:, :, :] * 255).astype(np.uint8))
|
||||
image.save(save_path)
|
||||
# print("-")
|
||||
|
||||
if show:
|
||||
_plt.show()
|
||||
|
||||
|
||||
# eyediagram.__doc__ = eyediagram.__doc__.replace("<common>", _common_doc)
|
||||
@@ -1,6 +1,9 @@
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from .datasets import load_data
|
||||
if __name__ == "__main__":
|
||||
from datasets import load_data
|
||||
else:
|
||||
from .datasets import load_data
|
||||
|
||||
def eye(*, path=None, data=None, sps=None, title=None, symbols=1000, skipfirst=0, width=2, alpha=None, complex=False, show=True):
|
||||
"""Plot an eye diagram for the data given by filepath.
|
||||
@@ -20,6 +23,7 @@ def eye(*, path=None, data=None, sps=None, title=None, symbols=1000, skipfirst=0
|
||||
raise ValueError("Either path or data and sps must be given.")
|
||||
if path is not None:
|
||||
data, config = load_data(path, skipfirst, symbols)
|
||||
data = data.detach().cpu().numpy()[:, :4]
|
||||
sps = int(config["glova"]["sps"])
|
||||
if sps is None:
|
||||
raise ValueError("sps not set.")
|
||||
@@ -71,3 +75,6 @@ def eye(*, path=None, data=None, sps=None, title=None, symbols=1000, skipfirst=0
|
||||
plt.show()
|
||||
|
||||
return fig
|
||||
|
||||
if __name__ == "__main__":
|
||||
eye(path="data/20241229-163838-128-16384-50000-0-0.2-16.8-0.058-PAM4-0-0.16.ini", symbols=1000, width=2, alpha=0.1, complex=False)
|
||||
Reference in New Issue
Block a user