Compare commits
5 Commits
33141bdf41
...
machine_le
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
249fe1e940 | ||
|
|
f38d0ca3bb | ||
|
|
3af73343c1 | ||
|
|
7a0b65f82d | ||
|
|
98305fdf47 |
2
.gitignore
vendored
2
.gitignore
vendored
@@ -163,3 +163,5 @@ cython_debug/
|
|||||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
#.idea/
|
#.idea/
|
||||||
|
tolerance_results/*
|
||||||
|
data/*
|
||||||
|
|||||||
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:bc02b0099ea3bb136733e3d20817cad79b6c50c2e4b845f0d206455dde188cc4
|
||||||
|
size 615
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:d80ff6f2a84acf973fbdf81a05ed0b1902f8bf97856cd5132b646f6b1173f496
|
||||||
|
size 615
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:54ac6b6a452aa6b7d312a4c8ab8f7ebe2f96c1c4170cbc56147e8f2f9d934ad6
|
||||||
|
size 615
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:bd2c6f4050488b6e857759d48aa1f1f37399d81cee1667d3668145e938d17c83
|
||||||
|
size 615
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:c59b8113092459a8751b385a7b1a6f10828626d2ec2f29775512157fd9bbc75c
|
||||||
|
size 615
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:04eaa2a29b3302e5de3bf99bf6a57fc8f27361fd3df3cac9245e25ab99324829
|
||||||
|
size 615
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:704d3b0b17b9d320f4717b5a5a8bdbc5714f3caa4efa7153e980766429e834f4
|
||||||
|
size 615
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:29304f35a88fd777566105f8666fff1c9927beb32756822365bcf9c159feb98e
|
||||||
|
size 615
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:4a87488c12e0253b2bb5d1ac7aa1536f69ef569e62c2aab6a10149d753e049b4
|
||||||
|
size 615
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:ff47c8d5413881edb03cfadfcde3b550ef7089543615a467c4f0027edaf1455e
|
||||||
|
size 615
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:9e05a0a54c7e3aaffeca0ac90cce1b274a544d90b329a93d453392f5df4e91a8
|
||||||
|
size 616
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:25e45b06b551ab8031c2159030d658999fcb3d1f0a34538c90768b94c8116771
|
||||||
|
size 616
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:707ed73713b6c2d80d4e333b1ccdf4650f50aefefff58ddb471c1d5411954b3d
|
||||||
|
size 616
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:a8c3baf878943741d83835c1afed05c1f9780ff3f0df260c0d92706343e59c50
|
||||||
|
size 616
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:98878d09510dedc24f1429bbd12cce49c66be6c9d279a28765b120efe820a171
|
||||||
|
size 616
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:fcbdaffa211d6b0b44b3ae1c66645999e95901bfdb2fffee4c45e34a0d901ee1
|
||||||
|
size 649
|
||||||
3
data/npys/0d56eb7933de7bd4a847eaa1ba4fd2c4.npy
Normal file
3
data/npys/0d56eb7933de7bd4a847eaa1ba4fd2c4.npy
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:73e922310068a66ab1a0c3d39b0b0fd7db798f2637b199a8c4bd113a38bb28c8
|
||||||
|
size 134217856
|
||||||
3
data/npys/0fb99fd81fd3076612f6aca9e2926e36.npy
Normal file
3
data/npys/0fb99fd81fd3076612f6aca9e2926e36.npy
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:43e80dc7d21aeff62c73f0ed02b20a4ac9b573d0e131a3ab8d1471077e03634b
|
||||||
|
size 134217856
|
||||||
3
data/npys/2a7f021c1c1112554d469b4e7bc3081f.npy
Normal file
3
data/npys/2a7f021c1c1112554d469b4e7bc3081f.npy
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:fb74cfbeec54b4f263c08510312881527fc7e484604fa0a1213b596f175fecc2
|
||||||
|
size 134217856
|
||||||
3
data/npys/3f6b6dfaecf6f1c003c6d9dfe950c6d0.npy
Normal file
3
data/npys/3f6b6dfaecf6f1c003c6d9dfe950c6d0.npy
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:447cca0af25e309c8be61216cea4fb2d3a8a967b0522760f1dab8f15e6b41574
|
||||||
|
size 134217856
|
||||||
3
data/npys/4365258197699ae6a83e846b9985f606.npy
Normal file
3
data/npys/4365258197699ae6a83e846b9985f606.npy
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:63321906920de825fc7857aa9f2e4944c3b32f3eadf99da06b66f6599924bc4c
|
||||||
|
size 134217856
|
||||||
3
data/npys/484ffd154056e75885022c286dbfafa7.npy
Normal file
3
data/npys/484ffd154056e75885022c286dbfafa7.npy
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:92e3fcc41f05380cb7b334fefc0a30bb8c1dfd9ca5c3b2cfaad36b0c7093914e
|
||||||
|
size 134217856
|
||||||
3
data/npys/55aac65b78a8181629f963815a7cdf5c.npy
Normal file
3
data/npys/55aac65b78a8181629f963815a7cdf5c.npy
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:0bf671d666d35617edd7bfb58548a5120bd92e5f9cb6edb4b5fc8d3bf5db8987
|
||||||
|
size 134217856
|
||||||
3
data/npys/5aaa36769f52738e1fa37f4c79b9eed2.npy
Normal file
3
data/npys/5aaa36769f52738e1fa37f4c79b9eed2.npy
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:bd429a5776555671b19d42e3ae152026dafd5bf95aeed9847df1432ed37f3eba
|
||||||
|
size 134217856
|
||||||
3
data/npys/6789fdea2609799ef2e975907625b79a.h5
Normal file
3
data/npys/6789fdea2609799ef2e975907625b79a.h5
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:1df90745cc2e6d4b0ad964fca2de1441e6e0b4b8345fbb0fbc1ffe9820674269
|
||||||
|
size 134481920
|
||||||
3
data/npys/78f8295da779018e431af641e9f3eb76.npy
Normal file
3
data/npys/78f8295da779018e431af641e9f3eb76.npy
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:477365edd696f66610298198257422a017f825e1f8bec4363bdfb0da0d741ebc
|
||||||
|
size 134217856
|
||||||
3
data/npys/7a63d865c33ab1bab2697ee5a9c2ce3b.npy
Normal file
3
data/npys/7a63d865c33ab1bab2697ee5a9c2ce3b.npy
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:1814f8ae0e6cdb69c0030741a3c6b35c74f2d6985eed34c0d5b4135384014abc
|
||||||
|
size 134217856
|
||||||
3
data/npys/a0ddcc9a8f2fc25a1152ec78d54d9676.npy
Normal file
3
data/npys/a0ddcc9a8f2fc25a1152ec78d54d9676.npy
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:d53bc0fd0897c7372c2f182b84163bcf401ad91f26b6949c6d7a1d70c5dbb513
|
||||||
|
size 134217856
|
||||||
3
data/npys/cf33366e5842c443d7e67f89947a380f.npy
Normal file
3
data/npys/cf33366e5842c443d7e67f89947a380f.npy
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:fc882f29530f7d683be631214f6667611e0aba87453a11157d71c50f3548fb3c
|
||||||
|
size 134217856
|
||||||
3
data/npys/d30358afa19a86066b4fda6adb74c1b4.npy
Normal file
3
data/npys/d30358afa19a86066b4fda6adb74c1b4.npy
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:0f0bc616c2e581444a6fa658e45ee942c1ef5a4d21f22363518331f8d80cbe62
|
||||||
|
size 134217856
|
||||||
3
data/npys/ec3a1c902e17ca5cb5d85398cea182f7.npy
Normal file
3
data/npys/ec3a1c902e17ca5cb5d85398cea182f7.npy
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:ad94f2af43a2a06ebddc78566cbef0ea538c3da9191682e2fde4ecddb061b0f0
|
||||||
|
size 134217856
|
||||||
3
data/npys/f4a8fbbeced371651d851440bd565ee4.npy
Normal file
3
data/npys/f4a8fbbeced371651d851440bd565ee4.npy
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:6930349d1ae1479dcb4c9ee9eaefe2da3eae62539cb4b97de873a7a8b175e809
|
||||||
|
size 134217856
|
||||||
@@ -1,3 +1,3 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:746ea83013870351296e01e294905ee027291ef79fd78c1e6b69dd9ebaa1cba0
|
oid sha256:76934d1d202aea1311ba67f5ea35eeb99a9c5c856f491565032e7d54ca6f072d
|
||||||
size 10240000
|
size 13598720
|
||||||
|
|||||||
37
notes/models.md
Normal file
37
notes/models.md
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
# models
|
||||||
|
|
||||||
|
## no polarisation flipping
|
||||||
|
|
||||||
|
```py
|
||||||
|
config_path="data/20241229-163*-128-16384-50000-*.ini"
|
||||||
|
model=".models/best_20241230_011907.tar"
|
||||||
|
```
|
||||||
|
|
||||||
|
```py
|
||||||
|
config_path="data/20241229-163*-128-16384-80000-*.ini"
|
||||||
|
model=".models/best_20241230_103752.tar"
|
||||||
|
```
|
||||||
|
|
||||||
|
```py
|
||||||
|
config_path="data/20241229-163*-128-16384-100000-*.ini"
|
||||||
|
model=".models/best_20241230_164534.tar"
|
||||||
|
```
|
||||||
|
|
||||||
|
## with polarisation flipping
|
||||||
|
|
||||||
|
polarisation flipping: signal is randomly rotated by 180°. polarization rotation can be detected by adding a tone on one of the polarisations, but only to mod 180° with a direct detection setup. the randomly flipped signal should allow the network to hopefully learn to compensate for dispersion, pmd independently from the polarization rot. the training data includes the flipped signal as well, but no indication if the polarisation is flipped.
|
||||||
|
|
||||||
|
```py
|
||||||
|
config_path="data/20241229-163*-128-16384-50000-*.ini"
|
||||||
|
model=".models/best_20241231_000328.tar"
|
||||||
|
```
|
||||||
|
|
||||||
|
```py
|
||||||
|
config_path="data/20241229-163*-128-16384-80000-*.ini"
|
||||||
|
model=".models/best_20241231_163614.tar"
|
||||||
|
```
|
||||||
|
|
||||||
|
```py
|
||||||
|
config_path="data/20241229-163*-128-16384-100000-*.ini"
|
||||||
|
model=".models/best_20241231_170532.tar"
|
||||||
|
```
|
||||||
59
notes/tolerance_testing.md
Normal file
59
notes/tolerance_testing.md
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
# Baseline Models
|
||||||
|
|
||||||
|
## a) D+S, pol_error 0, ortho_error 0, DGD 0
|
||||||
|
|
||||||
|
dataset
|
||||||
|
|
||||||
|
```raw
|
||||||
|
data/20250118-225840-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini
|
||||||
|
```
|
||||||
|
|
||||||
|
model
|
||||||
|
|
||||||
|
```raw
|
||||||
|
.models/best_20250118_225918.tar
|
||||||
|
```
|
||||||
|
|
||||||
|
## b) D+S, pol_error 0.4, ortho_error 0, DGD 0
|
||||||
|
|
||||||
|
dataset
|
||||||
|
|
||||||
|
```raw
|
||||||
|
data/20250116-214606-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini
|
||||||
|
```
|
||||||
|
|
||||||
|
model
|
||||||
|
|
||||||
|
```raw
|
||||||
|
.models/best_20250116_214816.tar
|
||||||
|
```
|
||||||
|
|
||||||
|
## c) D+S, pol_error 0, ortho_error 0.1, DGD 0
|
||||||
|
|
||||||
|
dataset
|
||||||
|
|
||||||
|
```raw
|
||||||
|
data/20250116-214547-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini
|
||||||
|
```
|
||||||
|
|
||||||
|
model
|
||||||
|
|
||||||
|
```raw
|
||||||
|
.models/best_20250117_122319.tar
|
||||||
|
```
|
||||||
|
|
||||||
|
## d) D+S, pol_error 0, ortho_error 0, DGD 10ps (1 T_sym)
|
||||||
|
|
||||||
|
birefringence angle pi/2 (worst case)
|
||||||
|
|
||||||
|
dataset
|
||||||
|
|
||||||
|
```raw
|
||||||
|
data/20250117-143926-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini
|
||||||
|
```
|
||||||
|
|
||||||
|
model
|
||||||
|
|
||||||
|
```raw
|
||||||
|
.models/best_20250117_144001.tar
|
||||||
|
```
|
||||||
2
pypho
2
pypho
Submodule pypho updated: dd015f4852...e44fc477fe
@@ -26,6 +26,8 @@ import torch
|
|||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
import torch.utils.data
|
import torch.utils.data
|
||||||
|
|
||||||
|
import hypertraining.models as models
|
||||||
|
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
@@ -253,14 +255,17 @@ class HyperTraining:
|
|||||||
model_kwargs = {
|
model_kwargs = {
|
||||||
"dims": (input_dim, *hidden_dims, self.model_settings.output_dim),
|
"dims": (input_dim, *hidden_dims, self.model_settings.output_dim),
|
||||||
"layer_function": layer_func,
|
"layer_function": layer_func,
|
||||||
"layer_parametrizations": layer_parametrizations,
|
"layer_func_kwargs": self.model_settings.model_layer_kwargs,
|
||||||
"activation_function": afunc,
|
"act_function": afunc,
|
||||||
|
"act_func_kwargs": None,
|
||||||
|
"parametrizations": layer_parametrizations,
|
||||||
"dtype": dtype,
|
"dtype": dtype,
|
||||||
"droupout_prob": self.model_settings.dropout_prob,
|
"dropout_prob": self.model_settings.dropout_prob,
|
||||||
"scale": scale_layers,
|
"scale_layers": scale_layers,
|
||||||
|
"rotate": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
model = util.complexNN.regenerator(**model_kwargs)
|
model = models.regenerator(*model_kwargs.pop("dims"), **model_kwargs)
|
||||||
n_nodes = sum(hidden_dims)
|
n_nodes = sum(hidden_dims)
|
||||||
|
|
||||||
if writer is not None:
|
if writer is not None:
|
||||||
@@ -381,7 +386,10 @@ class HyperTraining:
|
|||||||
running_loss = 0.0
|
running_loss = 0.0
|
||||||
model.train()
|
model.train()
|
||||||
loader_len = len(train_loader)
|
loader_len = len(train_loader)
|
||||||
for batch_idx, (x, y, _) in enumerate(train_loader):
|
for batch_idx, batch in enumerate(train_loader):
|
||||||
|
x = batch["x"]
|
||||||
|
y = batch["y"]
|
||||||
|
|
||||||
if batch_idx >= self.optuna_settings._n_train_batches:
|
if batch_idx >= self.optuna_settings._n_train_batches:
|
||||||
break
|
break
|
||||||
model.zero_grad(set_to_none=True)
|
model.zero_grad(set_to_none=True)
|
||||||
@@ -390,7 +398,7 @@ class HyperTraining:
|
|||||||
y.to(self.pytorch_settings.device),
|
y.to(self.pytorch_settings.device),
|
||||||
)
|
)
|
||||||
y_pred = model(x)
|
y_pred = model(x)
|
||||||
loss = util.complexNN.complex_mse_loss(y_pred, y)
|
loss = util.complexNN.complex_mse_loss(y_pred, y, power=True)
|
||||||
loss_value = loss.item()
|
loss_value = loss.item()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
@@ -444,7 +452,9 @@ class HyperTraining:
|
|||||||
model.eval()
|
model.eval()
|
||||||
running_error = 0
|
running_error = 0
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for batch_idx, (x, y, _) in enumerate(valid_loader):
|
for batch_idx, batch in enumerate(valid_loader):
|
||||||
|
x = batch["x"]
|
||||||
|
y = batch["y"]
|
||||||
if batch_idx >= self.optuna_settings._n_valid_batches:
|
if batch_idx >= self.optuna_settings._n_valid_batches:
|
||||||
break
|
break
|
||||||
x, y = (
|
x, y = (
|
||||||
@@ -452,50 +462,44 @@ class HyperTraining:
|
|||||||
y.to(self.pytorch_settings.device),
|
y.to(self.pytorch_settings.device),
|
||||||
)
|
)
|
||||||
y_pred = model(x)
|
y_pred = model(x)
|
||||||
error = util.complexNN.complex_mse_loss(y_pred, y)
|
error = util.complexNN.complex_mse_loss(y_pred, y, power=True)
|
||||||
error_value = error.item()
|
error_value = error.item()
|
||||||
running_error += error_value
|
running_error += error_value
|
||||||
|
|
||||||
running_error /= min(len(valid_loader), self.optuna_settings._n_valid_batches)
|
running_error /= min(len(valid_loader), self.optuna_settings._n_valid_batches)
|
||||||
|
|
||||||
if writer is not None:
|
if writer is not None:
|
||||||
title_append, subtitle = self.build_title(trial)
|
writer.add_scalar(
|
||||||
writer.add_figure(
|
"eval loss",
|
||||||
"fiber response",
|
running_error,
|
||||||
self.plot_model_response(
|
epoch,
|
||||||
trial,
|
|
||||||
model=model,
|
|
||||||
title_append=title_append,
|
|
||||||
subtitle=subtitle,
|
|
||||||
show=False,
|
|
||||||
),
|
|
||||||
epoch + 1,
|
|
||||||
)
|
|
||||||
writer.add_figure(
|
|
||||||
"eye diagram",
|
|
||||||
self.plot_model_response(
|
|
||||||
trial,
|
|
||||||
model=self.model,
|
|
||||||
title_append=title_append,
|
|
||||||
subtitle=subtitle,
|
|
||||||
show=False,
|
|
||||||
mode="eye",
|
|
||||||
),
|
|
||||||
epoch + 1,
|
|
||||||
)
|
)
|
||||||
|
# if (epoch + 1) % 10 == 0 or epoch < 10:
|
||||||
|
# # plotting is slow, so only do it every 10 epochs
|
||||||
|
# title_append, subtitle = self.build_title(trial)
|
||||||
|
# head_fig, eye_fig, powers_fig = self.plot_model_response(
|
||||||
|
# model=model,
|
||||||
|
# title_append=title_append,
|
||||||
|
# subtitle=subtitle,
|
||||||
|
# show=False,
|
||||||
|
# )
|
||||||
|
# writer.add_figure(
|
||||||
|
# "fiber response",
|
||||||
|
# head_fig,
|
||||||
|
# epoch + 1,
|
||||||
|
# )
|
||||||
|
# writer.add_figure(
|
||||||
|
# "eye diagram",
|
||||||
|
# eye_fig,
|
||||||
|
# epoch + 1,
|
||||||
|
# )
|
||||||
|
|
||||||
writer.add_figure(
|
# writer.add_figure(
|
||||||
"powers",
|
# "powers",
|
||||||
self.plot_model_response(
|
# powers_fig,
|
||||||
trial,
|
# epoch + 1,
|
||||||
model=self.model,
|
# )
|
||||||
title_append=title_append,
|
# writer.flush()
|
||||||
subtitle=subtitle,
|
|
||||||
mode="powers",
|
|
||||||
show=False,
|
|
||||||
),
|
|
||||||
epoch + 1,
|
|
||||||
)
|
|
||||||
|
|
||||||
# if enable_progress:
|
# if enable_progress:
|
||||||
# progress.stop()
|
# progress.stop()
|
||||||
@@ -511,15 +515,18 @@ class HyperTraining:
|
|||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
model = model.to(self.pytorch_settings.device)
|
model = model.to(self.pytorch_settings.device)
|
||||||
for x, y, timestamp in loader:
|
for batch in loader:
|
||||||
|
x = batch["x"]
|
||||||
|
y = batch["y"]
|
||||||
|
timestamp = batch["timestamp"]
|
||||||
x, y = (
|
x, y = (
|
||||||
x.to(self.pytorch_settings.device),
|
x.to(self.pytorch_settings.device),
|
||||||
y.to(self.pytorch_settings.device),
|
y.to(self.pytorch_settings.device),
|
||||||
)
|
)
|
||||||
if trace_powers:
|
if trace_powers:
|
||||||
y_pred, powers = model(x, trace_powers).cpu()
|
y_pred, powers = model(x, trace_powers=True).cpu()
|
||||||
else:
|
else:
|
||||||
y_pred = model(x, trace_powers).cpu()
|
y_pred = model(x, trace_powers=True).cpu()
|
||||||
# x = x.cpu()
|
# x = x.cpu()
|
||||||
# y = y.cpu()
|
# y = y.cpu()
|
||||||
y_pred = y_pred.view(y_pred.shape[0], -1, 2)
|
y_pred = y_pred.view(y_pred.shape[0], -1, 2)
|
||||||
@@ -539,7 +546,7 @@ class HyperTraining:
|
|||||||
return fiber_in, fiber_out, regen, timestamps, powers
|
return fiber_in, fiber_out, regen, timestamps, powers
|
||||||
return fiber_in, fiber_out, regen, timestamps
|
return fiber_in, fiber_out, regen, timestamps
|
||||||
|
|
||||||
def objective(self, trial: optuna.Trial, plot_before=False):
|
def objective(self, trial: optuna.Trial):
|
||||||
if self.stop_study:
|
if self.stop_study:
|
||||||
trial.study.stop()
|
trial.study.stop()
|
||||||
model = None
|
model = None
|
||||||
@@ -555,54 +562,54 @@ class HyperTraining:
|
|||||||
|
|
||||||
title_append, subtitle = self.build_title(trial)
|
title_append, subtitle = self.build_title(trial)
|
||||||
|
|
||||||
writer.add_figure(
|
# writer.add_figure(
|
||||||
"fiber response",
|
# "fiber response",
|
||||||
self.plot_model_response(
|
# self.plot_model_response(
|
||||||
trial,
|
# trial,
|
||||||
model=model,
|
# model=model,
|
||||||
title_append=title_append,
|
# title_append=title_append,
|
||||||
subtitle=subtitle,
|
# subtitle=subtitle,
|
||||||
show=False,
|
# show=False,
|
||||||
),
|
# ),
|
||||||
0,
|
# 0,
|
||||||
)
|
# )
|
||||||
writer.add_figure(
|
# writer.add_figure(
|
||||||
"eye diagram",
|
# "eye diagram",
|
||||||
self.plot_model_response(
|
# self.plot_model_response(
|
||||||
trial,
|
# trial,
|
||||||
model=self.model,
|
# model=self.model,
|
||||||
title_append=title_append,
|
# title_append=title_append,
|
||||||
subtitle=subtitle,
|
# subtitle=subtitle,
|
||||||
mode="eye",
|
# mode="eye",
|
||||||
show=False,
|
# show=False,
|
||||||
),
|
# ),
|
||||||
0,
|
# 0,
|
||||||
)
|
# )
|
||||||
|
|
||||||
writer.add_figure(
|
# writer.add_figure(
|
||||||
"powers",
|
# "powers",
|
||||||
self.plot_model_response(
|
# self.plot_model_response(
|
||||||
trial,
|
# trial,
|
||||||
model=self.model,
|
# model=self.model,
|
||||||
title_append=title_append,
|
# title_append=title_append,
|
||||||
subtitle=subtitle,
|
# subtitle=subtitle,
|
||||||
mode="powers",
|
# mode="powers",
|
||||||
show=False,
|
# show=False,
|
||||||
),
|
# ),
|
||||||
0,
|
# 0,
|
||||||
)
|
# )
|
||||||
|
|
||||||
train_loader, valid_loader = self.get_sliced_data(trial)
|
train_loader, valid_loader = self.get_sliced_data(trial)
|
||||||
|
|
||||||
optimizer_name = trial.suggest_categorical_optional("optimizer", self.optimizer_settings.optimizer)
|
optimizer_name = trial.suggest_categorical_optional("optimizer", self.optimizer_settings.optimizer)
|
||||||
|
|
||||||
lr = trial.suggest_float_optional("lr", self.optimizer_settings.learning_rate, log=True)
|
lr = trial.suggest_float_optional("lr", self.optimizer_settings.optimizer_kwargs["lr"], log=True)
|
||||||
|
|
||||||
optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)
|
optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)
|
||||||
if self.optimizer_settings.scheduler is not None:
|
# if self.optimizer_settings.scheduler is not None:
|
||||||
scheduler = getattr(optim.lr_scheduler, self.optimizer_settings.scheduler)(
|
# scheduler = getattr(optim.lr_scheduler, self.optimizer_settings.scheduler)(
|
||||||
optimizer, **self.optimizer_settings.scheduler_kwargs
|
# optimizer, **self.optimizer_settings.scheduler_kwargs
|
||||||
)
|
# )
|
||||||
|
|
||||||
for epoch in range(self.pytorch_settings.epochs):
|
for epoch in range(self.pytorch_settings.epochs):
|
||||||
trial.set_user_attr("epoch", epoch)
|
trial.set_user_attr("epoch", epoch)
|
||||||
@@ -628,8 +635,8 @@ class HyperTraining:
|
|||||||
writer,
|
writer,
|
||||||
# enable_progress=enable_progress,
|
# enable_progress=enable_progress,
|
||||||
)
|
)
|
||||||
if self.optimizer_settings.scheduler is not None:
|
# if self.optimizer_settings.scheduler is not None:
|
||||||
scheduler.step(error)
|
# scheduler.step(error)
|
||||||
|
|
||||||
trial.set_user_attr("mse", error)
|
trial.set_user_attr("mse", error)
|
||||||
trial.set_user_attr("log_mse", np.log10(error + np.finfo(float).eps))
|
trial.set_user_attr("log_mse", np.log10(error + np.finfo(float).eps))
|
||||||
@@ -645,10 +652,10 @@ class HyperTraining:
|
|||||||
if self.optuna_settings._multi_objective:
|
if self.optuna_settings._multi_objective:
|
||||||
return -np.log10(error + np.finfo(float).eps), trial.user_attrs.get("model_n_nodes", -1)
|
return -np.log10(error + np.finfo(float).eps), trial.user_attrs.get("model_n_nodes", -1)
|
||||||
|
|
||||||
if self.pytorch_settings.save_models and model is not None:
|
# if self.pytorch_settings.save_models and model is not None:
|
||||||
save_path = Path(self.pytorch_settings.model_dir) / f"{self.optuna_settings.study_name}_{trial.number}.pth"
|
# save_path = Path(self.pytorch_settings.model_dir) / f"{self.optuna_settings.study_name}_{trial.number}.pth"
|
||||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
# save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
torch.save(model, save_path)
|
# torch.save(model, save_path)
|
||||||
|
|
||||||
return error
|
return error
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,8 @@ from util.complexNN import (
|
|||||||
photodiode,
|
photodiode,
|
||||||
EOActivation,
|
EOActivation,
|
||||||
polarimeter,
|
polarimeter,
|
||||||
normalize_by_first
|
# normalize_by_first,
|
||||||
|
rotate,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -19,11 +20,11 @@ class polarisation_estimator2(Module):
|
|||||||
polarimeter(),
|
polarimeter(),
|
||||||
torch.nn.Linear(4, 4),
|
torch.nn.Linear(4, 4),
|
||||||
torch.nn.ReLU(),
|
torch.nn.ReLU(),
|
||||||
torch.nn.Dropout(p=0.01),
|
# torch.nn.Dropout(p=0.01),
|
||||||
torch.nn.Linear(4, 4),
|
torch.nn.Linear(4, 4),
|
||||||
torch.nn.ReLU(),
|
torch.nn.ReLU(),
|
||||||
torch.nn.Dropout(p=0.01),
|
# torch.nn.Dropout(p=0.01),
|
||||||
torch.nn.Linear(4, 4),
|
torch.nn.Linear(4, 1),
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
@@ -123,7 +124,8 @@ class regenerator(Module):
|
|||||||
parametrizations: list[dict] = None,
|
parametrizations: list[dict] = None,
|
||||||
dtype=torch.float64,
|
dtype=torch.float64,
|
||||||
dropout_prob=0.01,
|
dropout_prob=0.01,
|
||||||
scale_layers=False,
|
prescale=1,
|
||||||
|
rotate=False,
|
||||||
):
|
):
|
||||||
super(regenerator, self).__init__()
|
super(regenerator, self).__init__()
|
||||||
self._n_hidden_layers = len(dims) - 2
|
self._n_hidden_layers = len(dims) - 2
|
||||||
@@ -131,14 +133,15 @@ class regenerator(Module):
|
|||||||
layer_func_kwargs = layer_func_kwargs or {}
|
layer_func_kwargs = layer_func_kwargs or {}
|
||||||
act_func_kwargs = act_func_kwargs or {}
|
act_func_kwargs = act_func_kwargs or {}
|
||||||
|
|
||||||
self.build_model(dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob, scale_layers)
|
self.rotation = rotate
|
||||||
|
self.prescale = prescale
|
||||||
|
|
||||||
def build_model(self, dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob, scale_layers):
|
self.build_model(dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob)
|
||||||
|
|
||||||
|
def build_model(self, dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob):
|
||||||
for i in range(0, self._n_hidden_layers):
|
for i in range(0, self._n_hidden_layers):
|
||||||
self.add_module(f"layer_{i}", Sequential())
|
self.add_module(f"layer_{i}", Sequential())
|
||||||
|
|
||||||
if scale_layers:
|
|
||||||
self.get_submodule(f"layer_{i}").add_module("scale", Scale(dims[i]))
|
|
||||||
|
|
||||||
module = layer_function(dims[i], dims[i + 1], dtype=dtype, **layer_func_kwargs)
|
module = layer_function(dims[i], dims[i + 1], dtype=dtype, **layer_func_kwargs)
|
||||||
self.get_submodule(f"layer_{i}").add_module("ONN", module)
|
self.get_submodule(f"layer_{i}").add_module("ONN", module)
|
||||||
@@ -146,13 +149,14 @@ class regenerator(Module):
|
|||||||
module = act_function(size=dims[i + 1], **act_func_kwargs)
|
module = act_function(size=dims[i + 1], **act_func_kwargs)
|
||||||
self.get_submodule(f"layer_{i}").add_module("activation", module)
|
self.get_submodule(f"layer_{i}").add_module("activation", module)
|
||||||
|
|
||||||
|
if dropout_prob is not None and dropout_prob > 0:
|
||||||
module = DropoutComplex(p=dropout_prob)
|
module = DropoutComplex(p=dropout_prob)
|
||||||
self.get_submodule(f"layer_{i}").add_module("dropout", module)
|
self.get_submodule(f"layer_{i}").add_module("dropout", module)
|
||||||
|
|
||||||
self.add_module(f"layer_{self._n_hidden_layers}", Sequential())
|
self.add_module(f"layer_{self._n_hidden_layers}", Sequential())
|
||||||
|
|
||||||
if scale_layers:
|
# if scale_layers:
|
||||||
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("scale", Scale(dims[-2]))
|
# self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("scale", Scale(dims[-2]))
|
||||||
|
|
||||||
module = layer_function(dims[-2], dims[-1], dtype=dtype, **layer_func_kwargs)
|
module = layer_function(dims[-2], dims[-1], dtype=dtype, **layer_func_kwargs)
|
||||||
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("ONN", module)
|
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("ONN", module)
|
||||||
@@ -160,6 +164,14 @@ class regenerator(Module):
|
|||||||
module = act_function(size=dims[-1], **act_func_kwargs)
|
module = act_function(size=dims[-1], **act_func_kwargs)
|
||||||
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("activation", module)
|
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("activation", module)
|
||||||
|
|
||||||
|
module = Scale(size=dims[-1])
|
||||||
|
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("out_scale", module)
|
||||||
|
|
||||||
|
if self.rotation:
|
||||||
|
module = rotate()
|
||||||
|
self.add_module("rotate", module)
|
||||||
|
|
||||||
|
|
||||||
# module = Scale(size=dims[-1])
|
# module = Scale(size=dims[-1])
|
||||||
# self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("out_scale", module)
|
# self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("out_scale", module)
|
||||||
|
|
||||||
@@ -190,15 +202,28 @@ class regenerator(Module):
|
|||||||
powers.append(x.abs().square().sum())
|
powers.append(x.abs().square().sum())
|
||||||
return powers
|
return powers
|
||||||
|
|
||||||
def forward(self, x, trace_powers=False):
|
def forward(self, x, angle=None, pre_rot=False, trace_powers=False):
|
||||||
|
x = x * self.prescale
|
||||||
powers = self._trace_powers(trace_powers, x)
|
powers = self._trace_powers(trace_powers, x)
|
||||||
x = self.layer_0(x)
|
# x = self.layer_0(x)
|
||||||
powers = self._trace_powers(trace_powers, x, powers)
|
# powers = self._trace_powers(trace_powers, x, powers)
|
||||||
for i in range(1, self._n_hidden_layers):
|
for i in range(0, self._n_hidden_layers):
|
||||||
x = getattr(self, f"layer_{i}")(x)
|
x = getattr(self, f"layer_{i}")(x)
|
||||||
powers = self._trace_powers(trace_powers, x, powers)
|
powers = self._trace_powers(trace_powers, x, powers)
|
||||||
x = getattr(self, f"layer_{self._n_hidden_layers}")(x)
|
x = getattr(self, f"layer_{self._n_hidden_layers}")(x)
|
||||||
powers = self._trace_powers(trace_powers, x, powers)
|
if self.rotation:
|
||||||
if trace_powers:
|
try:
|
||||||
return x, powers
|
x_rot = self.rotate(x, angle)
|
||||||
return x
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
powers = self._trace_powers(trace_powers, x_rot, powers)
|
||||||
|
else:
|
||||||
|
x_rot = x
|
||||||
|
|
||||||
|
if pre_rot and trace_powers:
|
||||||
|
return x_rot, x, powers
|
||||||
|
if pre_rot and not trace_powers:
|
||||||
|
return x_rot, x
|
||||||
|
if not pre_rot and trace_powers:
|
||||||
|
return x_rot, powers
|
||||||
|
return x_rot
|
||||||
@@ -18,10 +18,14 @@ class DataSettings:
|
|||||||
shuffle: bool = True
|
shuffle: bool = True
|
||||||
in_out_delay: float = 0
|
in_out_delay: float = 0
|
||||||
xy_delay: tuple | float | int = 0
|
xy_delay: tuple | float | int = 0
|
||||||
drop_first: int = 1000
|
drop_first: int = 64
|
||||||
|
drop_last: int = 64
|
||||||
train_split: float = 0.8
|
train_split: float = 0.8
|
||||||
polarisations: tuple | list = (0,)
|
polarisations: tuple | list = (0,)
|
||||||
|
# cross_pol_interference: float = 0
|
||||||
randomise_polarisations: bool = False
|
randomise_polarisations: bool = False
|
||||||
|
osnr: float | int = None
|
||||||
|
seed: int = None
|
||||||
|
|
||||||
"""
|
"""
|
||||||
change to:
|
change to:
|
||||||
@@ -91,6 +95,12 @@ class ModelSettings:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def _early_stop_default_kwargs():
|
||||||
|
return {
|
||||||
|
"threshold": 1e-05,
|
||||||
|
"plateau": 25,
|
||||||
|
}
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class OptimizerSettings:
|
class OptimizerSettings:
|
||||||
optimizer: tuple | str = ("Adam", "RMSprop", "SGD")
|
optimizer: tuple | str = ("Adam", "RMSprop", "SGD")
|
||||||
@@ -99,6 +109,9 @@ class OptimizerSettings:
|
|||||||
scheduler: str | None = None
|
scheduler: str | None = None
|
||||||
scheduler_kwargs: dict | None = None
|
scheduler_kwargs: dict | None = None
|
||||||
|
|
||||||
|
early_stopping: bool = False
|
||||||
|
early_stop_kwargs: dict = field(default_factory=_early_stop_default_kwargs)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
change to:
|
change to:
|
||||||
|
|
||||||
|
|||||||
@@ -2,9 +2,9 @@ import copy
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import random
|
import random
|
||||||
from typing import Literal
|
|
||||||
import matplotlib
|
import matplotlib
|
||||||
from matplotlib.colors import LinearSegmentedColormap
|
from matplotlib.colors import LinearSegmentedColormap
|
||||||
|
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
||||||
import torch.nn.utils.parametrize
|
import torch.nn.utils.parametrize
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -47,38 +47,98 @@ from .settings import (
|
|||||||
PytorchSettings,
|
PytorchSettings,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from cmcrameri import cm
|
||||||
|
# from matplotlib import colors as mcolors
|
||||||
|
# alpha_map = mcolors.LinearSegmentedColormap(
|
||||||
|
# 'alphamap',
|
||||||
|
# {
|
||||||
|
# 'red': [(0, 0, 0), (1, 0, 0)],
|
||||||
|
# 'green': [(0, 0, 0), (1, 0, 0)],
|
||||||
|
# 'blue': [(0, 0, 0), (1, 0, 0)],
|
||||||
|
# 'alpha': [
|
||||||
|
# (0, 1, 1),
|
||||||
|
# # (0.2, 0.2, 0.1),
|
||||||
|
# (1, 0, 0)
|
||||||
|
# ]
|
||||||
|
# }
|
||||||
|
# )
|
||||||
|
# alpha_map.set_bad(color="#AAAAAA")
|
||||||
|
|
||||||
|
def pad_to_size(array, size):
|
||||||
|
if not hasattr(size, "__len__"):
|
||||||
|
size = (size, size)
|
||||||
|
|
||||||
|
left = (
|
||||||
|
(size[0] - array.shape[0] + 1) // 2 if size[0] is not None else 0
|
||||||
|
)
|
||||||
|
right = (
|
||||||
|
(size[0] - array.shape[0]) // 2 if size[0] is not None else 0
|
||||||
|
)
|
||||||
|
top = (
|
||||||
|
(size[1] - array.shape[1] + 1) // 2 if size[1] is not None else 0
|
||||||
|
)
|
||||||
|
bottom = (
|
||||||
|
(size[1] - array.shape[1]) // 2 if size[1] is not None else 0
|
||||||
|
)
|
||||||
|
|
||||||
|
array: np.ndarray = array
|
||||||
|
if array.ndim == 2:
|
||||||
|
return np.pad(
|
||||||
|
array,
|
||||||
|
(
|
||||||
|
(left, right),
|
||||||
|
(top, bottom),
|
||||||
|
),
|
||||||
|
constant_values=(np.nan, np.nan),
|
||||||
|
)
|
||||||
|
elif array.ndim == 3:
|
||||||
|
return np.pad(
|
||||||
|
array,
|
||||||
|
(
|
||||||
|
(left, right),
|
||||||
|
(top, bottom),
|
||||||
|
(0,0)
|
||||||
|
),
|
||||||
|
constant_values=(np.nan, np.nan),
|
||||||
|
)
|
||||||
|
|
||||||
def traverse_dict_update(target, source):
|
def traverse_dict_update(target, source):
|
||||||
for k, v in source.items():
|
for k, v in source.items():
|
||||||
if isinstance(v, dict):
|
if isinstance(v, dict):
|
||||||
|
try:
|
||||||
if k not in target:
|
if k not in target:
|
||||||
target[k] = {}
|
target[k] = {}
|
||||||
traverse_dict_update(target[k], v)
|
traverse_dict_update(target[k], v)
|
||||||
|
except TypeError:
|
||||||
|
if k not in target.__dict__:
|
||||||
|
setattr(target, k, {})
|
||||||
|
traverse_dict_update(target.__dict__[k], v)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
target[k] = v
|
target[k] = v
|
||||||
except TypeError:
|
except TypeError:
|
||||||
target.__dict__[k] = v
|
target.__dict__[k] = v
|
||||||
|
|
||||||
|
|
||||||
def get_parameter_names_and_values(model):
|
def get_parameter_names_and_values(model):
|
||||||
def is_parametrized(module):
|
def is_parametrized(module):
|
||||||
if hasattr(module, "parametrizations"):
|
if hasattr(module, "parametrizations"):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _get_param_info(module, prefix='', parametrization=False):
|
def _get_param_info(module, prefix="", parametrization=False):
|
||||||
param_list = []
|
param_list = []
|
||||||
for name, param in module.named_parameters(recurse=parametrization):
|
for name, param in module.named_parameters(recurse=parametrization):
|
||||||
if parametrization and name.startswith("parametrizations"):
|
if parametrization and name.startswith("parametrizations"):
|
||||||
name_parts = name.split('.')
|
name_parts = name.split(".")
|
||||||
name = name_parts[1]
|
name = name_parts[1]
|
||||||
param = getattr(module, name)
|
param = getattr(module, name)
|
||||||
full_name = prefix + ('.' if prefix else '') + name
|
full_name = prefix + ("." if prefix else "") + name
|
||||||
param_value = param.data
|
param_value = param.data
|
||||||
param_list.append((full_name, param_value))
|
param_list.append((full_name, param_value))
|
||||||
|
|
||||||
for child_name, child_module in module.named_children():
|
for child_name, child_module in module.named_children():
|
||||||
child_prefix = prefix + ('.' if prefix else '') + child_name
|
child_prefix = prefix + ("." if prefix else "") + child_name
|
||||||
if child_name == "parametrizations":
|
if child_name == "parametrizations":
|
||||||
continue
|
continue
|
||||||
param_list.extend(_get_param_info(child_module, child_prefix, is_parametrized(child_module)))
|
param_list.extend(_get_param_info(child_module, child_prefix, is_parametrized(child_module)))
|
||||||
@@ -87,6 +147,7 @@ def get_parameter_names_and_values(model):
|
|||||||
|
|
||||||
return _get_param_info(model)
|
return _get_param_info(model)
|
||||||
|
|
||||||
|
|
||||||
class PolarizationTrainer:
|
class PolarizationTrainer:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -219,7 +280,7 @@ class PolarizationTrainer:
|
|||||||
|
|
||||||
# dims = self.model_kwargs.pop("dims")
|
# dims = self.model_kwargs.pop("dims")
|
||||||
model_kwargs = copy.deepcopy(self.model_kwargs)
|
model_kwargs = copy.deepcopy(self.model_kwargs)
|
||||||
self.model = models.polarisation_estimator(*model_kwargs.pop('dims'),**model_kwargs)
|
self.model = models.polarisation_estimator(*model_kwargs.pop("dims"), **model_kwargs)
|
||||||
# self.model = models.polarisation_estimator2()
|
# self.model = models.polarisation_estimator2()
|
||||||
|
|
||||||
if self.writer is not None:
|
if self.writer is not None:
|
||||||
@@ -260,6 +321,7 @@ class PolarizationTrainer:
|
|||||||
target_delay=in_out_delay,
|
target_delay=in_out_delay,
|
||||||
xy_delay=xy_delay,
|
xy_delay=xy_delay,
|
||||||
drop_first=self.data_settings.drop_first,
|
drop_first=self.data_settings.drop_first,
|
||||||
|
drop_last=self.data_settings.drop_last,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
real=not dtype.is_complex,
|
real=not dtype.is_complex,
|
||||||
num_symbols=num_symbols,
|
num_symbols=num_symbols,
|
||||||
@@ -336,17 +398,20 @@ class PolarizationTrainer:
|
|||||||
write_div = 0
|
write_div = 0
|
||||||
loss_div = 0
|
loss_div = 0
|
||||||
for batch_idx, batch in enumerate(train_loader):
|
for batch_idx, batch in enumerate(train_loader):
|
||||||
x = batch["x"]
|
x = batch["angle_data2"]
|
||||||
y = batch["sop"]
|
y = batch["center_angle"]
|
||||||
self.model.zero_grad(set_to_none=True)
|
self.model.zero_grad(set_to_none=True)
|
||||||
x, y = (
|
x, y = (
|
||||||
x.to(self.pytorch_settings.device),
|
x.to(self.pytorch_settings.device),
|
||||||
y.to(self.pytorch_settings.device),
|
y.to(self.pytorch_settings.device),
|
||||||
)
|
)
|
||||||
y_pred = self.model(x)
|
y_pred = self.model(x).abs().real
|
||||||
|
# y_pred = torch.fmod(y_pred, self.mod)
|
||||||
|
y = y.abs().real
|
||||||
|
# y = torch.fmod(y, self.mod)
|
||||||
|
# loss = torch.nn.functional.mse_loss(torch.cos(y_pred), torch.cos(y))
|
||||||
# loss = torch.nn.functional.smooth_l1_loss(torch.cos(torch.fmod(y_pred, torch.pi/2)).squeeze(), torch.cos(torch.fmod(y, torch.pi/2)).squeeze(), beta=0.5)
|
# loss = torch.nn.functional.smooth_l1_loss(torch.cos(torch.fmod(y_pred, torch.pi/2)).squeeze(), torch.cos(torch.fmod(y, torch.pi/2)).squeeze(), beta=0.5)
|
||||||
loss = torch.nn.functional.mse_loss(y_pred, y)
|
loss = util.complexNN.naive_angle_loss(y_pred, y, mod=self.mod)
|
||||||
# loss = util.complexNN.naive_angle_loss(y_pred, y, mod=torch.pi/2)
|
|
||||||
loss_value = loss.item()
|
loss_value = loss.item()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
@@ -356,7 +421,7 @@ class PolarizationTrainer:
|
|||||||
loss_div += 1
|
loss_div += 1
|
||||||
|
|
||||||
if enable_progress:
|
if enable_progress:
|
||||||
progress.update(task, advance=1, description=f"{loss_value:.3e}")
|
progress.update(task, advance=1, description=f"{loss_value/np.pi*180:.3e} °")
|
||||||
|
|
||||||
if batch_idx % self.pytorch_settings.write_every == 0:
|
if batch_idx % self.pytorch_settings.write_every == 0:
|
||||||
self.writer.add_scalar(
|
self.writer.add_scalar(
|
||||||
@@ -395,22 +460,26 @@ class PolarizationTrainer:
|
|||||||
loss_div = 0
|
loss_div = 0
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for _, batch in enumerate(valid_loader):
|
for _, batch in enumerate(valid_loader):
|
||||||
x = batch["x"]
|
# x = batch["angle_data2"]
|
||||||
y = batch["sop"]
|
x = batch["angle_data2"]
|
||||||
|
y = batch["center_angle"]
|
||||||
x, y = (
|
x, y = (
|
||||||
x.to(self.pytorch_settings.device),
|
x.to(self.pytorch_settings.device),
|
||||||
y.to(self.pytorch_settings.device),
|
y.to(self.pytorch_settings.device),
|
||||||
)
|
)
|
||||||
y_pred = self.model(x)
|
y_pred = self.model(x).abs().real
|
||||||
|
# y_pred = torch.fmod(y_pred, self.mod)
|
||||||
|
y = y.abs().real
|
||||||
|
# y = torch.fmod(y, self.mod)
|
||||||
|
# loss = torch.nn.functional.mse_loss(torch.cos(y_pred), torch.cos(y))
|
||||||
# loss = torch.nn.functional.smooth_l1_loss(torch.cos(torch.fmod(y_pred, torch.pi/2)).squeeze(), torch.cos(torch.fmod(y, torch.pi/2)).squeeze(), beta=0.5)
|
# loss = torch.nn.functional.smooth_l1_loss(torch.cos(torch.fmod(y_pred, torch.pi/2)).squeeze(), torch.cos(torch.fmod(y, torch.pi/2)).squeeze(), beta=0.5)
|
||||||
loss = torch.nn.functional.mse_loss(y_pred, y)
|
loss = util.complexNN.naive_angle_loss(y_pred, y, mod=self.mod)
|
||||||
# loss = util.complexNN.naive_angle_loss(y_pred, y, mod=torch.pi/2)
|
|
||||||
loss_value = loss.item()
|
loss_value = loss.item()
|
||||||
running_loss += loss_value
|
running_loss += loss_value
|
||||||
loss_div += 1
|
loss_div += 1
|
||||||
|
|
||||||
if enable_progress:
|
if enable_progress:
|
||||||
progress.update(task, advance=1, description=f"{loss_value:.3e}")
|
progress.update(task, advance=1, description=f"{loss_value/np.pi*180:.3e} °")
|
||||||
|
|
||||||
running_loss = running_loss / loss_div
|
running_loss = running_loss / loss_div
|
||||||
|
|
||||||
@@ -506,19 +575,19 @@ class PolarizationTrainer:
|
|||||||
for i, config_path in enumerate(self.data_settings.config_path):
|
for i, config_path in enumerate(self.data_settings.config_path):
|
||||||
paths = Path.cwd().glob(config_path)
|
paths = Path.cwd().glob(config_path)
|
||||||
for j, path in enumerate(paths):
|
for j, path in enumerate(paths):
|
||||||
text = str(path) + '\n'
|
text = str(path) + "\n"
|
||||||
with open(path, 'r') as f:
|
with open(path, "r") as f:
|
||||||
text += f.read()
|
text += f.read()
|
||||||
text += '\n'
|
text += "\n"
|
||||||
self.writer.add_text(f"config_{i * len(self.data_settings.config_path) + j}", text)
|
self.writer.add_text(f"config_{i * len(self.data_settings.config_path) + j}", text)
|
||||||
|
|
||||||
elif isinstance(self.data_settings.config_path, str):
|
elif isinstance(self.data_settings.config_path, str):
|
||||||
paths = Path.cwd().glob(self.data_settings.config_path)
|
paths = Path.cwd().glob(self.data_settings.config_path)
|
||||||
for j, path in enumerate(paths):
|
for j, path in enumerate(paths):
|
||||||
text = str(path) + '\n'
|
text = str(path) + "\n"
|
||||||
with open(path, 'r') as f:
|
with open(path, "r") as f:
|
||||||
text += f.read()
|
text += f.read()
|
||||||
text += '\n'
|
text += "\n"
|
||||||
self.writer.add_text(f"config_{j}", text)
|
self.writer.add_text(f"config_{j}", text)
|
||||||
|
|
||||||
self.writer.flush()
|
self.writer.flush()
|
||||||
@@ -571,7 +640,8 @@ class PolarizationTrainer:
|
|||||||
if loss < self.best["loss"]:
|
if loss < self.best["loss"]:
|
||||||
self.best = checkpoint
|
self.best = checkpoint
|
||||||
save_path = (
|
save_path = (
|
||||||
Path(self.pytorch_settings.model_dir) / f"best_pol_{self.writer.get_logdir().split('/')[-1]}.tar"
|
Path(self.pytorch_settings.model_dir)
|
||||||
|
/ f"best_pol_{self.writer.get_logdir().split('/')[-1]}.tar"
|
||||||
)
|
)
|
||||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
self.save_checkpoint(self.best, save_path)
|
self.save_checkpoint(self.best, save_path)
|
||||||
@@ -580,6 +650,7 @@ class PolarizationTrainer:
|
|||||||
self.writer.close()
|
self.writer.close()
|
||||||
return self.best
|
return self.best
|
||||||
|
|
||||||
|
|
||||||
class RegenerationTrainer:
|
class RegenerationTrainer:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -592,6 +663,7 @@ class RegenerationTrainer:
|
|||||||
console=None,
|
console=None,
|
||||||
checkpoint_path=None,
|
checkpoint_path=None,
|
||||||
settings_override=None,
|
settings_override=None,
|
||||||
|
new_model=False,
|
||||||
reset_epoch=False,
|
reset_epoch=False,
|
||||||
):
|
):
|
||||||
self.resume = checkpoint_path is not None
|
self.resume = checkpoint_path is not None
|
||||||
@@ -605,12 +677,23 @@ class RegenerationTrainer:
|
|||||||
models.regenerator,
|
models.regenerator,
|
||||||
torch.nn.utils.parametrizations.orthogonal,
|
torch.nn.utils.parametrizations.orthogonal,
|
||||||
])
|
])
|
||||||
|
# self.new_model = True
|
||||||
|
self.model_name = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
if self.resume:
|
if self.resume:
|
||||||
print(f"loading checkpoint from {checkpoint_path}")
|
print(f"loading checkpoint from {checkpoint_path}")
|
||||||
self.checkpoint_dict = torch.load(checkpoint_path, weights_only=True)
|
self.checkpoint_dict = torch.load(checkpoint_path, weights_only=True)
|
||||||
if settings_override is not None:
|
if settings_override is not None:
|
||||||
traverse_dict_update(self.checkpoint_dict["settings"], settings_override)
|
traverse_dict_update(self.checkpoint_dict["settings"], settings_override)
|
||||||
if reset_epoch:
|
|
||||||
|
if not new_model:
|
||||||
|
# self.new_model = False
|
||||||
|
checkpoint_file = checkpoint_path.split("/")[-1].split(".")[0]
|
||||||
|
if checkpoint_file.startswith("best"):
|
||||||
|
self.model_name = "_".join(checkpoint_file.split("_")[1:])
|
||||||
|
else:
|
||||||
|
self.model_name = "_".join(checkpoint_file.split("_")[:-1])
|
||||||
|
|
||||||
|
if new_model or reset_epoch:
|
||||||
self.checkpoint_dict["epoch"] = -1
|
self.checkpoint_dict["epoch"] = -1
|
||||||
|
|
||||||
self.global_settings: GlobalSettings = self.checkpoint_dict["settings"]["global_settings"]
|
self.global_settings: GlobalSettings = self.checkpoint_dict["settings"]["global_settings"]
|
||||||
@@ -636,11 +719,15 @@ class RegenerationTrainer:
|
|||||||
self.model_settings: ModelSettings = model_settings
|
self.model_settings: ModelSettings = model_settings
|
||||||
self.optimizer_settings: OptimizerSettings = optimizer_settings
|
self.optimizer_settings: OptimizerSettings = optimizer_settings
|
||||||
|
|
||||||
|
if self.global_settings.seed is not None:
|
||||||
|
random.seed(self.global_settings.seed)
|
||||||
|
np.random.seed(self.global_settings.seed)
|
||||||
|
|
||||||
self.console = console or Console()
|
self.console = console or Console()
|
||||||
self.writer = None
|
self.writer = None
|
||||||
|
|
||||||
def setup_tb_writer(self, append=None):
|
def setup_tb_writer(self, append=None):
|
||||||
log_dir = self.pytorch_settings.summary_dir + "/" + (datetime.now().strftime("%Y%m%d_%H%M%S"))
|
log_dir = self.pytorch_settings.summary_dir + "/" + self.model_name
|
||||||
if append is not None:
|
if append is not None:
|
||||||
log_dir += "_" + str(append)
|
log_dir += "_" + str(append)
|
||||||
|
|
||||||
@@ -669,7 +756,7 @@ class RegenerationTrainer:
|
|||||||
|
|
||||||
def define_model(self, model_kwargs=None):
|
def define_model(self, model_kwargs=None):
|
||||||
if self.resume:
|
if self.resume:
|
||||||
model_kwargs = self.checkpoint_dict["model_kwargs"]
|
model_kwargs = None
|
||||||
else:
|
else:
|
||||||
model_kwargs = model_kwargs
|
model_kwargs = model_kwargs
|
||||||
|
|
||||||
@@ -678,6 +765,14 @@ class RegenerationTrainer:
|
|||||||
|
|
||||||
input_dim = 2 * self.data_settings.output_size
|
input_dim = 2 * self.data_settings.output_size
|
||||||
|
|
||||||
|
# if self.data_settings.polarisations:
|
||||||
|
# input_dim *= 2
|
||||||
|
|
||||||
|
output_dim = self.model_settings.output_dim
|
||||||
|
|
||||||
|
if self.data_settings.polarisations:
|
||||||
|
output_dim *= 2
|
||||||
|
|
||||||
dtype = getattr(torch, self.data_settings.dtype)
|
dtype = getattr(torch, self.data_settings.dtype)
|
||||||
|
|
||||||
afunc = getattr(util.complexNN, self.model_settings.model_activation_func)
|
afunc = getattr(util.complexNN, self.model_settings.model_activation_func)
|
||||||
@@ -689,7 +784,7 @@ class RegenerationTrainer:
|
|||||||
hidden_dims = [self.model_settings.overrides.get(f"n_hidden_nodes_{i}") for i in range(n_hidden_layers)]
|
hidden_dims = [self.model_settings.overrides.get(f"n_hidden_nodes_{i}") for i in range(n_hidden_layers)]
|
||||||
|
|
||||||
self.model_kwargs = {
|
self.model_kwargs = {
|
||||||
"dims": (input_dim, *hidden_dims, self.model_settings.output_dim),
|
"dims": (input_dim, *hidden_dims, output_dim),
|
||||||
"layer_function": layer_func,
|
"layer_function": layer_func,
|
||||||
"layer_func_kwargs": self.model_settings.model_layer_kwargs,
|
"layer_func_kwargs": self.model_settings.model_layer_kwargs,
|
||||||
"act_function": afunc,
|
"act_function": afunc,
|
||||||
@@ -697,7 +792,7 @@ class RegenerationTrainer:
|
|||||||
"parametrizations": layer_parametrizations,
|
"parametrizations": layer_parametrizations,
|
||||||
"dtype": dtype,
|
"dtype": dtype,
|
||||||
"dropout_prob": self.model_settings.dropout_prob,
|
"dropout_prob": self.model_settings.dropout_prob,
|
||||||
"scale_layers": self.model_settings.scale,
|
"prescale": self.model_settings.scale,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
self.model_kwargs = model_kwargs
|
self.model_kwargs = model_kwargs
|
||||||
@@ -706,10 +801,12 @@ class RegenerationTrainer:
|
|||||||
|
|
||||||
# dims = self.model_kwargs.pop("dims")
|
# dims = self.model_kwargs.pop("dims")
|
||||||
model_kwargs = copy.deepcopy(self.model_kwargs)
|
model_kwargs = copy.deepcopy(self.model_kwargs)
|
||||||
self.model = models.regenerator(*model_kwargs.pop('dims'),**model_kwargs)
|
self.model = models.regenerator(*model_kwargs.pop("dims"), **model_kwargs)
|
||||||
|
|
||||||
if self.writer is not None:
|
if self.writer is not None:
|
||||||
self.writer.add_graph(self.model, torch.zeros(1, input_dim, dtype=dtype))
|
self.writer.add_graph(
|
||||||
|
self.model, (torch.rand(1, input_dim, dtype=dtype), torch.rand(1, 1, dtype=dtype.to_real()))
|
||||||
|
)
|
||||||
|
|
||||||
self.model = self.model.to(self.pytorch_settings.device)
|
self.model = self.model.to(self.pytorch_settings.device)
|
||||||
if self.resume:
|
if self.resume:
|
||||||
@@ -728,13 +825,16 @@ class RegenerationTrainer:
|
|||||||
|
|
||||||
num_symbols = None
|
num_symbols = None
|
||||||
config_path = self.data_settings.config_path
|
config_path = self.data_settings.config_path
|
||||||
polarisations = self.data_settings.polarisations
|
|
||||||
randomise_polarisations = self.data_settings.randomise_polarisations
|
randomise_polarisations = self.data_settings.randomise_polarisations
|
||||||
|
polarisations = self.data_settings.polarisations
|
||||||
|
osnr = self.data_settings.osnr
|
||||||
|
# cross_pol_interference = self.data_settings.cross_pol_interference
|
||||||
if override is not None:
|
if override is not None:
|
||||||
num_symbols = override.get("num_symbols", None)
|
num_symbols = override.get("num_symbols", None)
|
||||||
config_path = override.get("config_path", config_path)
|
config_path = override.get("config_path", config_path)
|
||||||
polarisations = override.get("polarisations", polarisations)
|
polarisations = override.get("polarisations", polarisations)
|
||||||
randomise_polarisations = override.get("randomise_polarisation", randomise_polarisations)
|
randomise_polarisations = override.get("randomise_polarisation", randomise_polarisations)
|
||||||
|
# cross_pol_interference = override.get("angle_var", 0)
|
||||||
# get dataset
|
# get dataset
|
||||||
dataset = FiberRegenerationDataset(
|
dataset = FiberRegenerationDataset(
|
||||||
file_path=config_path,
|
file_path=config_path,
|
||||||
@@ -743,11 +843,14 @@ class RegenerationTrainer:
|
|||||||
target_delay=in_out_delay,
|
target_delay=in_out_delay,
|
||||||
xy_delay=xy_delay,
|
xy_delay=xy_delay,
|
||||||
drop_first=self.data_settings.drop_first,
|
drop_first=self.data_settings.drop_first,
|
||||||
|
drop_last=self.data_settings.drop_last,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
real=not dtype.is_complex,
|
real=not dtype.is_complex,
|
||||||
num_symbols=num_symbols,
|
num_symbols=num_symbols,
|
||||||
polarisations=polarisations,
|
|
||||||
randomise_polarisations=randomise_polarisations,
|
randomise_polarisations=randomise_polarisations,
|
||||||
|
polarisations=polarisations,
|
||||||
|
# cross_pol_interference=cross_pol_interference,
|
||||||
|
osnr = osnr,
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset_size = len(dataset)
|
dataset_size = len(dataset)
|
||||||
@@ -816,16 +919,25 @@ class RegenerationTrainer:
|
|||||||
running_loss = 0.0
|
running_loss = 0.0
|
||||||
self.model.train()
|
self.model.train()
|
||||||
loader_len = len(train_loader)
|
loader_len = len(train_loader)
|
||||||
|
# x_key = "x_stacked"# if self.data_settings.polarisations else "x"
|
||||||
|
# y_key = "y_stacked"# if self.data_settings.polarisations else "y"
|
||||||
|
x_key = "x"
|
||||||
|
y_key = "y"
|
||||||
for batch_idx, batch in enumerate(train_loader):
|
for batch_idx, batch in enumerate(train_loader):
|
||||||
x = batch["x"]
|
x = batch[x_key]
|
||||||
y = batch["y"]
|
y = batch[y_key]
|
||||||
|
angle = batch["mean_angle"]
|
||||||
self.model.zero_grad(set_to_none=True)
|
self.model.zero_grad(set_to_none=True)
|
||||||
x, y = (
|
x, y, angle = (
|
||||||
x.to(self.pytorch_settings.device),
|
x.to(self.pytorch_settings.device),
|
||||||
y.to(self.pytorch_settings.device),
|
y.to(self.pytorch_settings.device),
|
||||||
|
angle.to(self.pytorch_settings.device),
|
||||||
)
|
)
|
||||||
y_pred = self.model(x)
|
y_pred = self.model(x, -angle)
|
||||||
|
# loss = util.complexNN.complex_mse_loss(y_pred*torch.sqrt(torch.ones(1, device=y.device)*5), y*torch.sqrt(torch.ones(1, device=y.device)*5), power=True)
|
||||||
loss = util.complexNN.complex_mse_loss(y_pred, y, power=True)
|
loss = util.complexNN.complex_mse_loss(y_pred, y, power=True)
|
||||||
|
|
||||||
|
|
||||||
loss_value = loss.item()
|
loss_value = loss.item()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
@@ -868,16 +980,24 @@ class RegenerationTrainer:
|
|||||||
|
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
running_error = 0
|
running_error = 0
|
||||||
|
# x_key = "x_stacked"# if self.data_settings.polarisations else "x"
|
||||||
|
# y_key = "y_stacked"# if self.data_settings.polarisations else "y"
|
||||||
|
x_key = "x"
|
||||||
|
y_key = "y"
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for _, batch in enumerate(valid_loader):
|
for _, batch in enumerate(valid_loader):
|
||||||
x = batch["x"]
|
x = batch[x_key]
|
||||||
y = batch["y"]
|
y = batch[y_key]
|
||||||
x, y = (
|
angle = batch["mean_angle"]
|
||||||
|
x, y, angle = (
|
||||||
x.to(self.pytorch_settings.device),
|
x.to(self.pytorch_settings.device),
|
||||||
y.to(self.pytorch_settings.device),
|
y.to(self.pytorch_settings.device),
|
||||||
|
angle.to(self.pytorch_settings.device),
|
||||||
)
|
)
|
||||||
y_pred = self.model(x)
|
y_pred = self.model(x, -angle)
|
||||||
|
# error = util.complexNN.complex_mse_loss(y_pred*torch.sqrt(torch.ones(1, device=y.device)*5), y*torch.sqrt(torch.ones(1, device=y.device)*5), power=True)
|
||||||
error = util.complexNN.complex_mse_loss(y_pred, y, power=True)
|
error = util.complexNN.complex_mse_loss(y_pred, y, power=True)
|
||||||
|
|
||||||
error_value = error.item()
|
error_value = error.item()
|
||||||
running_error += error_value
|
running_error += error_value
|
||||||
|
|
||||||
@@ -894,7 +1014,7 @@ class RegenerationTrainer:
|
|||||||
if (epoch + 1) % 10 == 0 or epoch < 10:
|
if (epoch + 1) % 10 == 0 or epoch < 10:
|
||||||
# plotting is slow, so only do it every 10 epochs
|
# plotting is slow, so only do it every 10 epochs
|
||||||
title_append, subtitle = self.build_title(epoch + 1)
|
title_append, subtitle = self.build_title(epoch + 1)
|
||||||
head_fig, eye_fig, powers_fig = self.plot_model_response(
|
head_fig, eye_fig, weight_fig, powers_fig = self.plot_model_response(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
title_append=title_append,
|
title_append=title_append,
|
||||||
subtitle=subtitle,
|
subtitle=subtitle,
|
||||||
@@ -910,6 +1030,11 @@ class RegenerationTrainer:
|
|||||||
eye_fig,
|
eye_fig,
|
||||||
epoch + 1,
|
epoch + 1,
|
||||||
)
|
)
|
||||||
|
self.writer.add_figure(
|
||||||
|
"weights",
|
||||||
|
weight_fig,
|
||||||
|
epoch + 1,
|
||||||
|
)
|
||||||
|
|
||||||
self.writer.add_figure(
|
self.writer.add_figure(
|
||||||
"powers",
|
"powers",
|
||||||
@@ -928,45 +1053,70 @@ class RegenerationTrainer:
|
|||||||
def run_model(self, model, loader, trace_powers=False):
|
def run_model(self, model, loader, trace_powers=False):
|
||||||
model.eval()
|
model.eval()
|
||||||
fiber_out = []
|
fiber_out = []
|
||||||
|
fiber_out_rot = []
|
||||||
fiber_in = []
|
fiber_in = []
|
||||||
regen = []
|
regen = []
|
||||||
timestamps = []
|
timestamps = []
|
||||||
|
angles = []
|
||||||
|
# x_key = "x_stacked"# if self.data_settings.polarisations else "x"
|
||||||
|
# y_key = "y_stacked"# if self.data_settings.polarisations else "y"
|
||||||
|
x_key = "x"
|
||||||
|
y_key = "y"
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
model = model.to(self.pytorch_settings.device)
|
model = model.to(self.pytorch_settings.device)
|
||||||
for batch in loader:
|
for batch in loader:
|
||||||
x = batch["x"]
|
x = batch[x_key]
|
||||||
y = batch["y"]
|
y = batch[y_key]
|
||||||
|
plot_target = batch["plot_target"]
|
||||||
|
angle = batch["mean_angle"]
|
||||||
|
# center_angle = batch["center_angle"]
|
||||||
timestamp = batch["timestamp"]
|
timestamp = batch["timestamp"]
|
||||||
plot_data = batch["plot_data"]
|
plot_data = batch["plot_data"]
|
||||||
x, y = (
|
plot_data_rot = batch["plot_data_rot"]
|
||||||
|
x, y, angle = (
|
||||||
x.to(self.pytorch_settings.device),
|
x.to(self.pytorch_settings.device),
|
||||||
y.to(self.pytorch_settings.device),
|
y.to(self.pytorch_settings.device),
|
||||||
|
angle.to(self.pytorch_settings.device),
|
||||||
)
|
)
|
||||||
if trace_powers:
|
if trace_powers:
|
||||||
y_pred, powers = model(x, trace_powers).cpu()
|
y_pred, powers = model(x, -angle, True).cpu()
|
||||||
else:
|
else:
|
||||||
y_pred = model(x, trace_powers).cpu()
|
y_pred = model(x, -angle).cpu()
|
||||||
# x = x.cpu()
|
# x = x.cpu()
|
||||||
# y = y.cpu()
|
# y = y.cpu()
|
||||||
|
# if self.data_settings.polarisations:
|
||||||
|
y_pred = y_pred[:, :2]
|
||||||
y_pred = y_pred.view(y_pred.shape[0], -1, 2)
|
y_pred = y_pred.view(y_pred.shape[0], -1, 2)
|
||||||
y = y.view(y.shape[0], -1, 2)
|
y_pred = y_pred[:, y_pred.shape[1]//2, :]
|
||||||
plot_data = plot_data.view(plot_data.shape[0], -1, 2)
|
# y = y.view(y.shape[0], -1, 2)
|
||||||
|
# plot_data = plot_data.view(plot_data.shape[0], -1, 2)
|
||||||
|
# c = torch.cos(-angle).cpu()
|
||||||
|
# s = torch.sin(-angle).cpu()
|
||||||
|
# rot = torch.stack([torch.stack([c, -s], dim=1), torch.stack([s, c], dim=1)], dim=2).squeeze(-1)
|
||||||
|
# plot_data = torch.bmm(plot_data, rot.to(dtype=plot_data.dtype))
|
||||||
|
# plot_data = plot_data
|
||||||
|
# sines = torch.sin(-angle.cpu())
|
||||||
|
# cosines = torch.cos(-angle.cpu())
|
||||||
|
# plot_data = torch.stack((plot_data[..., 0] * cosines - plot_data[..., 1] * sines, plot_data[..., 0] * sines + plot_data[..., 1] * cosines), dim=-1)
|
||||||
# x = x.view(x.shape[0], -1, 2)
|
# x = x.view(x.shape[0], -1, 2)
|
||||||
|
|
||||||
# timestamp = timestamp.view(-1, 1)
|
# timestamp = timestamp.view(-1, 1)
|
||||||
fiber_out.append(plot_data.squeeze())
|
fiber_out.append(plot_data.squeeze())
|
||||||
fiber_in.append(y.squeeze())
|
fiber_out_rot.append(plot_data_rot.squeeze())
|
||||||
|
fiber_in.append(plot_target.squeeze())
|
||||||
regen.append(y_pred.squeeze())
|
regen.append(y_pred.squeeze())
|
||||||
timestamps.append(timestamp.squeeze())
|
timestamps.append(timestamp.squeeze())
|
||||||
|
angles.append(angle.squeeze())
|
||||||
|
|
||||||
fiber_out = torch.vstack(fiber_out).cpu()
|
fiber_out = torch.vstack(fiber_out).cpu()
|
||||||
|
fiber_out_rot = torch.vstack(fiber_out_rot).cpu()
|
||||||
fiber_in = torch.vstack(fiber_in).cpu()
|
fiber_in = torch.vstack(fiber_in).cpu()
|
||||||
regen = torch.vstack(regen).cpu()
|
regen = torch.vstack(regen).cpu()
|
||||||
|
angles = torch.vstack(angles).cpu()
|
||||||
timestamps = torch.concat(timestamps).cpu()
|
timestamps = torch.concat(timestamps).cpu()
|
||||||
if trace_powers:
|
if trace_powers:
|
||||||
return fiber_in, fiber_out, regen, timestamps, powers
|
return fiber_in, fiber_out, fiber_out_rot, angles, regen, timestamps, powers
|
||||||
return fiber_in, fiber_out, regen, timestamps
|
return fiber_in, fiber_out, fiber_out_rot, angles, regen, timestamps
|
||||||
|
|
||||||
def write_parameters(self, epoch, attributes: list[str] | tuple[str] = None):
|
def write_parameters(self, epoch, attributes: list[str] | tuple[str] = None):
|
||||||
parameter_list = get_parameter_names_and_values(self.model)
|
parameter_list = get_parameter_names_and_values(self.model)
|
||||||
@@ -998,7 +1148,7 @@ class RegenerationTrainer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
title_append, subtitle = self.build_title(0)
|
title_append, subtitle = self.build_title(0)
|
||||||
head_fig, eye_fig, powers_fig = self.plot_model_response(
|
head_fig, eye_fig, weight_fig, powers_fig = self.plot_model_response(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
title_append=title_append,
|
title_append=title_append,
|
||||||
subtitle=subtitle,
|
subtitle=subtitle,
|
||||||
@@ -1014,6 +1164,11 @@ class RegenerationTrainer:
|
|||||||
eye_fig,
|
eye_fig,
|
||||||
0,
|
0,
|
||||||
)
|
)
|
||||||
|
self.writer.add_figure(
|
||||||
|
"weights",
|
||||||
|
weight_fig,
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
|
||||||
self.writer.add_figure(
|
self.writer.add_figure(
|
||||||
"powers",
|
"powers",
|
||||||
@@ -1027,24 +1182,27 @@ class RegenerationTrainer:
|
|||||||
for i, config_path in enumerate(self.data_settings.config_path):
|
for i, config_path in enumerate(self.data_settings.config_path):
|
||||||
paths = Path.cwd().glob(config_path)
|
paths = Path.cwd().glob(config_path)
|
||||||
for j, path in enumerate(paths):
|
for j, path in enumerate(paths):
|
||||||
text = str(path) + '\n'
|
text = str(path) + "\n"
|
||||||
with open(path, 'r') as f:
|
with open(path, "r") as f:
|
||||||
text += f.read()
|
text += f.read()
|
||||||
text += '\n'
|
text += "\n"
|
||||||
self.writer.add_text(f"config_{i * len(self.data_settings.config_path) + j}", text)
|
self.writer.add_text(f"config_{i * len(self.data_settings.config_path) + j}", text)
|
||||||
elif isinstance(self.data_settings.config_path, str):
|
elif isinstance(self.data_settings.config_path, str):
|
||||||
paths = Path.cwd().glob(self.data_settings.config_path)
|
paths = Path.cwd().glob(self.data_settings.config_path)
|
||||||
for j, path in enumerate(paths):
|
for j, path in enumerate(paths):
|
||||||
text = str(path) + '\n'
|
text = str(path) + "\n"
|
||||||
with open(path, 'r') as f:
|
with open(path, "r") as f:
|
||||||
text += f.read()
|
text += f.read()
|
||||||
text += '\n'
|
text += "\n"
|
||||||
self.writer.add_text(f"config_{j}", text)
|
self.writer.add_text(f"config_{j}", text)
|
||||||
|
|
||||||
self.writer.flush()
|
self.writer.flush()
|
||||||
|
|
||||||
train_loader, valid_loader = self.get_sliced_data()
|
train_loader, valid_loader = self.get_sliced_data()
|
||||||
|
|
||||||
|
# train_loader.dataset.fiber_out.to(self.pytorch_settings.device)
|
||||||
|
# train_loader.dataset.fiber_in.to(self.pytorch_settings.device)
|
||||||
|
|
||||||
optimizer_name = self.optimizer_settings.optimizer
|
optimizer_name = self.optimizer_settings.optimizer
|
||||||
|
|
||||||
# lr = self.optimizer_settings.learning_rate
|
# lr = self.optimizer_settings.learning_rate
|
||||||
@@ -1074,6 +1232,7 @@ class RegenerationTrainer:
|
|||||||
# except ValueError:
|
# except ValueError:
|
||||||
# pass
|
# pass
|
||||||
|
|
||||||
|
self.early_stop_vals = {"min_loss": float("inf"), "plateau_cnt": 0}
|
||||||
for epoch in range(self.best["epoch"] + 1, self.pytorch_settings.epochs):
|
for epoch in range(self.best["epoch"] + 1, self.pytorch_settings.epochs):
|
||||||
enable_progress = True
|
enable_progress = True
|
||||||
if enable_progress:
|
if enable_progress:
|
||||||
@@ -1089,9 +1248,48 @@ class RegenerationTrainer:
|
|||||||
epoch,
|
epoch,
|
||||||
enable_progress=enable_progress,
|
enable_progress=enable_progress,
|
||||||
)
|
)
|
||||||
|
if self.early_stop(loss):
|
||||||
|
self.save_model_checkpoints(epoch, loss)
|
||||||
|
break
|
||||||
if self.optimizer_settings.scheduler is not None:
|
if self.optimizer_settings.scheduler is not None:
|
||||||
self.scheduler.step(loss)
|
self.scheduler.step(loss)
|
||||||
self.writer.add_scalar("learning rate", self.optimizer.param_groups[0]["lr"], epoch)
|
self.writer.add_scalar("learning rate", self.optimizer.param_groups[0]["lr"], epoch)
|
||||||
|
self.save_model_checkpoints(epoch, loss)
|
||||||
|
self.writer.flush()
|
||||||
|
|
||||||
|
save_path = (Path(self.pytorch_settings.model_dir) / f"best_{self.writer.get_logdir().split('/')[-1]}.tar")
|
||||||
|
print(f"Training complete. Best checkpoint: {save_path}")
|
||||||
|
self.writer.close()
|
||||||
|
return self.best
|
||||||
|
|
||||||
|
def early_stop(self, loss):
|
||||||
|
# not stopping early at all
|
||||||
|
if not self.optimizer_settings.early_stopping:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# stopping because of abs threshold
|
||||||
|
if (loss_thr := self.optimizer_settings.early_stop_kwargs.get("threshold", None)) is not None:
|
||||||
|
if loss <= loss_thr:
|
||||||
|
print(f"Early stop: loss is below threshold ({loss:.2e} <= {loss_thr:.2e})")
|
||||||
|
return True
|
||||||
|
|
||||||
|
# update vals
|
||||||
|
if loss < self.early_stop_vals["min_loss"]:
|
||||||
|
self.early_stop_vals["min_loss"] = loss
|
||||||
|
self.early_stop_vals["plateau_cnt"] = 0
|
||||||
|
return False
|
||||||
|
|
||||||
|
# stopping because of plateau
|
||||||
|
if (plateau_thresh := self.optimizer_settings.early_stop_kwargs.get("plateau", None)) is not None:
|
||||||
|
self.early_stop_vals["plateau_cnt"] += 1
|
||||||
|
if self.early_stop_vals["plateau_cnt"] >= plateau_thresh:
|
||||||
|
print(f"Early stop: loss plateau length over threshold ({self.early_stop_vals["plateau_cnt"]} >= {plateau_thresh})")
|
||||||
|
return True
|
||||||
|
|
||||||
|
# no stop
|
||||||
|
return False
|
||||||
|
|
||||||
|
def save_model_checkpoints(self, epoch, loss):
|
||||||
if self.pytorch_settings.save_models and self.model is not None:
|
if self.pytorch_settings.save_models and self.model is not None:
|
||||||
save_path = (
|
save_path = (
|
||||||
Path(self.pytorch_settings.model_dir) / f"{self.writer.get_logdir().split('/')[-1]}_{epoch}.tar"
|
Path(self.pytorch_settings.model_dir) / f"{self.writer.get_logdir().split('/')[-1]}_{epoch}.tar"
|
||||||
@@ -1107,15 +1305,12 @@ class RegenerationTrainer:
|
|||||||
)
|
)
|
||||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
self.save_checkpoint(self.best, save_path)
|
self.save_checkpoint(self.best, save_path)
|
||||||
self.writer.flush()
|
|
||||||
|
|
||||||
self.writer.close()
|
|
||||||
return self.best
|
|
||||||
|
|
||||||
def _plot_model_response_powers(self, powers, layer_names, title_append="", subtitle="", show=True):
|
def _plot_model_response_powers(self, powers, layer_names, title_append="", subtitle="", show=True):
|
||||||
powers = [power / powers[0] for power in powers]
|
powers = [power / powers[0] for power in powers]
|
||||||
fig, ax = plt.subplots()
|
fig, ax = plt.subplots()
|
||||||
fig.set_figwidth(18)
|
fig.set_figwidth(18)
|
||||||
|
fig.set_figheight(4)
|
||||||
fig.suptitle(
|
fig.suptitle(
|
||||||
f"Energy conservation{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}"
|
f"Energy conservation{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}"
|
||||||
)
|
)
|
||||||
@@ -1131,6 +1326,77 @@ class RegenerationTrainer:
|
|||||||
plt.show()
|
plt.show()
|
||||||
return fig
|
return fig
|
||||||
|
|
||||||
|
def _plot_model_weights(self, model, title_append="", subtitle="", show=True):
|
||||||
|
model_params = []
|
||||||
|
plots = []
|
||||||
|
dims = []
|
||||||
|
for num, (layer_name, layer) in enumerate(model.named_children()):
|
||||||
|
onn_weights = layer.ONN.weight
|
||||||
|
onn_weights = onn_weights.detach().cpu().numpy()
|
||||||
|
onn_values = np.abs(onn_weights).real
|
||||||
|
onn_angles = np.mod(np.angle(onn_weights), 2*np.pi).real
|
||||||
|
|
||||||
|
model_params.append({layer_name: onn_weights})
|
||||||
|
plots.append({layer_name: (num, onn_values, onn_angles)})
|
||||||
|
dims.append(onn_weights.shape[0])
|
||||||
|
|
||||||
|
max_size = np.max(dims)
|
||||||
|
|
||||||
|
for plot in plots:
|
||||||
|
layer_name, (num, onn_values, onn_angles) = plot.popitem()
|
||||||
|
|
||||||
|
if num == 0:
|
||||||
|
value_img = onn_values
|
||||||
|
angle_img = onn_angles
|
||||||
|
onn_angles = pad_to_size(onn_angles, (max_size, None))
|
||||||
|
onn_values = pad_to_size(onn_values, (max_size, None))
|
||||||
|
else:
|
||||||
|
onn_values = pad_to_size(onn_values, (max_size, onn_values.shape[1]+1))
|
||||||
|
onn_angles = pad_to_size(onn_angles, (max_size, onn_angles.shape[1]+1))
|
||||||
|
value_img = np.concatenate((value_img, onn_values), axis=1)
|
||||||
|
angle_img = np.concatenate((angle_img, onn_angles), axis=1)
|
||||||
|
|
||||||
|
value_img = np.ma.array(value_img, mask=np.isnan(value_img))
|
||||||
|
angle_img = np.ma.array(angle_img, mask=np.isnan(angle_img))
|
||||||
|
|
||||||
|
fig, axs = plt.subplots(1, 2, sharex=True, sharey=True, figsize=(18, 6.5))
|
||||||
|
fig.tight_layout()
|
||||||
|
|
||||||
|
dividers = map(make_axes_locatable, axs)
|
||||||
|
caxs = list(map(lambda d: d.append_axes("right", size="5%", pad=0.05), dividers))
|
||||||
|
|
||||||
|
masked_value_img = value_img
|
||||||
|
cmap = cm.batlow
|
||||||
|
cmap.set_bad(color="#AAAAAA")
|
||||||
|
im_val = axs[0].imshow(masked_value_img, cmap=cmap, vmin=0, vmax=1)
|
||||||
|
fig.colorbar(im_val, cax=caxs[0], orientation="vertical")
|
||||||
|
|
||||||
|
masked_angle_img = np.ma.masked_where(np.isnan(angle_img), angle_img)
|
||||||
|
cmap = cm.romaO
|
||||||
|
cmap.set_bad(color="#AAAAAA")
|
||||||
|
im_ang = axs[1].imshow(masked_angle_img, cmap=cmap, vmin=0, vmax=2*np.pi)
|
||||||
|
cbar = fig.colorbar(im_ang, cax=caxs[1], orientation="vertical", ticks=[i/8 * 2*np.pi for i in range(9)])
|
||||||
|
cbar.ax.set_yticklabels(["0", "π/4", "π/2", "3π/4", "π", "5π/4", "3π/2", "7π/4", "2π"])
|
||||||
|
|
||||||
|
|
||||||
|
axs[0].axis("off")
|
||||||
|
axs[1].axis("off")
|
||||||
|
|
||||||
|
axs[0].set_title("Values")
|
||||||
|
axs[1].set_title("Angles")
|
||||||
|
|
||||||
|
title = "Layer Weights"
|
||||||
|
if title_append:
|
||||||
|
title += f" {title_append}"
|
||||||
|
if subtitle:
|
||||||
|
title += f"\n{subtitle}"
|
||||||
|
fig.suptitle(title)
|
||||||
|
|
||||||
|
|
||||||
|
if show:
|
||||||
|
plt.show()
|
||||||
|
return fig
|
||||||
|
|
||||||
def _plot_model_response_eye(
|
def _plot_model_response_eye(
|
||||||
self, *signals, timestamps, labels=None, sps=None, title_append="", subtitle="", show=True
|
self, *signals, timestamps, labels=None, sps=None, title_append="", subtitle="", show=True
|
||||||
):
|
):
|
||||||
@@ -1184,6 +1450,7 @@ class RegenerationTrainer:
|
|||||||
|
|
||||||
fig, axs = plt.subplots(1, 2 * len(signals), sharex=True, sharey=True)
|
fig, axs = plt.subplots(1, 2 * len(signals), sharex=True, sharey=True)
|
||||||
fig.set_figwidth(18)
|
fig.set_figwidth(18)
|
||||||
|
fig.set_figheight(4)
|
||||||
fig.suptitle(f"Eye diagram{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}")
|
fig.suptitle(f"Eye diagram{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}")
|
||||||
# xaxis = timestamps / sps
|
# xaxis = timestamps / sps
|
||||||
# xaxis = np.arange(2 * sps) / sps
|
# xaxis = np.arange(2 * sps) / sps
|
||||||
@@ -1253,7 +1520,7 @@ class RegenerationTrainer:
|
|||||||
xaxis = timestamps / sps
|
xaxis = timestamps / sps
|
||||||
else:
|
else:
|
||||||
xaxis = timestamps
|
xaxis = timestamps
|
||||||
ax.plot(xaxis, np.abs(signal[:, i]) ** 2, label=label)
|
ax.plot(xaxis, np.abs(signal[:, i]) ** 2, label=label, alpha=0.7)
|
||||||
ax.set_xlabel("Sample" if sps is None else "Symbol")
|
ax.set_xlabel("Sample" if sps is None else "Symbol")
|
||||||
ax.set_ylabel("normalized power")
|
ax.set_ylabel("normalized power")
|
||||||
ax.minorticks_on()
|
ax.minorticks_on()
|
||||||
@@ -1281,7 +1548,9 @@ class RegenerationTrainer:
|
|||||||
model = model.to(self.pytorch_settings.device)
|
model = model.to(self.pytorch_settings.device)
|
||||||
model.eval()
|
model.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
_, powers = model(input_data, trace_powers=True)
|
_, powers = model(
|
||||||
|
input_data, torch.zeros(input_data.shape[0], 1).to(self.pytorch_settings.device), trace_powers=True
|
||||||
|
)
|
||||||
|
|
||||||
powers = [power.item() for power in powers]
|
powers = [power.item() for power in powers]
|
||||||
layer_names = [name for (name, _) in model.named_children()]
|
layer_names = [name for (name, _) in model.named_children()]
|
||||||
@@ -1292,33 +1561,48 @@ class RegenerationTrainer:
|
|||||||
|
|
||||||
data_settings_backup = copy.deepcopy(self.data_settings)
|
data_settings_backup = copy.deepcopy(self.data_settings)
|
||||||
pytorch_settings_backup = copy.deepcopy(self.pytorch_settings)
|
pytorch_settings_backup = copy.deepcopy(self.pytorch_settings)
|
||||||
self.data_settings.drop_first = 99.5 + random.randint(0, 1000)
|
self.data_settings.drop_first = int(64 + random.randint(0, 1000))
|
||||||
self.data_settings.shuffle = False
|
self.data_settings.shuffle = False
|
||||||
self.data_settings.train_split = 1.0
|
self.data_settings.train_split = 1.0
|
||||||
self.pytorch_settings.batchsize = max(self.pytorch_settings.head_symbols, self.pytorch_settings.eye_symbols)
|
self.pytorch_settings.batchsize = max(self.pytorch_settings.head_symbols, self.pytorch_settings.eye_symbols)
|
||||||
config_path = random.choice(self.data_settings.config_path) if isinstance(self.data_settings.config_path, (list, tuple)) else self.data_settings.config_path
|
config_path = (
|
||||||
fiber_length = int(float(str(config_path).split('-')[4])/1000)
|
random.choice(self.data_settings.config_path)
|
||||||
|
if isinstance(self.data_settings.config_path, (list, tuple))
|
||||||
|
else self.data_settings.config_path
|
||||||
|
)
|
||||||
|
# fiber_length = int(float(str(config_path).split("-")[4]) / 1000)
|
||||||
if not hasattr(self, "_plot_loader"):
|
if not hasattr(self, "_plot_loader"):
|
||||||
self._plot_loader, _ = self.get_sliced_data(
|
self._plot_loader, _ = self.get_sliced_data(
|
||||||
override={
|
override={
|
||||||
"num_symbols": self.pytorch_settings.batchsize,
|
"num_symbols": self.pytorch_settings.batchsize,
|
||||||
"config_path": config_path,
|
"config_path": config_path,
|
||||||
"shuffle": False,
|
"shuffle": False,
|
||||||
"polarisations": (np.random.rand(1)*np.pi*2,),
|
# "polarisations": (np.random.rand(1) * np.pi * 2,),
|
||||||
"randomise_polarisation": False,
|
"polarisations": self.data_settings.polarisations,
|
||||||
|
"randomise_polarisation": self.data_settings.randomise_polarisations,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
self._sps = self._plot_loader.dataset.samples_per_symbol
|
self._sps = self._plot_loader.dataset.samples_per_symbol
|
||||||
|
fiber_length = float(self._plot_loader.dataset.config["fiber"]["length"])/1000
|
||||||
self.data_settings = data_settings_backup
|
self.data_settings = data_settings_backup
|
||||||
self.pytorch_settings = pytorch_settings_backup
|
self.pytorch_settings = pytorch_settings_backup
|
||||||
|
|
||||||
fiber_in, fiber_out, regen, timestamps = self.run_model(model, self._plot_loader)
|
fiber_in, fiber_out, fiber_out_rot, angles, regen, timestamps = self.run_model(model, self._plot_loader)
|
||||||
fiber_in = fiber_in.view(-1, 2)
|
fiber_in = fiber_in.view(-1, 2)
|
||||||
fiber_out = fiber_out.view(-1, 2)
|
fiber_out = fiber_out.view(-1, 2)
|
||||||
|
fiber_out_rot = fiber_out_rot.view(-1, 2)
|
||||||
|
angles = angles.view(-1, 1)
|
||||||
|
angles = angles.real
|
||||||
|
angles = torch.fmod(angles, 2*torch.pi)
|
||||||
|
angles = torch.div(angles, 2*torch.pi)
|
||||||
|
angles = torch.repeat_interleave(angles, 2, dim=1)
|
||||||
|
|
||||||
regen = regen.view(-1, 2)
|
regen = regen.view(-1, 2)
|
||||||
|
|
||||||
fiber_in = fiber_in.numpy()
|
fiber_in = fiber_in.numpy()
|
||||||
fiber_out = fiber_out.numpy()
|
fiber_out = fiber_out.numpy()
|
||||||
|
fiber_out_rot = fiber_out_rot.numpy()
|
||||||
|
angles = angles.numpy()
|
||||||
regen = regen.numpy()
|
regen = regen.numpy()
|
||||||
timestamps = timestamps.numpy()
|
timestamps = timestamps.numpy()
|
||||||
|
|
||||||
@@ -1327,11 +1611,12 @@ class RegenerationTrainer:
|
|||||||
import gc
|
import gc
|
||||||
|
|
||||||
head_fig = self._plot_model_response_head(
|
head_fig = self._plot_model_response_head(
|
||||||
fiber_in[:self.pytorch_settings.head_symbols*self._sps],
|
|
||||||
fiber_out[: self.pytorch_settings.head_symbols * self._sps],
|
fiber_out[: self.pytorch_settings.head_symbols * self._sps],
|
||||||
|
fiber_in[: self.pytorch_settings.head_symbols * self._sps],
|
||||||
regen[: self.pytorch_settings.head_symbols * self._sps],
|
regen[: self.pytorch_settings.head_symbols * self._sps],
|
||||||
|
angles[: self.pytorch_settings.head_symbols * self._sps],
|
||||||
timestamps=timestamps[: self.pytorch_settings.head_symbols * self._sps],
|
timestamps=timestamps[: self.pytorch_settings.head_symbols * self._sps],
|
||||||
labels=("fiber in", "fiber out", "regen"),
|
labels=("fiber out", "fiber in", "regen", "normed angle"),
|
||||||
sps=self._sps,
|
sps=self._sps,
|
||||||
title_append=title_append + f" ({fiber_length} km)",
|
title_append=title_append + f" ({fiber_length} km)",
|
||||||
subtitle=subtitle,
|
subtitle=subtitle,
|
||||||
@@ -1349,9 +1634,11 @@ class RegenerationTrainer:
|
|||||||
subtitle=subtitle,
|
subtitle=subtitle,
|
||||||
show=show,
|
show=show,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
weight_fig = self._plot_model_weights(model, title_append=title_append, subtitle=subtitle, show=show)
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
return head_fig, eye_fig, power_fig
|
return head_fig, eye_fig, weight_fig, power_fig
|
||||||
|
|
||||||
def build_title(self, number: int):
|
def build_title(self, number: int):
|
||||||
title_append = f"epoch {number}"
|
title_append = f"epoch {number}"
|
||||||
@@ -1361,7 +1648,7 @@ class RegenerationTrainer:
|
|||||||
self.model_settings.overrides.get(f"n_hidden_nodes_{i}", -1) for i in range(model_n_hidden_layers)
|
self.model_settings.overrides.get(f"n_hidden_nodes_{i}", -1) for i in range(model_n_hidden_layers)
|
||||||
]
|
]
|
||||||
model_dims.insert(0, input_dim)
|
model_dims.insert(0, input_dim)
|
||||||
model_dims.append(2)
|
model_dims.append(self.model_settings.output_dim)
|
||||||
model_dims = [str(dim) for dim in model_dims]
|
model_dims = [str(dim) for dim in model_dims]
|
||||||
model_activation_func = self.model_settings.model_activation_func
|
model_activation_func = self.model_settings.model_activation_func
|
||||||
model_dtype = self.data_settings.dtype
|
model_dtype = self.data_settings.dtype
|
||||||
|
|||||||
217
src/single-core-regen/plot_model.py
Normal file
217
src/single-core-regen/plot_model.py
Normal file
@@ -0,0 +1,217 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import util
|
||||||
|
from hypertraining.settings import GlobalSettings, DataSettings, ModelSettings, OptimizerSettings, PytorchSettings
|
||||||
|
from hypertraining import models
|
||||||
|
|
||||||
|
# def move_to_location_in_size(array, location, size):
|
||||||
|
# array_x, array_y = array.shape
|
||||||
|
# location_x, location_y = location
|
||||||
|
# size_x, size_y = size
|
||||||
|
|
||||||
|
# left = location_x
|
||||||
|
# right = size_x - array_x - location_x
|
||||||
|
|
||||||
|
# top = location_y
|
||||||
|
# bottom = size_y - array_y - location_y
|
||||||
|
|
||||||
|
# return np.pad(
|
||||||
|
# array,
|
||||||
|
# (
|
||||||
|
# (left, right),
|
||||||
|
# (top, bottom),
|
||||||
|
# ),
|
||||||
|
# constant_values=(-np.inf, -np.inf),
|
||||||
|
# )
|
||||||
|
|
||||||
|
def register_puccs_cmap(puccs_path=None):
|
||||||
|
puccs_path = Path(__file__).resolve().parent / 'puccs.csv' if puccs_path is None else puccs_path
|
||||||
|
|
||||||
|
colors = []
|
||||||
|
# keys = None
|
||||||
|
with open(puccs_path, "r") as f:
|
||||||
|
for i, line in enumerate(f.readlines()):
|
||||||
|
elements = tuple(line.split(","))
|
||||||
|
# if i == 0:
|
||||||
|
# # keys = elements
|
||||||
|
# continue
|
||||||
|
# else:
|
||||||
|
try:
|
||||||
|
colors.append(tuple(map(float, elements[4:])))
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
# colors = []
|
||||||
|
# for current in puccs_csv_data:
|
||||||
|
# colors.append(tuple(current[4:]))
|
||||||
|
from matplotlib.colors import LinearSegmentedColormap
|
||||||
|
import matplotlib as mpl
|
||||||
|
mpl.colormaps.register(LinearSegmentedColormap.from_list('puccs', colors))
|
||||||
|
|
||||||
|
def pad_to_size(array, size):
|
||||||
|
if not hasattr(size, "__len__"):
|
||||||
|
size = (size, size)
|
||||||
|
|
||||||
|
left = (
|
||||||
|
(size[0] - array.shape[0] + 1) // 2 if size[0] is not None else 0
|
||||||
|
)
|
||||||
|
right = (
|
||||||
|
(size[0] - array.shape[0]) // 2 if size[0] is not None else 0
|
||||||
|
)
|
||||||
|
top = (
|
||||||
|
(size[1] - array.shape[1] + 1) // 2 if size[1] is not None else 0
|
||||||
|
)
|
||||||
|
bottom = (
|
||||||
|
(size[1] - array.shape[1]) // 2 if size[1] is not None else 0
|
||||||
|
)
|
||||||
|
|
||||||
|
array: np.ndarray = array
|
||||||
|
if array.ndim == 2:
|
||||||
|
return np.pad(
|
||||||
|
array,
|
||||||
|
(
|
||||||
|
(left, right),
|
||||||
|
(top, bottom),
|
||||||
|
),
|
||||||
|
constant_values=(np.nan, np.nan),
|
||||||
|
)
|
||||||
|
elif array.ndim == 3:
|
||||||
|
return np.pad(
|
||||||
|
array,
|
||||||
|
(
|
||||||
|
(left, right),
|
||||||
|
(top, bottom),
|
||||||
|
(0,0)
|
||||||
|
),
|
||||||
|
constant_values=(np.nan, np.nan),
|
||||||
|
)
|
||||||
|
|
||||||
|
def model_plot(model_path, show=True):
|
||||||
|
torch.serialization.add_safe_globals([
|
||||||
|
*util.complexNN.__all__,
|
||||||
|
GlobalSettings,
|
||||||
|
DataSettings,
|
||||||
|
ModelSettings,
|
||||||
|
OptimizerSettings,
|
||||||
|
PytorchSettings,
|
||||||
|
models.regenerator,
|
||||||
|
torch.nn.utils.parametrizations.orthogonal,
|
||||||
|
])
|
||||||
|
checkpoint_dict = torch.load(model_path, weights_only=True)
|
||||||
|
|
||||||
|
dims = checkpoint_dict["model_kwargs"].pop("dims")
|
||||||
|
|
||||||
|
model = models.regenerator(*dims, **checkpoint_dict["model_kwargs"])
|
||||||
|
model.load_state_dict(checkpoint_dict["model_state_dict"], strict=False)
|
||||||
|
|
||||||
|
model_params = []
|
||||||
|
plots = []
|
||||||
|
max_size = np.max(dims)
|
||||||
|
# max_act_size = np.max(dims[1:])
|
||||||
|
|
||||||
|
# angles = [None, None]
|
||||||
|
# weights = [None, None]
|
||||||
|
|
||||||
|
for num, (layer_name, layer) in enumerate(model.named_children()):
|
||||||
|
# each layer contains an "ONN" layer and an "activation" layer
|
||||||
|
# activation layer is approximately the same for all layers and nodes -> rotation by 90 degrees
|
||||||
|
onn_weights = layer.ONN.weight
|
||||||
|
onn_weights = onn_weights.detach().cpu().numpy()
|
||||||
|
onn_values = np.abs(onn_weights).real
|
||||||
|
onn_angles = np.mod(np.angle(onn_weights), 2*np.pi).real
|
||||||
|
|
||||||
|
model_params.append({layer_name: onn_weights})
|
||||||
|
plots.append({layer_name: (num, onn_values, onn_angles)})#, act_values, act_angles)})
|
||||||
|
|
||||||
|
# fig, axs = plt.subplots(3, len(model_params)*2-1, figsize=(20, 5))
|
||||||
|
|
||||||
|
for plot in plots:
|
||||||
|
layer_name, (num, onn_values, onn_angles) = plot.popitem()
|
||||||
|
|
||||||
|
if num == 0:
|
||||||
|
value_img = onn_values
|
||||||
|
angle_img = onn_angles
|
||||||
|
onn_angles = pad_to_size(onn_angles, (max_size, None))
|
||||||
|
onn_values = pad_to_size(onn_values, (max_size, None))
|
||||||
|
else:
|
||||||
|
onn_values = pad_to_size(onn_values, (max_size, onn_values.shape[1]+1))
|
||||||
|
onn_angles = pad_to_size(onn_angles, (max_size, onn_angles.shape[1]+1))
|
||||||
|
value_img = np.concatenate((value_img, onn_values), axis=1)
|
||||||
|
angle_img = np.concatenate((angle_img, onn_angles), axis=1)
|
||||||
|
|
||||||
|
value_img = np.ma.array(value_img, mask=np.isnan(value_img))
|
||||||
|
angle_img = np.ma.array(angle_img, mask=np.isnan(angle_img))
|
||||||
|
|
||||||
|
# from cmcrameri import cm
|
||||||
|
from cmap import Colormap as cm
|
||||||
|
import scicomap as sc
|
||||||
|
# from matplotlib import colors as mcolors
|
||||||
|
# alpha_map = mcolors.LinearSegmentedColormap(
|
||||||
|
# 'alphamap',
|
||||||
|
# {
|
||||||
|
# 'red': [(0, 0, 0), (1, 0, 0)],
|
||||||
|
# 'green': [(0, 0, 0), (1, 0, 0)],
|
||||||
|
# 'blue': [(0, 0, 0), (1, 0, 0)],
|
||||||
|
# 'alpha': [
|
||||||
|
# (0, 1, 1),
|
||||||
|
# # (0.2, 0.2, 0.1),
|
||||||
|
# (1, 0, 0)
|
||||||
|
# ]
|
||||||
|
# }
|
||||||
|
# )
|
||||||
|
# alpha_map.set_bad(color="#AAAAAA")
|
||||||
|
|
||||||
|
|
||||||
|
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
||||||
|
|
||||||
|
fig, axs = plt.subplots(1, 2, sharex=True, sharey=True, figsize=(16, 5))
|
||||||
|
# fig.tight_layout()
|
||||||
|
dividers = map(make_axes_locatable, axs)
|
||||||
|
caxs = list(map(lambda d: d.append_axes("right", size="5%", pad=0.05), dividers))
|
||||||
|
# masked_value_img = np.ma.masked_where(np.isnan(value_img), value_img)
|
||||||
|
masked_value_img = value_img
|
||||||
|
cmap = cm('google:turbo').to_matplotlib()
|
||||||
|
# cmap = sc.ScicoSequential("rainbow").get_mpl_color_map()
|
||||||
|
cmap.set_bad(color="#AAAAAA")
|
||||||
|
im_val = axs[0].imshow(masked_value_img, cmap=cmap, vmin=0, vmax=1)
|
||||||
|
fig.colorbar(im_val, cax=caxs[0], orientation="vertical")
|
||||||
|
|
||||||
|
|
||||||
|
masked_angle_img = np.ma.masked_where(np.isnan(angle_img), angle_img)
|
||||||
|
# cmap = cm('crameri:romao').to_matplotlib()
|
||||||
|
# cmap = plt.get_cmap('puccs')
|
||||||
|
# cmap = sc.ScicoCircular("colorwheel").get_mpl_color_map()
|
||||||
|
cmap = cm('colorcet:CET_C8').to_matplotlib()
|
||||||
|
cmap.set_bad(color="#AAAAAA")
|
||||||
|
im_ang = axs[1].imshow(masked_angle_img, cmap=cmap, vmin=0, vmax=2*np.pi)
|
||||||
|
cbar = fig.colorbar(im_ang, cax=caxs[1], orientation="vertical", ticks=[i/8 * 2*np.pi for i in range(9)])
|
||||||
|
cbar.ax.set_yticklabels(["0", "π/4", "π/2", "3π/4", "π", "5π/4", "3π/2", "7π/4", "2π"])
|
||||||
|
# im_ang_w = axs[2].imshow(masked_angle_img, cmap=cmap)
|
||||||
|
# im_ang_w = axs[2].imshow(masked_value_img, cmap=alpha_map)
|
||||||
|
|
||||||
|
axs[0].axis("off")
|
||||||
|
axs[1].axis("off")
|
||||||
|
# axs[2].axis("off")
|
||||||
|
|
||||||
|
axs[0].set_title("Values")
|
||||||
|
axs[1].set_title("Angles")
|
||||||
|
# axs[2].set_title("Values and Angles")
|
||||||
|
|
||||||
|
|
||||||
|
...
|
||||||
|
if show:
|
||||||
|
plt.show()
|
||||||
|
return fig
|
||||||
|
|
||||||
|
# model = models.regenerator(*dims, **model_kwargs)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
register_puccs_cmap()
|
||||||
|
if len(sys.argv) > 1:
|
||||||
|
model_plot(sys.argv[1])
|
||||||
|
else:
|
||||||
|
print("Please provide a model path as an argument")
|
||||||
|
# model_plot(".models/best_20250114_224234.tar")
|
||||||
102
src/single-core-regen/puccs.csv
Normal file
102
src/single-core-regen/puccs.csv
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
"x","L","a","b","R","G","B"
|
||||||
|
0.,0.5187848173343539,0.6399990176455989,0.67,0.8889427469969852,0.22673227640012172,0.
|
||||||
|
0.01,0.5374499525557803,0.604014067614707,0.6777967519386492,0.8956274406155226,0.27553288030331824,0.
|
||||||
|
0.02,0.5560867887452998,0.5680836759482211,0.6855816828789898,0.9019507507843885,0.318608215541461,0.
|
||||||
|
0.03,0.5746877595125583,0.5322224300667823,0.6933516322080414,0.907905487190649,0.3580633000693721,0.
|
||||||
|
0.04,0.5932314662487472,0.49647158484797804,0.7010976613543587,0.9134808162089558,0.3949845524063657,0.
|
||||||
|
0.05,0.6117000836392819,0.46086550613202343,0.7088123243737041,0.918668356138916,0.43002019316005363,0.
|
||||||
|
0.06,0.6300828534995973,0.4254249348741487,0.7164911273850869,0.923462736751354,0.4635961938811463,0.
|
||||||
|
0.07,0.6483763163456417,0.3901565406944371,0.7241326253017896,0.9278609626724071,0.49601354353255284,0.
|
||||||
|
0.08,0.6665840140182806,0.3550534951951814,0.7317382976124045,0.9318616057744784,0.5274983630587982,0.
|
||||||
|
0.09,0.6847162776119433,0.3200958808181962,0.7393124597949372,0.9354640163365924,0.5582303922647159,0.
|
||||||
|
0.1,0.7027902128942014,0.2852507189547545,0.7468622572263107,0.9386675557407496,0.5883604892249517,0.004034952213848706
|
||||||
|
0.11,0.7208298719332069,0.25047163906104203,0.7543977368741345,0.9414708123927996,0.6180221032545026,0.016031521294251994
|
||||||
|
0.12,0.7388665670611175,0.2156982733607376,0.7619319784446927,0.943870754968487,0.6473392272576862,0.029857267582036696
|
||||||
|
0.13,0.7569392765472108,0.18085547473834482,0.7694812638396673,0.9458617774020323,0.676432172396153,0.045365670193636125
|
||||||
|
0.14,0.7750950944867471,0.14585244938794778,0.7770652650825484,0.9474345911958609,0.7054219201084561,0.06017985923530026
|
||||||
|
0.15,0.793389684293558,0.11058188251425949,0.7847072337503834,0.9485749196617762,0.7344334940032564,0.07418869502646075
|
||||||
|
0.16,0.8117919447684838,0.07510373484536464,0.792394178330817,0.9492596163836376,0.7634480277996188,0.08767517868137237
|
||||||
|
0.17,0.8293050962981561,0.03629277424762101,0.799038155466063,0.9462308253550155,0.7922009241807345,0.10066327128139077
|
||||||
|
0.18,0.8213303100752708,-0.0062517290795987,0.7879999288492758,0.9088702681901394,0.7940579017644396,0.10139639009534024
|
||||||
|
0.19,0.8134831311534617,-0.048115463155645855,0.7771383286984362,0.8716809050191757,0.7954897210083888,0.10232311621802098
|
||||||
|
0.2,0.80558613530069,-0.0902449644291895,0.7662077749032042,0.8337524177888596,0.7965471523787845,0.10344968926026826
|
||||||
|
0.21,0.7975860185564765,-0.13292460297117392,0.7551344872795225,0.7947193410849823,0.7972381033243311,0.10477682283894393
|
||||||
|
0.22,0.7894147026971006,-0.17651756772919341,0.7438242359834689,0.7540941866826836,0.7975605026647324,0.10631182441371936
|
||||||
|
0.23,0.7809997374598548,-0.2214103719409295,0.7321767396537806,0.7112894518675287,0.7974995317311054,0.1080672415170634
|
||||||
|
0.24,0.7722646970273015,-0.2680107379394189,0.7200862142018722,0.6655745739336695,0.7970267795229349,0.11006041388465265
|
||||||
|
0.25,0.7631307298557146,-0.3167393290089981,0.7074435179925446,0.6160047476007512,0.7960993904970947,0.11231257117602686
|
||||||
|
0.26,0.7535192192483822,-0.36801555555407994,0.6941398344519211,0.5612859274945571,0.794659599537827,0.11484733363789801
|
||||||
|
0.27,0.7433557597838075,-0.42223636134393283,0.6800721760037781,0.4994862901720824,0.7926351396848288,0.11768844813479104
|
||||||
|
0.28,0.732575139048096,-0.479749646583324,0.6651502794883674,0.42731393423789277,0.7899410218414098,0.12085678487511567
|
||||||
|
0.29,0.7211269294461059,-0.5408244362880141,0.6493043460161184,0.3378265607222193,0.786483110019224,0.124366774034814
|
||||||
|
0.3,0.7090756028785993,-0.6051167807996883,0.6326236137723747,0.2098475715121697,0.7821998608677176,0.12819222127525928
|
||||||
|
0.31,0.7094510768540225,-0.6165036055456403,0.5630307498747129,0.15061488620640032,0.7845112116922692,0.21943537230975235
|
||||||
|
0.32,0.7174669421288304,-0.5917687864932311,0.4797229624661701,0.18766933782916642,0.7905828987725732,0.31091344246312086
|
||||||
|
0.33,0.7249009746435938,-0.5688293479200438,0.40246208306061504,0.21160609617940718,0.7962175427587832,0.38519766326885596
|
||||||
|
0.34,0.7317072855135611,-0.5478268906666535,0.3317250285377912,0.22717569971119178,0.8013847719431052,0.4490960048955565
|
||||||
|
0.35,0.7379328517830899,-0.5286164561226088,0.26702357292455026,0.23690087622812972,0.8061220291668977,0.5056371468159843
|
||||||
|
0.36,0.7436229063122554,-0.5110584677642499,0.20788761731555405,0.24226377668817778,0.8104638164122776,0.5563570758573497
|
||||||
|
0.37,0.7488251728809415,-0.4950056627547577,0.15382117501783654,0.24424372086048424,0.8144455902164638,0.6022301663745243
|
||||||
|
0.38,0.7535943992285348,-0.48028910419451787,0.10425526029155024,0.24352232677523483,0.818107753931944,0.6440238320299774
|
||||||
|
0.39,0.757994865186593,-0.4667104416936734,0.05852182167144754,0.240562414747303,0.8214980148949816,0.6824536572462205
|
||||||
|
0.4,0.7620994844391137,-0.4540446830999986,0.015863077249098356,0.2356325204239052,0.8246710357361025,0.7182393675419642
|
||||||
|
0.41,0.7659871096124125,-0.4420485102716773,-0.024540477496154123,0.22880568593963535,0.8276865975886148,0.7521146815529202
|
||||||
|
0.42,0.7697410958994951,-0.4304647113488041,-0.06355514164248566,0.21993360985514526,0.8306086550266585,0.7848331944479765
|
||||||
|
0.43,0.773446484628189,-0.4190308715098135,-0.10206473803580057,0.20858849290850018,0.833503273690861,0.8171544357676854
|
||||||
|
0.44,0.7771893686864673,-0.4074813310994203,-0.14096401824224686,0.1939295692427068,0.8364382500400466,0.8498448067259188
|
||||||
|
0.45,0.7810574093604746,-0.3955455908045306,-0.18116403397486242,0.17438366103820427,0.839483669055626,0.8836865023336339
|
||||||
|
0.46,0.7851360804917298,-0.3829599011818591,-0.2235531031349741,0.14679145002531463,0.8427091517444469,0.9194481212717681
|
||||||
|
0.47,0.789525027020907,-0.369416784561489,-0.26916682191206776,0.10278921007810798,0.8461971304126237,0.9580316568065935
|
||||||
|
0.48,0.7942371698732826,-0.35487637041943493,-0.3181394757087982,0.0013920913109500188,0.8499626968466341,0.9995866371771526
|
||||||
|
0.49,0.7773897680996302,-0.31852357140025195,-0.34537976514700053,0.10740420703601522,0.8254781216972907,1.
|
||||||
|
0.5,0.7604011244310231,-0.28211213216592784,-0.3722846952738428,0.1581725581872408,0.8008522647497104,1.
|
||||||
|
0.51,0.7433440454962605,-0.2455540169176899,-0.3992980063927199,0.19300141807932156,0.7761561224913385,1.
|
||||||
|
0.52,0.7262590833969331,-0.20893614020926626,-0.42635547610418184,0.2194621842292243,0.751443124097109,1.
|
||||||
|
0.53,0.709058602701224,-0.17207067467417486,-0.453595892719742,0.2405673704012788,0.7265803324554873,1.
|
||||||
|
0.54,0.6915768892539101,-0.1346024482921609,-0.48128169789479536,0.25788347992973676,0.701321051230534,1.
|
||||||
|
0.55,0.6736331627810209,-0.09614399811510127,-0.5096991935104321,0.2722888922216317,0.6753950894563805,1.
|
||||||
|
0.56,0.6551463184003872,-0.05652149358027936,-0.5389768254408652,0.28422807900785235,0.6486730893521468,1.
|
||||||
|
0.57,0.6361671326276888,-0.01584376303510615,-0.5690341788729347,0.293907374075009,0.6212117649042732,1.
|
||||||
|
0.58,0.6168396823565967,0.025580396234342995,-0.5996430791016598,0.301442767979156,0.5931976878638505,1.
|
||||||
|
0.59,0.5973210287815495,0.06741435793529688,-0.6305547881733555,0.30694603901024253,0.5648312189065924,1.
|
||||||
|
0.6,0.5777303704171711,0.10940264614179468,-0.661580531294122,0.3105418468883679,0.5362525958007331,1.
|
||||||
|
0.61,0.5581475370499237,0.15137416317967575,-0.6925938819599547,0.3123531986526998,0.5075386530652202,1.
|
||||||
|
0.62,0.5386227795100639,0.19322120739317136,-0.7235152578861672,0.31248922600720636,0.4787151440558522,1.
|
||||||
|
0.63,0.5191666876024412,0.23492108185347996,-0.754327887989376,0.31103663081260624,0.44973844514160927,1.
|
||||||
|
0.64,0.4996990584326256,0.2766456839100268,-0.7851587896650079,0.30803814950244496,0.4204116611935119,1.
|
||||||
|
0.65,0.479957679121191,0.3189570094767831,-0.8164232296840259,0.30343473603466015,0.390226489453496,1.
|
||||||
|
0.66,0.4600072725872886,0.3617163391430824,-0.8480187063016573,0.29717122075330515,0.3591178757512998,1.
|
||||||
|
0.67,0.44600100870220305,0.4113853615984094,-0.8697728377551008,0.3178994129506999,0.3295740682997879,1.
|
||||||
|
0.68,0.4574651571354146,0.44026390446569547,-0.8504539292487465,0.3842479358768364,0.3280946443367561,1.
|
||||||
|
0.69,0.4691809168948424,0.46977626401045774,-0.830711015748157,0.44293649140770447,0.3260767554252525,1.
|
||||||
|
0.7,0.4811696900083858,0.49997635259991063,-0.8105080314416201,0.49708450874457527,0.3234487047238236,1.
|
||||||
|
0.71,0.49350094811609174,0.5310391714342613,-0.7897279055963483,0.5485591109413528,0.3201099534066949,1.
|
||||||
|
0.72,0.5062548753068121,0.5631667067020758,-0.7682355153041539,0.5985798481027601,0.3159263917472715,1.
|
||||||
|
0.73,0.5195243020949684,0.5965928013272943,-0.7458744264238399,0.6480500606439057,0.31071717884730565,1.
|
||||||
|
0.74,0.5334043922713477,0.6315571758288618,-0.7224842728734379,0.6976685401842261,0.3042411890803418,1.
|
||||||
|
0.75,0.5479805812358602,0.6682750446095802,-0.697921082452685,0.7479712773579563,0.29618040787504757,1.
|
||||||
|
0.76,0.5633244502526606,0.7069267230777347,-0.6720642293775535,0.7993701361353484,0.28611136999256687,1.
|
||||||
|
0.77,0.5794956601139,0.7476624986056212,-0.6448131757501174,0.8521918014427678,0.2734527325942473,1.
|
||||||
|
0.78,0.5965429098573916,0.7906050455688622,-0.6160858559672187,0.9067003897516911,0.2573693489198746,1.
|
||||||
|
0.79,0.6145761476424179,0.8360313267658297,-0.5856969899409387,0.963334644317004,0.23648492980159264,1.
|
||||||
|
0.8,0.6232910688128902,0.859291371252556,-0.5300995185388214,1.,0.21867949406239662,0.9712088595948508
|
||||||
|
0.81,0.6159984336377875,0.8439887543380684,-0.44635440435952856,1.,0.21606849746358275,0.9041480210597966
|
||||||
|
0.82,0.6091642745073532,0.8296481879180277,-0.36787420852419694,1.,0.21421830096504035,0.8419706002336461
|
||||||
|
0.83,0.6025478038652375,0.8157644115969636,-0.2918938425681935,1.,0.21295365915197917,0.7823908751330636
|
||||||
|
0.84,0.5961857222953111,0.8024144366282877,-0.21883475834162458,0.9971140114799418,0.21220068235083267,0.7256713129328118
|
||||||
|
0.85,0.5900921771070883,0.7896279492437488,-0.1488594167412921,0.993273906363258,0.2118788857127918,0.671860243327784
|
||||||
|
0.86,0.5842771639541229,0.7774259239818333,-0.08208260304413262,0.9887084084529413,0.21191070453347688,0.6209624706933893
|
||||||
|
0.87,0.578741582584259,0.7658102488427286,-0.018514649521559012,0.9835846378805114,0.2122246941077346,0.5728987835613306
|
||||||
|
0.88,0.5734741590353537,0.7547572669288056,0.04197390858426542,0.9780378159372328,0.21275878699579343,0.5274829957183049
|
||||||
|
0.89,0.5684517008574971,0.7442183119942206,0.09964940221121898,0.9721670725313721,0.21346242315895625,0.4844270603851604
|
||||||
|
0.9,0.5636419856510335,0.7341257696545772,0.15488185789614228,0.9660363209686843,0.21429691147008262,0.4433660148378527
|
||||||
|
0.91,0.5590069340453534,0.7243997354573974,0.20810856081277884,0.9596781387247791,0.2152344151262528,0.4038812338146013
|
||||||
|
0.92,0.5545051525321143,0.7149533506766244,0.25980485409830323,0.9530986696850675,0.21625626438013962,0.3655130449917989
|
||||||
|
0.93,0.5500961975299247,0.705701749880514,0.3104351723857584,0.9462863346513658,0.21735046958786286,0.327780364198278
|
||||||
|
0.94,0.545740378056064,0.6965616468647046,0.36045530782708896,0.93921469089265,0.21851014470332586,0.29014917175372823
|
||||||
|
0.95,0.5414004092067859,0.6874548042588865,0.41029342232076466,0.9318478255642132,0.21973168075163751,0.2519897371806688
|
||||||
|
0.96,0.5370416605957644,0.6783085548415655,0.46034719456417006,0.9241434776436454,0.22101341980094052,0.2124579038400577
|
||||||
|
0.97,0.5326309593934517,0.6690532898786764,0.5109975653738162,0.9160532016485884,0.22235495330179011,0.17018252385769012
|
||||||
|
0.98,0.5281374148557197,0.6596241892863608,0.5625992691950712,0.90752576202319,0.22375597459867458,0.1223073280126531
|
||||||
|
0.99,0.5235317096396147,0.6499597345521199,0.615488972291106,0.8985077346125597,0.22521565729028564,0.05933950582860665
|
||||||
|
1.,0.5187848173343539,0.6399990176455989,0.67,0.8889427469969852,0.22673227640012172,0.
|
||||||
|
@@ -1,6 +1,8 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
import optuna
|
import optuna
|
||||||
|
import torch
|
||||||
|
import util
|
||||||
from hypertraining.hypertraining import HyperTraining
|
from hypertraining.hypertraining import HyperTraining
|
||||||
from hypertraining.settings import (
|
from hypertraining.settings import (
|
||||||
GlobalSettings,
|
GlobalSettings,
|
||||||
@@ -16,24 +18,29 @@ global_settings = GlobalSettings(
|
|||||||
)
|
)
|
||||||
|
|
||||||
data_settings = DataSettings(
|
data_settings = DataSettings(
|
||||||
config_path="data/*-128-16384-100000-0-0-17-0-PAM4-0.ini",
|
# config_path="data/*-128-16384-100000-0-0-17-0-PAM4-0.ini",
|
||||||
|
config_path="data/20241204-131003-128-16384-100000-0-0-17-0-PAM4-0.ini",
|
||||||
dtype="complex64",
|
dtype="complex64",
|
||||||
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
|
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
|
||||||
symbols=13, # study: single_core_regen_20241123_011232
|
# symbols=13, # study: single_core_regen_20241123_011232
|
||||||
|
# symbols = (3, 13),
|
||||||
|
symbols=4,
|
||||||
# output_size = (11, 32), # ballpark 26 taps -> 2 taps per input symbol -> 1 tap every 0.01m (model has 52 inputs)
|
# output_size = (11, 32), # ballpark 26 taps -> 2 taps per input symbol -> 1 tap every 0.01m (model has 52 inputs)
|
||||||
output_size=26, # study: single_core_regen_20241123_011232 (model_input_dim/2)
|
# output_size=26, # study: single_core_regen_20241123_011232 (model_input_dim/2)
|
||||||
|
output_size=(8, 30),
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
in_out_delay=0,
|
in_out_delay=0,
|
||||||
xy_delay=0,
|
xy_delay=0,
|
||||||
drop_first=128 * 100,
|
drop_first=256,
|
||||||
train_split=0.8,
|
train_split=0.8,
|
||||||
|
randomise_polarisations=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
pytorch_settings = PytorchSettings(
|
pytorch_settings = PytorchSettings(
|
||||||
epochs=10000,
|
epochs=10,
|
||||||
batchsize=2**10,
|
batchsize=2**10,
|
||||||
device="cuda",
|
device="cuda",
|
||||||
dataloader_workers=12,
|
dataloader_workers=4,
|
||||||
dataloader_prefetch=4,
|
dataloader_prefetch=4,
|
||||||
summary_dir=".runs",
|
summary_dir=".runs",
|
||||||
write_every=2**5,
|
write_every=2**5,
|
||||||
@@ -43,28 +50,70 @@ pytorch_settings = PytorchSettings(
|
|||||||
|
|
||||||
model_settings = ModelSettings(
|
model_settings = ModelSettings(
|
||||||
output_dim=2,
|
output_dim=2,
|
||||||
# n_hidden_layers = (3, 8),
|
n_hidden_layers = (2, 5),
|
||||||
n_hidden_layers=4,
|
n_hidden_nodes=(2, 16),
|
||||||
overrides={
|
model_activation_func="EOActivation",
|
||||||
"n_hidden_nodes_0": 8,
|
dropout_prob=0,
|
||||||
"n_hidden_nodes_1": 6,
|
model_layer_function="ONNRect",
|
||||||
"n_hidden_nodes_2": 4,
|
model_layer_kwargs={"square": True},
|
||||||
"n_hidden_nodes_3": 8,
|
# scale=(False, True),
|
||||||
|
scale=False,
|
||||||
|
model_layer_parametrizations=[
|
||||||
|
{
|
||||||
|
"tensor_name": "weight",
|
||||||
|
"parametrization": util.complexNN.energy_conserving,
|
||||||
},
|
},
|
||||||
model_activation_func="Mag",
|
{
|
||||||
# satabsT0=(1e-6, 1),
|
"tensor_name": "alpha",
|
||||||
|
"parametrization": util.complexNN.clamp,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tensor_name": "gain",
|
||||||
|
"parametrization": util.complexNN.clamp,
|
||||||
|
"kwargs": {
|
||||||
|
"min": 0,
|
||||||
|
"max": float("inf"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tensor_name": "phase_bias",
|
||||||
|
"parametrization": util.complexNN.clamp,
|
||||||
|
"kwargs": {
|
||||||
|
"min": 0,
|
||||||
|
"max": 2 * torch.pi,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tensor_name": "scales",
|
||||||
|
"parametrization": util.complexNN.clamp,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tensor_name": "angle",
|
||||||
|
"parametrization": util.complexNN.clamp,
|
||||||
|
"kwargs": {
|
||||||
|
"min": -torch.pi,
|
||||||
|
"max": torch.pi,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tensor_name": "loss",
|
||||||
|
"parametrization": util.complexNN.clamp,
|
||||||
|
},
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
optimizer_settings = OptimizerSettings(
|
optimizer_settings = OptimizerSettings(
|
||||||
optimizer="Adam",
|
optimizer="AdamW",
|
||||||
# learning_rate = (1e-5, 1e-1),
|
optimizer_kwargs={
|
||||||
learning_rate=5e-3
|
"lr": 5e-3,
|
||||||
# learning_rate=5e-4,
|
"amsgrad": True,
|
||||||
|
# "weight_decay": 1e-7,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
optuna_settings = OptunaSettings(
|
optuna_settings = OptunaSettings(
|
||||||
n_trials=1,
|
n_trials=1024,
|
||||||
n_workers=1,
|
n_workers=8,
|
||||||
timeout=3600,
|
timeout=3600,
|
||||||
directions=("minimize",),
|
directions=("minimize",),
|
||||||
metrics_names=("mse",),
|
metrics_names=("mse",),
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from datetime import datetime
|
# from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import matplotlib
|
import matplotlib
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -13,7 +13,7 @@ from hypertraining.settings import (
|
|||||||
OptimizerSettings,
|
OptimizerSettings,
|
||||||
)
|
)
|
||||||
|
|
||||||
from hypertraining.training import RegenerationTrainer, PolarizationTrainer
|
from hypertraining.training import RegenerationTrainer#, PolarizationTrainer
|
||||||
|
|
||||||
# import torch
|
# import torch
|
||||||
import json
|
import json
|
||||||
@@ -26,25 +26,39 @@ global_settings = GlobalSettings(
|
|||||||
)
|
)
|
||||||
|
|
||||||
data_settings = DataSettings(
|
data_settings = DataSettings(
|
||||||
config_path="data/*-128-16384-10000-0-0-0-0-PAM4-0-0.4.ini",
|
# config_path="data/20250115-114319-128-16384-inf-100000-0-0-0-0-PAM4-1-0-10.ini", # no effects enabled - baseline
|
||||||
# config_path=[f"data/20241202-*-128-16384-{length}-0-0-17-0-PAM4-0.ini" for length in range(48000, 53000, 1000)],
|
# config_path = "data/20250115-233553-128-16384-1060.0-100000-0-0.2-17.0-0.058-PAM4-1.0-0.0-10.ini", # dispersion + slope only
|
||||||
|
# config_path="data/20250115-115836-128-16384-60.0-100000-0-0.2-17-0.058-PAM4-1000-0.2-10.ini", # all linear effects enabled with realistic values + noise + pmd (delta_beta=0.2) + ortho_error = 0.1
|
||||||
|
# config_path="data/20250118-225840-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini", # a)
|
||||||
|
# config_path="data/20250116-214606-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini", # b)
|
||||||
|
# config_path="data/20250116-214547-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini", # c)
|
||||||
|
# config_path="data/20250117-143926-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini", # d) 10ps dgd
|
||||||
|
config_path="data/20250120-105720-128-16384-inf-100000-0-0.2-17-0.058-PAM4-0-0-10.ini", # d) 10ns
|
||||||
|
|
||||||
|
# config_path="data/20250114-215547-128-16384-60.0-100000-1.15-0.2-17-0.058-PAM4-1-0-10.ini", # with gamma=1.15, 2.5dBm launch power, no pmd
|
||||||
|
|
||||||
|
|
||||||
dtype="complex64",
|
dtype="complex64",
|
||||||
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
|
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
|
||||||
symbols=13, # study: single_core_regen_20241123_011232
|
symbols=4, # study: single_core_regen_20241123_011232 -> taps spread over 4 symbols @ 10GBd
|
||||||
# output_size = (11, 32), # 26 taps -> 2 taps per input symbol -> 1 tap every 1cm (model has 52 inputs (x/y))
|
# output_size = (11, 32), # 26 taps -> 2 taps per input symbol -> 1 tap every 1cm (model has 52 inputs (x/y))
|
||||||
output_size=26, # study: single_core_regen_20241123_011232 (model_input_dim/2)
|
output_size=20, # = model_input_dim/2 study: single_core_regen_20241123_011232
|
||||||
shuffle=True,
|
shuffle=False,
|
||||||
drop_first=64,
|
drop_first=256,
|
||||||
|
drop_last=256,
|
||||||
train_split=0.8,
|
train_split=0.8,
|
||||||
randomise_polarisations=True,
|
randomise_polarisations=False,
|
||||||
|
polarisations=False,
|
||||||
|
# cross_pol_interference=0.01,
|
||||||
|
osnr=16, #16dB due to amplification with NF 5
|
||||||
)
|
)
|
||||||
|
|
||||||
pytorch_settings = PytorchSettings(
|
pytorch_settings = PytorchSettings(
|
||||||
epochs=10000,
|
epochs=1000,
|
||||||
batchsize=2**14,
|
batchsize=2**13,
|
||||||
device="cuda",
|
device="cuda",
|
||||||
dataloader_workers=16,
|
dataloader_workers=32,
|
||||||
dataloader_prefetch=8,
|
dataloader_prefetch=4,
|
||||||
summary_dir=".runs",
|
summary_dir=".runs",
|
||||||
write_every=2**5,
|
write_every=2**5,
|
||||||
save_models=True,
|
save_models=True,
|
||||||
@@ -53,80 +67,51 @@ pytorch_settings = PytorchSettings(
|
|||||||
|
|
||||||
model_settings = ModelSettings(
|
model_settings = ModelSettings(
|
||||||
output_dim=2,
|
output_dim=2,
|
||||||
n_hidden_layers=5,
|
n_hidden_layers=3,
|
||||||
overrides={
|
overrides={
|
||||||
# "hidden_layer_dims": (8, 8, 4, 4),
|
# "hidden_layer_dims": (8, 8, 4, 4),
|
||||||
"n_hidden_nodes_0": 8,
|
"n_hidden_nodes_0": 16,
|
||||||
"n_hidden_nodes_1": 8,
|
"n_hidden_nodes_1": 8,
|
||||||
"n_hidden_nodes_2": 4,
|
"n_hidden_nodes_2": 8,
|
||||||
"n_hidden_nodes_3": 4,
|
# "n_hidden_nodes_3": 4,
|
||||||
"n_hidden_nodes_4": 2,
|
# "n_hidden_nodes_4": 2,
|
||||||
},
|
},
|
||||||
model_activation_func="EOActivation",
|
model_activation_func="EOActivation",
|
||||||
dropout_prob=0.01,
|
dropout_prob=0,
|
||||||
model_layer_function="ONNRect",
|
model_layer_function="ONNRect",
|
||||||
model_layer_kwargs={"square": True},
|
model_layer_kwargs={"square": True},
|
||||||
scale=False,
|
scale=2.0,
|
||||||
model_layer_parametrizations=[
|
model_layer_parametrizations=[
|
||||||
{
|
# EOactivation
|
||||||
"tensor_name": "weight",
|
|
||||||
"parametrization": util.complexNN.energy_conserving,
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"tensor_name": "alpha",
|
"tensor_name": "alpha",
|
||||||
"parametrization": util.complexNN.clamp,
|
"parametrization": util.complexNN.clamp,
|
||||||
|
"kwargs": {
|
||||||
|
"min": 0,
|
||||||
|
"max": 1,
|
||||||
},
|
},
|
||||||
|
},
|
||||||
|
# ONNRect
|
||||||
{
|
{
|
||||||
"tensor_name": "gain",
|
"tensor_name": "weight",
|
||||||
|
"parametrization": torch.nn.utils.parametrizations.orthogonal,
|
||||||
|
},
|
||||||
|
# Scale
|
||||||
|
{
|
||||||
|
"tensor_name": "scale",
|
||||||
"parametrization": util.complexNN.clamp,
|
"parametrization": util.complexNN.clamp,
|
||||||
"kwargs": {
|
"kwargs": {
|
||||||
"min": 0,
|
"min": 0,
|
||||||
"max": float("inf"),
|
"max": 10,
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"tensor_name": "phase_bias",
|
|
||||||
"parametrization": util.complexNN.clamp,
|
|
||||||
"kwargs": {
|
|
||||||
"min": 0,
|
|
||||||
"max": 2 * torch.pi,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"tensor_name": "scales",
|
|
||||||
"parametrization": util.complexNN.clamp,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"tensor_name": "angle",
|
|
||||||
"parametrization": util.complexNN.clamp,
|
|
||||||
"kwargs": {
|
|
||||||
"min": -torch.pi,
|
|
||||||
"max": torch.pi,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
# {
|
|
||||||
# "tensor_name": "scale",
|
|
||||||
# "parametrization": util.complexNN.clamp,
|
|
||||||
# },
|
|
||||||
# {
|
|
||||||
# "tensor_name": "bias",
|
|
||||||
# "parametrization": util.complexNN.clamp,
|
|
||||||
# },
|
|
||||||
# {
|
|
||||||
# "tensor_name": "V",
|
|
||||||
# "parametrization": torch.nn.utils.parametrizations.orthogonal,
|
|
||||||
# },
|
|
||||||
{
|
|
||||||
"tensor_name": "loss",
|
|
||||||
"parametrization": util.complexNN.clamp,
|
|
||||||
},
|
},
|
||||||
|
}
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
optimizer_settings = OptimizerSettings(
|
optimizer_settings = OptimizerSettings(
|
||||||
optimizer="AdamW",
|
optimizer="AdamW",
|
||||||
optimizer_kwargs={
|
optimizer_kwargs={
|
||||||
"lr": 0.01,
|
"lr": 0.005,
|
||||||
"amsgrad": True,
|
"amsgrad": True,
|
||||||
# "weight_decay": 1e-7,
|
# "weight_decay": 1e-7,
|
||||||
},
|
},
|
||||||
@@ -134,107 +119,19 @@ optimizer_settings = OptimizerSettings(
|
|||||||
scheduler="ReduceLROnPlateau",
|
scheduler="ReduceLROnPlateau",
|
||||||
scheduler_kwargs={
|
scheduler_kwargs={
|
||||||
"patience": 2**6,
|
"patience": 2**6,
|
||||||
"factor": 0.75,
|
"factor": 0.5,
|
||||||
# "threshold": 1e-3,
|
# "threshold": 1e-3,
|
||||||
"min_lr": 1e-6,
|
"min_lr": 1e-6,
|
||||||
"cooldown": 10,
|
"cooldown": 10,
|
||||||
},
|
},
|
||||||
)
|
early_stopping=True,
|
||||||
|
early_stop_kwargs={
|
||||||
|
"threshold": 1e-06,
|
||||||
def save_dict_to_file(dictionary, filename):
|
"plateau": 2**7,
|
||||||
"""
|
|
||||||
Save the best dictionary to a JSON file.
|
|
||||||
|
|
||||||
:param best: Dictionary containing the best training results.
|
|
||||||
:type best: dict
|
|
||||||
:param filename: Path to the JSON file where the dictionary will be saved.
|
|
||||||
:type filename: str
|
|
||||||
"""
|
|
||||||
with open(filename, "w") as f:
|
|
||||||
json.dump(dictionary, f, indent=4)
|
|
||||||
|
|
||||||
|
|
||||||
def sweep_lengths(*lengths, model=None, data_glob: str = None, strategy="newest"):
|
|
||||||
assert model is not None, "Model must be provided."
|
|
||||||
assert data_glob is not None, "Data glob must be provided."
|
|
||||||
model = model
|
|
||||||
|
|
||||||
fiber_ins = {}
|
|
||||||
fiber_outs = {}
|
|
||||||
regens = {}
|
|
||||||
timestampss = {}
|
|
||||||
|
|
||||||
trainer = RegenerationTrainer(
|
|
||||||
checkpoint_path=model,
|
|
||||||
)
|
|
||||||
trainer.define_model()
|
|
||||||
|
|
||||||
for length in lengths:
|
|
||||||
data_glob_length = data_glob.replace("{length}", str(length))
|
|
||||||
files = list(Path.cwd().glob(data_glob_length))
|
|
||||||
if len(files) == 0:
|
|
||||||
continue
|
|
||||||
if strategy == "newest":
|
|
||||||
sorted_kwargs = {
|
|
||||||
"key": lambda x: x.stat().st_mtime,
|
|
||||||
"reverse": True,
|
|
||||||
}
|
}
|
||||||
elif strategy == "oldest":
|
)
|
||||||
sorted_kwargs = {
|
|
||||||
"key": lambda x: x.stat().st_mtime,
|
|
||||||
"reverse": False,
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown strategy {strategy}.")
|
|
||||||
file = sorted(files, **sorted_kwargs)[0]
|
|
||||||
|
|
||||||
loader, _ = trainer.get_sliced_data(override={"config_path": file})
|
|
||||||
fiber_in, fiber_out, regen, timestamps = trainer.run_model(trainer.model, loader=loader)
|
|
||||||
|
|
||||||
fiber_ins[length] = fiber_in
|
|
||||||
fiber_outs[length] = fiber_out
|
|
||||||
regens[length] = regen
|
|
||||||
timestampss[length] = timestamps
|
|
||||||
|
|
||||||
data = torch.zeros(2 * len(timestampss.keys()) + 2, 2, tuple(fiber_outs.values())[-1].shape[0])
|
|
||||||
channel_names = ["" for _ in range(2 * len(timestampss.keys()) + 2)]
|
|
||||||
|
|
||||||
data[1, 0, :] = timestampss[tuple(timestampss.keys())[-1]] / 128
|
|
||||||
data[1, 1, :] = fiber_ins[tuple(timestampss.keys())[-1]][:, 0].abs().square()
|
|
||||||
|
|
||||||
channel_names[1] = "fiber in x"
|
|
||||||
|
|
||||||
for li, length in enumerate(timestampss.keys()):
|
|
||||||
data[2 + 2 * li, 0, :] = timestampss[length] / 128
|
|
||||||
data[2 + 2 * li, 1, :] = fiber_outs[length][:, 0].abs().square()
|
|
||||||
data[2 + 2 * li + 1, 0, :] = timestampss[length] / 128
|
|
||||||
data[2 + 2 * li + 1, 1, :] = regens[length][:, 0].abs().square()
|
|
||||||
|
|
||||||
channel_names[2 + 2 * li + 1] = f"regen x {length}"
|
|
||||||
channel_names[2 + 2 * li] = f"fiber out x {length}"
|
|
||||||
|
|
||||||
# get current backend
|
|
||||||
backend = matplotlib.get_backend()
|
|
||||||
|
|
||||||
matplotlib.use("TkCairo")
|
|
||||||
eye = util.eye_diagram.eye_diagram(data.to(dtype=torch.float32).detach().cpu().numpy(), channel_names=channel_names)
|
|
||||||
|
|
||||||
print_attrs = ("channel_name", "success", "min_area")
|
|
||||||
with np.printoptions(precision=3, suppress=True, formatter={"float": "{:0.3e}".format}):
|
|
||||||
for result in eye.eye_stats:
|
|
||||||
print_dict = {attr: result[attr] for attr in print_attrs}
|
|
||||||
rprint(print_dict)
|
|
||||||
rprint()
|
|
||||||
|
|
||||||
eye.plot(all_stats=False)
|
|
||||||
matplotlib.use(backend)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# lengths = range(90000, 100000+10000, 10000)
|
|
||||||
# lengths = [100000]
|
|
||||||
# sweep_lengths(*lengths, model=".models/best_20241204_132605.tar", data_glob="data/202412*-{length}-*-0.ini", strategy="newest")
|
|
||||||
|
|
||||||
trainer = RegenerationTrainer(
|
trainer = RegenerationTrainer(
|
||||||
global_settings=global_settings,
|
global_settings=global_settings,
|
||||||
@@ -242,67 +139,15 @@ if __name__ == "__main__":
|
|||||||
pytorch_settings=pytorch_settings,
|
pytorch_settings=pytorch_settings,
|
||||||
model_settings=model_settings,
|
model_settings=model_settings,
|
||||||
optimizer_settings=optimizer_settings,
|
optimizer_settings=optimizer_settings,
|
||||||
# checkpoint_path=".models/best_20241205_235929.tar",
|
checkpoint_path=".models/best_20250117_144001.tar",
|
||||||
# 20241202_143149
|
new_model=True,
|
||||||
|
settings_override={
|
||||||
|
"data_settings": data_settings.__dict__,
|
||||||
|
# "optimizer_settings": {
|
||||||
|
# "early_stop_kwargs":{
|
||||||
|
# "plateau": 2**8,
|
||||||
|
# }
|
||||||
|
# }
|
||||||
|
}
|
||||||
)
|
)
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
# from hypertraining.lighning_models import regenerator, regeneratorData
|
|
||||||
# import lightning as L
|
|
||||||
|
|
||||||
# model = regenerator(
|
|
||||||
# 2 * data_settings.output_size,
|
|
||||||
# *model_settings.overrides["hidden_layer_dims"],
|
|
||||||
# model_settings.output_dim,
|
|
||||||
# layer_function=getattr(util.complexNN, model_settings.model_layer_function),
|
|
||||||
# layer_func_kwargs=model_settings.model_layer_kwargs,
|
|
||||||
# act_function=getattr(util.complexNN, model_settings.model_activation_func),
|
|
||||||
# act_func_kwargs=None,
|
|
||||||
# parametrizations=model_settings.model_layer_parametrizations,
|
|
||||||
# dtype=getattr(torch, data_settings.dtype),
|
|
||||||
# dropout_prob=model_settings.dropout_prob,
|
|
||||||
# scale_layers=model_settings.scale,
|
|
||||||
# optimizer=getattr(torch.optim, optimizer_settings.optimizer),
|
|
||||||
# optimizer_kwargs=optimizer_settings.optimizer_kwargs,
|
|
||||||
# lr_scheduler=getattr(torch.optim.lr_scheduler, optimizer_settings.scheduler),
|
|
||||||
# lr_scheduler_kwargs=optimizer_settings.scheduler_kwargs,
|
|
||||||
# )
|
|
||||||
|
|
||||||
# dm = regeneratorData(
|
|
||||||
# config_globs=data_settings.config_path,
|
|
||||||
# output_symbols=data_settings.symbols,
|
|
||||||
# output_dim=data_settings.output_size,
|
|
||||||
# dtype=getattr(torch, data_settings.dtype),
|
|
||||||
# drop_first=data_settings.drop_first,
|
|
||||||
# shuffle=data_settings.shuffle,
|
|
||||||
# train_split=data_settings.train_split,
|
|
||||||
# batch_size=pytorch_settings.batchsize,
|
|
||||||
# loader_settings={
|
|
||||||
# "num_workers": pytorch_settings.dataloader_workers,
|
|
||||||
# "prefetch_factor": pytorch_settings.dataloader_prefetch,
|
|
||||||
# "pin_memory": True,
|
|
||||||
# "drop_last": True,
|
|
||||||
# },
|
|
||||||
# seed=global_settings.seed,
|
|
||||||
# )
|
|
||||||
|
|
||||||
# # writer = L.SummaryWriter(pytorch_settings.summary_dir + f"/{datetime.now().strftime('%Y%m%d_%H%M%S')}")
|
|
||||||
|
|
||||||
# # from torch.utils.tensorboard import SummaryWriter
|
|
||||||
|
|
||||||
# subdir = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
|
||||||
|
|
||||||
# # writer = SummaryWriter(pytorch_settings.summary_dir + f"/{subdir}")
|
|
||||||
|
|
||||||
# logger = L.pytorch.loggers.TensorBoardLogger(pytorch_settings.summary_dir, name=subdir, log_graph=True)
|
|
||||||
|
|
||||||
# trainer = L.Trainer(
|
|
||||||
# fast_dev_run=False,
|
|
||||||
# # max_epochs=pytorch_settings.epochs,
|
|
||||||
# max_epochs=2,
|
|
||||||
# enable_checkpointing=True,
|
|
||||||
# default_root_dir=f".models/{subdir}/",
|
|
||||||
# logger=logger,
|
|
||||||
# )
|
|
||||||
|
|
||||||
# trainer.fit(model, dm)
|
|
||||||
|
|||||||
@@ -12,20 +12,22 @@ Full license text in LICENSE file
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import configparser
|
import configparser
|
||||||
|
# import copy
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import hashlib
|
import hashlib
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import time
|
import time
|
||||||
|
import h5py
|
||||||
from matplotlib import pyplot as plt # noqa: F401
|
from matplotlib import pyplot as plt # noqa: F401
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import add_pypho # noqa: F401
|
from . import add_pypho # noqa: F401
|
||||||
import pypho
|
import pypho
|
||||||
|
|
||||||
default_config = f"""
|
default_config = f"""
|
||||||
[glova]
|
[glova]
|
||||||
nos = 256
|
sps = 128
|
||||||
sps = 256
|
nos = 16384
|
||||||
f0 = 193414489032258.06
|
f0 = 193414489032258.06
|
||||||
symbolrate = 10e9
|
symbolrate = 10e9
|
||||||
wisdom_dir = "{str((Path.home() / ".pypho"))}"
|
wisdom_dir = "{str((Path.home() / ".pypho"))}"
|
||||||
@@ -37,9 +39,9 @@ length = 10000
|
|||||||
gamma = 1.14
|
gamma = 1.14
|
||||||
alpha = 0.2
|
alpha = 0.2
|
||||||
D = 17
|
D = 17
|
||||||
S = 0
|
S = 0.058
|
||||||
birefsteps = 0
|
bireflength = 10
|
||||||
max_delta_beta = 0.4
|
pmd_q = 0.2
|
||||||
; birefseed = 0xC0FFEE
|
; birefseed = 0xC0FFEE
|
||||||
|
|
||||||
[signal]
|
[signal]
|
||||||
@@ -47,17 +49,15 @@ max_delta_beta = 0.4
|
|||||||
|
|
||||||
modulation = "pam"
|
modulation = "pam"
|
||||||
mod_order = 4
|
mod_order = 4
|
||||||
mod_depth = 0.8
|
mod_depth = 1
|
||||||
|
|
||||||
max_jitter = 0.02
|
max_jitter = 0.02
|
||||||
; jitter_seed = 0xC0FFEE
|
; jitter_seed = 0xC0FFEE
|
||||||
|
|
||||||
laser_power = 0
|
laser_power = 0
|
||||||
edfa_power = 3
|
edfa_power = 0
|
||||||
edfa_nf = 5
|
edfa_nf = 5
|
||||||
|
|
||||||
pulse_shape = "gauss"
|
pulse_shape = "gauss"
|
||||||
fwhm = 0.33
|
fwhm = 0.33
|
||||||
|
osnr = "inf"
|
||||||
|
|
||||||
[data]
|
[data]
|
||||||
dir = "data"
|
dir = "data"
|
||||||
@@ -71,6 +71,7 @@ def get_config(config_file=None):
|
|||||||
"""
|
"""
|
||||||
if config_file is None:
|
if config_file is None:
|
||||||
config_file = Path(__file__).parent / "signal_generation.ini"
|
config_file = Path(__file__).parent / "signal_generation.ini"
|
||||||
|
config_file = Path(config_file)
|
||||||
if not config_file.exists():
|
if not config_file.exists():
|
||||||
with open(config_file, "w") as f:
|
with open(config_file, "w") as f:
|
||||||
f.write(default_config)
|
f.write(default_config)
|
||||||
@@ -83,7 +84,10 @@ def get_config(config_file=None):
|
|||||||
conf[section] = {}
|
conf[section] = {}
|
||||||
for key in config[section]:
|
for key in config[section]:
|
||||||
# print(f"{key} = {config[section][key]}")
|
# print(f"{key} = {config[section][key]}")
|
||||||
|
try:
|
||||||
conf[section][key] = eval(config[section][key])
|
conf[section][key] = eval(config[section][key])
|
||||||
|
except NameError:
|
||||||
|
conf[section][key] = float(config[section][key])
|
||||||
# if isinstance(conf[section][key], str):
|
# if isinstance(conf[section][key], str):
|
||||||
# conf[section][key] = config[section][key].strip('"')
|
# conf[section][key] = config[section][key].strip('"')
|
||||||
return conf
|
return conf
|
||||||
@@ -96,7 +100,9 @@ class PDM_IM_IPM:
|
|||||||
mod_order=8,
|
mod_order=8,
|
||||||
seed=None,
|
seed=None,
|
||||||
):
|
):
|
||||||
assert np.cbrt(mod_order) == int(np.cbrt(mod_order)) and mod_order > 1, "mod_order must be a cube of an integer greater than 1"
|
assert np.cbrt(mod_order) == int(np.cbrt(mod_order)) and mod_order > 1, (
|
||||||
|
"mod_order must be a cube of an integer greater than 1"
|
||||||
|
)
|
||||||
self.glova = glova
|
self.glova = glova
|
||||||
self.mod_order = mod_order
|
self.mod_order = mod_order
|
||||||
self.symbols_per_dim = int(np.cbrt(mod_order))
|
self.symbols_per_dim = int(np.cbrt(mod_order))
|
||||||
@@ -110,14 +116,7 @@ class PDM_IM_IPM:
|
|||||||
|
|
||||||
class pam_generator:
|
class pam_generator:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self, glova, mod_order=None, mod_depth=0.5, pulse_shape="gauss", fwhm=0.33, seed=None, single_channel=False
|
||||||
glova,
|
|
||||||
mod_order=None,
|
|
||||||
mod_depth=0.5,
|
|
||||||
pulse_shape="gauss",
|
|
||||||
fwhm=0.33,
|
|
||||||
seed=None,
|
|
||||||
single_channel=False
|
|
||||||
) -> None:
|
) -> None:
|
||||||
self.glova = glova
|
self.glova = glova
|
||||||
self.pulse_shape = pulse_shape
|
self.pulse_shape = pulse_shape
|
||||||
@@ -138,15 +137,13 @@ class pam_generator:
|
|||||||
symbols_x = symbols[0] / (self.mod_order)
|
symbols_x = symbols[0] / (self.mod_order)
|
||||||
diffs_x = np.diff(symbols_x, prepend=symbols_x[0])
|
diffs_x = np.diff(symbols_x, prepend=symbols_x[0])
|
||||||
digital_x = self.generate_digital_signal(diffs_x, max_jitter)
|
digital_x = self.generate_digital_signal(diffs_x, max_jitter)
|
||||||
digital_x = np.pad(
|
digital_x = np.pad(digital_x, (0, self.glova.sps // 2), "constant", constant_values=(0, 0))
|
||||||
digital_x, (0, self.glova.sps // 2), "constant", constant_values=(0, 0)
|
|
||||||
)
|
|
||||||
|
|
||||||
# create analog signal of diff of symbols
|
# create analog signal of diff of symbols
|
||||||
E_x = np.convolve(digital_x, wavelet)
|
E_x = np.convolve(digital_x, wavelet)
|
||||||
|
|
||||||
# convert to pam and set modulation depth (scale and move up such that 1 stays at 1)
|
# convert to pam and set modulation depth (scale and move up such that 1 stays at 1)
|
||||||
E_x = np.cumsum(E_x) * self.modulation_depth + 2*(1 - self.modulation_depth)
|
E_x = np.cumsum(E_x) * self.modulation_depth + (1 - self.modulation_depth)
|
||||||
|
|
||||||
# cut off the wavelet tails
|
# cut off the wavelet tails
|
||||||
E_x = E_x[self.glova.sps // 2 + len(wavelet) // 2 - 1 : -len(wavelet) // 2]
|
E_x = E_x[self.glova.sps // 2 + len(wavelet) // 2 - 1 : -len(wavelet) // 2]
|
||||||
@@ -158,25 +155,21 @@ class pam_generator:
|
|||||||
symbols_y = symbols[1] / (self.mod_order)
|
symbols_y = symbols[1] / (self.mod_order)
|
||||||
diffs_y = np.diff(symbols_y, prepend=symbols_y[0])
|
diffs_y = np.diff(symbols_y, prepend=symbols_y[0])
|
||||||
digital_y = self.generate_digital_signal(diffs_y, max_jitter)
|
digital_y = self.generate_digital_signal(diffs_y, max_jitter)
|
||||||
digital_y = np.pad(
|
digital_y = np.pad(digital_y, (0, self.glova.sps // 2), "constant", constant_values=(0, 0))
|
||||||
digital_y, (0, self.glova.sps // 2), "constant", constant_values=(0, 0)
|
|
||||||
)
|
|
||||||
E_y = np.convolve(digital_y, wavelet)
|
E_y = np.convolve(digital_y, wavelet)
|
||||||
|
|
||||||
E_y = np.cumsum(E_y) * self.modulation_depth + (1 - self.modulation_depth)
|
E_y = np.cumsum(E_y) * self.modulation_depth + (1 - self.modulation_depth)
|
||||||
|
|
||||||
E_y = E_y[self.glova.sps // 2 + len(wavelet) // 2 - 1 : -len(wavelet) // 2]
|
E_y = E_y[self.glova.sps // 2 + len(wavelet) // 2 - 1 : -len(wavelet) // 2]
|
||||||
|
|
||||||
|
|
||||||
E[0]["E"][1] = np.sqrt(np.multiply(np.square(E[0]["E"][1]), E_y))
|
E[0]["E"][1] = np.sqrt(np.multiply(np.square(E[0]["E"][1]), E_y))
|
||||||
|
|
||||||
# rotate the signal on the y-polarisation by 90°
|
# rotate the signal on the y-polarisation by 90°
|
||||||
E[0]["E"][1] *= 1j
|
# E[0]["E"][1] *= 1j
|
||||||
else:
|
else:
|
||||||
E[0]["E"][1] = np.zeros_like(E[0]["E"][0], dtype=E[0]["E"][0].dtype)
|
E[0]["E"][1] = np.zeros_like(E[0]["E"][0], dtype=E[0]["E"][0].dtype)
|
||||||
return E
|
return E
|
||||||
|
|
||||||
|
|
||||||
def generate_digital_signal(self, symbols, max_jitter=0):
|
def generate_digital_signal(self, symbols, max_jitter=0):
|
||||||
rs = np.random.RandomState(self.seed)
|
rs = np.random.RandomState(self.seed)
|
||||||
signal = np.zeros(self.glova.nos * self.glova.sps)
|
signal = np.zeros(self.glova.nos * self.glova.sps)
|
||||||
@@ -198,19 +191,19 @@ class pam_generator:
|
|||||||
endpoint=True,
|
endpoint=True,
|
||||||
)
|
)
|
||||||
sigma = self.fwhm / (1 * np.sqrt(2 * np.log(2))) * self.glova.sps
|
sigma = self.fwhm / (1 * np.sqrt(2 * np.log(2))) * self.glova.sps
|
||||||
pulse = (
|
pulse = 1 / (sigma * np.sqrt(2 * np.pi)) * np.exp(-np.square(sample_points) / (2 * np.square(sigma)))
|
||||||
1
|
|
||||||
/ (sigma * np.sqrt(2 * np.pi))
|
|
||||||
* np.exp(-np.square(sample_points) / (2 * np.square(sigma)))
|
|
||||||
)
|
|
||||||
return pulse
|
return pulse
|
||||||
|
|
||||||
|
|
||||||
def initialize_fiber_and_data(config, input_data_override=None):
|
def initialize_fiber_and_data(config):
|
||||||
|
f0 = config["glova"].get("f0", None)
|
||||||
|
if f0 is None:
|
||||||
|
f0 = 299792458/(config["glova"].get("lambda0", 1550)*1e-9)
|
||||||
|
config["glova"]["f0"] = f0
|
||||||
py_glova = pypho.setup(
|
py_glova = pypho.setup(
|
||||||
nos=config["glova"]["nos"],
|
nos=config["glova"]["nos"],
|
||||||
sps=config["glova"]["sps"],
|
sps=config["glova"]["sps"],
|
||||||
f0=config["glova"]["f0"],
|
f0=f0,
|
||||||
symbolrate=config["glova"]["symbolrate"],
|
symbolrate=config["glova"]["symbolrate"],
|
||||||
wisdom_dir=config["glova"]["wisdom_dir"],
|
wisdom_dir=config["glova"]["wisdom_dir"],
|
||||||
flags=config["glova"]["flags"],
|
flags=config["glova"]["flags"],
|
||||||
@@ -221,29 +214,22 @@ def initialize_fiber_and_data(config, input_data_override=None):
|
|||||||
c_data = pypho.cfiber.DataWrapper(py_glova.sps * py_glova.nos)
|
c_data = pypho.cfiber.DataWrapper(py_glova.sps * py_glova.nos)
|
||||||
py_edfa = pypho.oamp(py_glova, Pmean=config["signal"]["edfa_power"], NF=config["signal"]["edfa_nf"])
|
py_edfa = pypho.oamp(py_glova, Pmean=config["signal"]["edfa_power"], NF=config["signal"]["edfa_nf"])
|
||||||
|
|
||||||
if input_data_override is not None:
|
osnr = config["signal"]["osnr"] = float(config["signal"].get("osnr", "inf"))
|
||||||
c_data.E_in = input_data_override[0]
|
|
||||||
noise = input_data_override[1]
|
config["signal"]["seed"] = config["signal"].get("seed", (int(time.time() * 1000)) % 2**32)
|
||||||
else:
|
config["signal"]["jitter_seed"] = config["signal"].get("jitter_seed", (int(time.time() * 1000)) % 2**32)
|
||||||
config["signal"]["seed"] = config["signal"].get(
|
|
||||||
"seed", (int(time.time() * 1000)) % 2**32
|
|
||||||
)
|
|
||||||
config["signal"]["jitter_seed"] = config["signal"].get(
|
|
||||||
"jitter_seed", (int(time.time() * 1000)) % 2**32
|
|
||||||
)
|
|
||||||
symbolsrc = pypho.symbols(
|
symbolsrc = pypho.symbols(
|
||||||
py_glova, py_glova.nos, pattern="ones", p1=config["signal"]["mod_order"], seed=config["signal"]["seed"]
|
py_glova, py_glova.nos, pattern="ones", p1=config["signal"]["mod_order"], seed=config["signal"]["seed"]
|
||||||
)
|
)
|
||||||
laser = pypho.lasmod(
|
laserx = pypho.lasmod(py_glova, power=0, Df=0, theta=np.pi/4)
|
||||||
py_glova, power=config["signal"]["laser_power"]+1.5, Df=0, theta=np.pi / 4
|
# lasery = pypho.lasmod(py_glova, power=0, Df=25, theta=0)
|
||||||
)
|
|
||||||
modulator = pam_generator(
|
modulator = pam_generator(
|
||||||
py_glova,
|
py_glova,
|
||||||
mod_depth=config["signal"]["mod_depth"],
|
mod_depth=config["signal"]["mod_depth"],
|
||||||
pulse_shape=config["signal"]["pulse_shape"],
|
pulse_shape=config["signal"]["pulse_shape"],
|
||||||
fwhm=config["signal"]["fwhm"],
|
fwhm=config["signal"]["fwhm"],
|
||||||
seed=config["signal"]["jitter_seed"],
|
seed=config["signal"]["jitter_seed"],
|
||||||
single_channel=False,
|
|
||||||
mod_order=config["signal"]["mod_order"],
|
mod_order=config["signal"]["mod_order"],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -254,17 +240,64 @@ def initialize_fiber_and_data(config, input_data_override=None):
|
|||||||
# symbols_x += 1
|
# symbols_x += 1
|
||||||
|
|
||||||
|
|
||||||
cw = laser()
|
cw = laserx()
|
||||||
|
# cwy = lasery()
|
||||||
|
# cw[0]['E'][0] = cw[0]['E'][0]
|
||||||
|
# cw[0]['E'][1] = cwy[0]['E'][0]
|
||||||
|
|
||||||
|
|
||||||
source_signal = modulator(E=cw, symbols=(symbols_x, symbols_y))
|
source_signal = modulator(E=cw, symbols=(symbols_x, symbols_y))
|
||||||
|
|
||||||
|
if osnr != float("inf"):
|
||||||
|
osnr_lin = 10 ** (osnr / 10)
|
||||||
|
signal_power = np.sum(pypho.functions.getpower_W(source_signal[0]["E"]))
|
||||||
|
noise_power = signal_power / osnr_lin
|
||||||
|
noise = np.random.normal(0, 1, source_signal[0]["E"].shape) + 1j * np.random.normal(
|
||||||
|
0, 1, source_signal[0]["E"].shape
|
||||||
|
)
|
||||||
|
noise_power_is = np.sum(pypho.functions.getpower_W(noise))
|
||||||
|
noise = noise * np.sqrt(noise_power / noise_power_is)
|
||||||
|
noise_power_is = np.sum(pypho.functions.getpower_W(noise))
|
||||||
|
source_signal[0]["E"] += noise
|
||||||
|
source_signal[0]["noise"] = noise_power_is
|
||||||
|
|
||||||
# source_signal[0]['E'][1] += source_signal[0]['E'][1][np.argmin(np.abs(source_signal[0]['E'][1]))]
|
# source_signal[0]['E'][1] += source_signal[0]['E'][1][np.argmin(np.abs(source_signal[0]['E'][1]))]
|
||||||
|
|
||||||
source_signal = py_edfa(E=source_signal)
|
## side channels
|
||||||
|
# df = 100
|
||||||
|
# signal_power = pypho.functions.W_to_dBm(np.sum(pypho.functions.getpower_W(source_signal[0]["E"])))
|
||||||
|
|
||||||
|
|
||||||
|
# symbols_x_side = symbolsrc(pattern="random")
|
||||||
|
# symbols_y_side = symbolsrc(pattern="random")
|
||||||
|
# symbols_x_side[:3] = 0
|
||||||
|
# symbols_y_side[:3] = 0
|
||||||
|
|
||||||
|
# cw_left = laser(Df=-df)
|
||||||
|
# source_signal_left = modulator(E=cw_left, symbols=(symbols_x_side, symbols_y_side))
|
||||||
|
|
||||||
|
# cw_right = laser(Df=df)
|
||||||
|
# source_signal_right = modulator(E=cw_right, symbols=(symbols_y_side, symbols_x_side))
|
||||||
|
|
||||||
|
E_in_pure = source_signal[0]["E"]
|
||||||
|
|
||||||
|
nf = py_edfa.NF
|
||||||
|
pmean = py_edfa.Pmean
|
||||||
|
|
||||||
|
# ideal amplification to launch power into fiber
|
||||||
|
source_signal = py_edfa(E=source_signal, NF=0, Pmean=config["signal"]["laser_power"])
|
||||||
|
# source_signal_left = py_edfa(E=source_signal_left, NF=0, Pmean=config["signal"]["laser_power"])
|
||||||
|
# source_signal_right = py_edfa(E=source_signal_right, NF=0, Pmean=config["signal"]["laser_power"])
|
||||||
|
|
||||||
|
# source_signal[0]["E"][0] += source_signal_left[0]["E"][0] + source_signal_right[0]["E"][0]
|
||||||
|
# source_signal[0]["E"][1] += source_signal_left[0]["E"][1] + source_signal_right[0]["E"][1]
|
||||||
|
|
||||||
c_data.E_in = source_signal[0]["E"]
|
c_data.E_in = source_signal[0]["E"]
|
||||||
noise = source_signal[0]["noise"]
|
noise = source_signal[0]["noise"]
|
||||||
|
|
||||||
|
py_edfa.NF = nf
|
||||||
|
py_edfa.Pmean = pmean
|
||||||
|
|
||||||
py_fiber = pypho.fiber(
|
py_fiber = pypho.fiber(
|
||||||
glova=py_glova,
|
glova=py_glova,
|
||||||
l=config["fiber"]["length"],
|
l=config["fiber"]["length"],
|
||||||
@@ -272,27 +305,32 @@ def initialize_fiber_and_data(config, input_data_override=None):
|
|||||||
gamma=config["fiber"]["gamma"],
|
gamma=config["fiber"]["gamma"],
|
||||||
D=config["fiber"]["d"],
|
D=config["fiber"]["d"],
|
||||||
S=config["fiber"]["s"],
|
S=config["fiber"]["s"],
|
||||||
|
phi_max=0.02,
|
||||||
)
|
)
|
||||||
if config["fiber"].get("birefsteps", 0) > 0:
|
|
||||||
seed = config["fiber"].get(
|
config["fiber"]["birefsteps"] = config["fiber"].get(
|
||||||
"birefseed", (int(time.time() * 1000)) % 2**32
|
"birefsteps", config["fiber"]["length"] // config["fiber"].get("bireflength", config["fiber"]["length"])
|
||||||
)
|
)
|
||||||
|
if config["fiber"]["birefsteps"] > 0:
|
||||||
|
config["fiber"]["bireflength"] = config["fiber"].get("bireflength", config["fiber"]["length"] / config["fiber"]["birefsteps"])
|
||||||
|
seed = config["fiber"].get("birefseed", (int(time.time() * 1000)) % 2**32)
|
||||||
py_fiber.birefarray = pypho.birefringence_segment.create_pmd_fibre(
|
py_fiber.birefarray = pypho.birefringence_segment.create_pmd_fibre(
|
||||||
py_fiber.l,
|
config["fiber"]["length"],
|
||||||
py_fiber.l / config["fiber"]["birefsteps"],
|
config["fiber"]["bireflength"],
|
||||||
# maxDeltaD=config["fiber"]["d"]/5,
|
maxDeltaBeta=config["fiber"].get("pmd_q", 0)/np.sqrt(2*config["fiber"]["bireflength"]),
|
||||||
maxDeltaBeta = config["fiber"].get("max_delta_beta", 0),
|
|
||||||
seed=seed,
|
seed=seed,
|
||||||
)
|
)
|
||||||
c_params = pypho.cfiber.ParamsWrapper.from_fiber(
|
elif (dgd := config['fiber'].get('dgd', 0)) > 0:
|
||||||
py_fiber, max_step=1e3 if py_fiber.gamma == 0 else 200
|
py_fiber.birefarray = [
|
||||||
)
|
pypho.birefringence_segment(z_point=0, angle=np.pi/2, delta_beta=1000*dgd/config["fiber"]["length"])
|
||||||
|
]
|
||||||
|
c_params = pypho.cfiber.ParamsWrapper.from_fiber(py_fiber, max_step=config["fiber"]["length"] if py_fiber.gamma == 0 else 200)
|
||||||
c_fiber = pypho.cfiber.FiberWrapper(c_data, c_params, c_glova)
|
c_fiber = pypho.cfiber.FiberWrapper(c_data, c_params, c_glova)
|
||||||
|
|
||||||
return c_fiber, c_data, noise, py_edfa
|
return c_fiber, c_data, noise, py_edfa, (symbols_x, symbols_y), py_glova, E_in_pure
|
||||||
|
|
||||||
|
|
||||||
def save_data(data, config):
|
def save_data(data, config, **metadata):
|
||||||
data_dir = Path(config["data"]["dir"])
|
data_dir = Path(config["data"]["dir"])
|
||||||
npy_dir = config["data"].get("npy_dir", "")
|
npy_dir = config["data"].get("npy_dir", "")
|
||||||
save_dir = data_dir / npy_dir if len(npy_dir) else data_dir
|
save_dir = data_dir / npy_dir if len(npy_dir) else data_dir
|
||||||
@@ -307,6 +345,7 @@ def save_data(data, config):
|
|||||||
seed = config["signal"].get("seed", False)
|
seed = config["signal"].get("seed", False)
|
||||||
jitter_seed = config["signal"].get("jitter_seed", False)
|
jitter_seed = config["signal"].get("jitter_seed", False)
|
||||||
birefseed = config["fiber"].get("birefseed", False)
|
birefseed = config["fiber"].get("birefseed", False)
|
||||||
|
osnr = float(config["signal"].get("osnr", "inf"))
|
||||||
|
|
||||||
config_content = "\n".join((
|
config_content = "\n".join((
|
||||||
f"; Generated by {str(Path(__file__).name)} @ {timestamp.strftime('%Y-%m-%d %H:%M:%S')}",
|
f"; Generated by {str(Path(__file__).name)} @ {timestamp.strftime('%Y-%m-%d %H:%M:%S')}",
|
||||||
@@ -326,8 +365,11 @@ def save_data(data, config):
|
|||||||
f"D = {config['fiber']['d']}",
|
f"D = {config['fiber']['d']}",
|
||||||
f"S = {config['fiber']['s']}",
|
f"S = {config['fiber']['s']}",
|
||||||
f"birefsteps = {config['fiber'].get('birefsteps', 0)}",
|
f"birefsteps = {config['fiber'].get('birefsteps', 0)}",
|
||||||
f"max_delta_beta = {config['fiber'].get('max_delta_beta', 0)}",
|
f"pmd_q = {config['fiber'].get('pmd_q', 0)}",
|
||||||
f"birefseed = {hex(birefseed)}" if birefseed else "; birefseed = not set",
|
f"birefseed = {hex(birefseed)}" if birefseed else "; birefseed = not set",
|
||||||
|
f"dgd = {config['fiber'].get('dgd', 0)}",
|
||||||
|
f"ortho_error = {config['fiber'].get('ortho_error', 0)}",
|
||||||
|
f"pol_error = {config['fiber'].get('pol_error', 0)}",
|
||||||
"",
|
"",
|
||||||
"[signal]",
|
"[signal]",
|
||||||
f"seed = {hex(seed)}" if seed else "; seed = not set",
|
f"seed = {hex(seed)}" if seed else "; seed = not set",
|
||||||
@@ -335,100 +377,93 @@ def save_data(data, config):
|
|||||||
f'modulation = "{config["signal"]["modulation"]}"',
|
f'modulation = "{config["signal"]["modulation"]}"',
|
||||||
f"mod_order = {config['signal']['mod_order']}",
|
f"mod_order = {config['signal']['mod_order']}",
|
||||||
f"mod_depth = {config['signal']['mod_depth']}",
|
f"mod_depth = {config['signal']['mod_depth']}",
|
||||||
""
|
"",
|
||||||
f"max_jitter = {config['signal']['max_jitter']}",
|
f"max_jitter = {config['signal']['max_jitter']}",
|
||||||
f"jitter_seed = {hex(jitter_seed)}" if jitter_seed else "; jitter_seed = not set",
|
f"jitter_seed = {hex(jitter_seed)}" if jitter_seed else "; jitter_seed = not set",
|
||||||
""
|
"",
|
||||||
f"laser_power = {config['signal']['laser_power']}",
|
f"laser_power = {config['signal']['laser_power']}",
|
||||||
f"edfa_power = {config['signal']['edfa_power']}",
|
f"edfa_power = {config['signal']['edfa_power']}",
|
||||||
f"edfa_nf = {config['signal']['edfa_nf']}",
|
f"edfa_nf = {config['signal']['edfa_nf']}",
|
||||||
""
|
f"osnr = {osnr}",
|
||||||
|
"",
|
||||||
f'pulse_shape = "{config["signal"]["pulse_shape"]}"',
|
f'pulse_shape = "{config["signal"]["pulse_shape"]}"',
|
||||||
f"fwhm = {config['signal']['fwhm']}",
|
f"fwhm = {config['signal']['fwhm']}",
|
||||||
"",
|
"",
|
||||||
"[data]",
|
"[data]",
|
||||||
f'dir = "{str(data_dir)}"',
|
f'dir = "{str(data_dir)}"',
|
||||||
f'npy_dir = "{npy_dir}"',
|
f'npy_dir = "{npy_dir}"',
|
||||||
"file = "
|
"file = ",
|
||||||
))
|
))
|
||||||
config_hash = hashlib.md5(config_content.encode()).hexdigest()
|
config_hash = hashlib.md5(config_content.encode()).hexdigest()
|
||||||
save_file = f"{config_hash}.npy"
|
save_file = f"{config_hash}.h5"
|
||||||
config_content += f'"{str(save_file)}"\n'
|
config_content += f'"{str(save_file)}"\n'
|
||||||
|
|
||||||
|
config_filename:Path = create_config_filename(config, data_dir, timestamp)
|
||||||
|
while config_filename.exists():
|
||||||
|
time.sleep(1)
|
||||||
|
config_filename = create_config_filename(config, data_dir=data_dir)
|
||||||
|
|
||||||
|
|
||||||
|
with open(config_filename, "w") as f:
|
||||||
|
f.write(config_content)
|
||||||
|
|
||||||
|
with h5py.File(save_dir / save_file, "w") as outfile:
|
||||||
|
outfile.create_dataset("data", data=save_data)
|
||||||
|
outfile.create_dataset("symbols", data=metadata.pop("symbols"))
|
||||||
|
for key, value in metadata.items():
|
||||||
|
# if isinstance(value, dict):
|
||||||
|
# value = json.dumps(model_runner.convert_arrays(value))
|
||||||
|
outfile.attrs[key] = value
|
||||||
|
# np.save(save_dir / save_file, save_data)
|
||||||
|
|
||||||
|
# print("Saved config to", config_filename)
|
||||||
|
# print("Saved data to", save_dir / save_file)
|
||||||
|
|
||||||
|
return config_filename
|
||||||
|
|
||||||
|
def create_config_filename(config, data_dir:Path, timestamp=None):
|
||||||
|
if timestamp is None:
|
||||||
|
timestamp = datetime.now()
|
||||||
filename_components = (
|
filename_components = (
|
||||||
timestamp.strftime("%Y%m%d-%H%M%S"),
|
timestamp.strftime("%Y%m%d-%H%M%S"),
|
||||||
config["glova"]["sps"],
|
config["glova"]["sps"],
|
||||||
config["glova"]["nos"],
|
config["glova"]["nos"],
|
||||||
|
config["signal"]["osnr"],
|
||||||
config["fiber"]["length"],
|
config["fiber"]["length"],
|
||||||
config["fiber"]["gamma"],
|
config["fiber"]["gamma"],
|
||||||
config["fiber"]["alpha"],
|
config["fiber"]["alpha"],
|
||||||
config["fiber"]["d"],
|
config["fiber"]["d"],
|
||||||
config["fiber"]["s"],
|
config["fiber"]["s"],
|
||||||
f"{config['signal']['modulation'].upper()}{config['signal']['mod_order']}",
|
f"{config['signal']['modulation'].upper()}{config['signal']['mod_order']}",
|
||||||
config['fiber'].get('birefsteps',0),
|
config["fiber"].get("birefsteps", 0),
|
||||||
config["fiber"].get("max_delta_beta", 0),
|
config["fiber"].get("pmd_q", 0),
|
||||||
|
int(config["glova"]["symbolrate"] / 1e9),
|
||||||
)
|
)
|
||||||
|
|
||||||
lookup_file = "-".join(map(str, filename_components)) + ".ini"
|
lookup_file = "-".join(map(str, filename_components)) + ".ini"
|
||||||
with open(data_dir / lookup_file, "w") as f:
|
return data_dir / lookup_file
|
||||||
f.write(config_content)
|
|
||||||
|
|
||||||
np.save(save_dir / save_file, save_data)
|
def length_loop(config, lengths, save=True):
|
||||||
|
|
||||||
print("Saved config to", data_dir / lookup_file)
|
|
||||||
print("Saved data to", save_dir / save_file)
|
|
||||||
|
|
||||||
|
|
||||||
def length_loop(config, lengths, incremental=False, bireflength=None, save=True):
|
|
||||||
lengths = sorted(lengths)
|
lengths = sorted(lengths)
|
||||||
input_override = None
|
for length in lengths:
|
||||||
birefsteps_running = 0
|
|
||||||
for lind, length in enumerate(lengths):
|
|
||||||
# print(f"\nGenerating data for fiber length {length}")
|
|
||||||
if lind > 0 and incremental:
|
|
||||||
# set the length to the difference between the current and previous length -> incremental
|
|
||||||
length = lengths[lind] - lengths[lind - 1]
|
|
||||||
if incremental:
|
|
||||||
print(
|
|
||||||
f"\nGenerating data for fiber length {lengths[lind]}m [using {length}m increment]"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
print(f"\nGenerating data for fiber length {length}m")
|
print(f"\nGenerating data for fiber length {length}m")
|
||||||
config["fiber"]["length"] = length
|
config["fiber"]["length"] = length
|
||||||
if bireflength is not None and bireflength > 0:
|
|
||||||
config["fiber"]["birefsteps"] = length // bireflength
|
|
||||||
birefsteps_running += config["fiber"]["birefsteps"]
|
|
||||||
# set the input data to the output data of the previous run
|
|
||||||
cfiber, cdata, noise, edfa = initialize_fiber_and_data(
|
|
||||||
config, input_data_override=input_override
|
|
||||||
)
|
|
||||||
|
|
||||||
if lind == 0:
|
cfiber, cdata, noise, edfa, symbols, py_glova = initialize_fiber_and_data(config)
|
||||||
cdata_orig = cdata
|
|
||||||
|
|
||||||
mean_power_in = np.sum(pypho.functions.getpower_W(cdata.E_in))
|
mean_power_in = np.sum(pypho.functions.getpower_W(cdata.E_in))
|
||||||
print(
|
|
||||||
f"Mean power in: {mean_power_in:.3e} W ({pypho.functions.W_to_dBm(mean_power_in):.3e} dBm)"
|
|
||||||
)
|
|
||||||
|
|
||||||
cfiber()
|
cfiber()
|
||||||
|
|
||||||
|
|
||||||
mean_power_out = np.sum(pypho.functions.getpower_W(cdata.E_out))
|
mean_power_out = np.sum(pypho.functions.getpower_W(cdata.E_out))
|
||||||
print(
|
|
||||||
f"Mean power out: {mean_power_out:.3e} W ({pypho.functions.W_to_dBm(mean_power_out):.3e} dBm)"
|
|
||||||
)
|
|
||||||
|
|
||||||
if incremental:
|
print(f"Mean power in: {mean_power_in:.3e} W ({pypho.functions.W_to_dBm(mean_power_in):.3e} dBm)")
|
||||||
input_override = (cdata.E_out, noise)
|
print(f"Mean power out: {mean_power_out:.3e} W ({pypho.functions.W_to_dBm(mean_power_out):.3e} dBm)")
|
||||||
cdata.E_in = cdata_orig.E_in
|
|
||||||
config["fiber"]["length"] = lengths[lind]
|
|
||||||
if bireflength is not None:
|
|
||||||
config["fiber"]["birefsteps"] = birefsteps_running
|
|
||||||
|
|
||||||
E_tmp = [{'E': cdata.E_out, 'noise': noise*(-cfiber.params.l*cfiber.params.alpha)}]
|
E_tmp = [{"E": cdata.E_out, "noise": noise * (-cfiber.params.l * cfiber.params.alpha)}]
|
||||||
E_tmp = edfa(E=E_tmp)
|
E_tmp = edfa(E=E_tmp)
|
||||||
cdata.E_out = E_tmp[0]['E']
|
cdata.E_out = E_tmp[0]["E"]
|
||||||
|
|
||||||
|
mean_power_amp = np.sum(pypho.functions.getpower_W(cdata.E_out))
|
||||||
|
print(f"Mean power after EDFA: {mean_power_amp:.3e} W ({pypho.functions.W_to_dBm(mean_power_amp):.3e} dBm)")
|
||||||
|
|
||||||
if save:
|
if save:
|
||||||
save_data(cdata, config)
|
save_data(cdata, config)
|
||||||
|
|
||||||
@@ -436,27 +471,55 @@ def length_loop(config, lengths, incremental=False, bireflength=None, save=True)
|
|||||||
|
|
||||||
|
|
||||||
def single_run_with_plot(config, save=True):
|
def single_run_with_plot(config, save=True):
|
||||||
cfiber, cdata, noise, edfa = initialize_fiber_and_data(config)
|
cfiber, cdata, config_filename = single_run(config, save)
|
||||||
|
|
||||||
mean_power_in = np.sum(pypho.functions.getpower_W(cdata.E_in))
|
|
||||||
print(
|
|
||||||
f"Mean power in: {mean_power_in:.3e} W ({pypho.functions.W_to_dBm(mean_power_in):.3e} dBm)"
|
|
||||||
)
|
|
||||||
|
|
||||||
cfiber()
|
|
||||||
|
|
||||||
mean_power_out = np.sum(pypho.functions.getpower_W(cdata.E_out))
|
|
||||||
print(
|
|
||||||
f"Mean power out: {mean_power_out:.3e} W ({pypho.functions.W_to_dBm(mean_power_out):.3e} dBm)"
|
|
||||||
)
|
|
||||||
|
|
||||||
E_tmp = [{'E': cdata.E_out, 'noise': noise*(-cfiber.params.l*cfiber.params.alpha)}]
|
|
||||||
E_tmp = edfa(E=E_tmp)
|
|
||||||
cdata.E_out = E_tmp[0]['E']
|
|
||||||
if save:
|
|
||||||
save_data(cdata, config)
|
|
||||||
|
|
||||||
in_out_eyes(cfiber, cdata, show_pols=False)
|
in_out_eyes(cfiber, cdata, show_pols=False)
|
||||||
|
return config_filename
|
||||||
|
|
||||||
|
|
||||||
|
def single_run(config, save=True, silent=True):
|
||||||
|
cfiber, cdata, noise, edfa, symbols, glova, E_in = initialize_fiber_and_data(config)
|
||||||
|
|
||||||
|
# transmit
|
||||||
|
cfiber()
|
||||||
|
|
||||||
|
# amplify
|
||||||
|
E_tmp = [{"E": cdata.E_out, "noise": noise}]
|
||||||
|
|
||||||
|
E_tmp = edfa(E=E_tmp)
|
||||||
|
|
||||||
|
|
||||||
|
# rotate
|
||||||
|
# ortho error
|
||||||
|
ortho_error = config["fiber"].get("ortho_error", 0)
|
||||||
|
|
||||||
|
E_tmp[0]["E"] = np.stack((
|
||||||
|
E_tmp[0]["E"][0] * np.cos(ortho_error/2) + E_tmp[0]["E"][1] * np.sin(ortho_error/2),
|
||||||
|
E_tmp[0]["E"][0] * np.sin(ortho_error/2) + E_tmp[0]["E"][1] * np.cos(ortho_error/2)
|
||||||
|
), axis=0)
|
||||||
|
|
||||||
|
|
||||||
|
pol_error = config['fiber'].get('pol_error', 0)
|
||||||
|
|
||||||
|
E_tmp[0]["E"] = np.stack((
|
||||||
|
E_tmp[0]["E"][0] * np.cos(pol_error) - E_tmp[0]["E"][1] * np.sin(pol_error),
|
||||||
|
E_tmp[0]["E"][0] * np.sin(pol_error) + E_tmp[0]["E"][1] * np.cos(pol_error)
|
||||||
|
), axis=0)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# output
|
||||||
|
cdata.E_out = E_tmp[0]["E"]
|
||||||
|
|
||||||
|
config_filename = None
|
||||||
|
symbols = np.array(symbols)
|
||||||
|
if save:
|
||||||
|
config_filename = save_data(cdata, config, **{"symbols": symbols})
|
||||||
|
if not silent:
|
||||||
|
print(f"Saved config to {config_filename}")
|
||||||
|
return cfiber, cdata, config_filename
|
||||||
|
|
||||||
|
|
||||||
def in_out_eyes(cfiber, cdata, show_pols=False):
|
def in_out_eyes(cfiber, cdata, show_pols=False):
|
||||||
fig, axs = plt.subplots(2, 2, sharex=True, sharey=True)
|
fig, axs = plt.subplots(2, 2, sharex=True, sharey=True)
|
||||||
@@ -620,9 +683,7 @@ def plot_eye_diagram(
|
|||||||
signal = signal[: head * eye_width]
|
signal = signal[: head * eye_width]
|
||||||
if normalize:
|
if normalize:
|
||||||
signal = signal / np.max(signal)
|
signal = signal / np.max(signal)
|
||||||
slices = np.lib.stride_tricks.sliding_window_view(signal, eye_width + 1)[
|
slices = np.lib.stride_tricks.sliding_window_view(signal, eye_width + 1)[offset % (eye_width + 1) :: eye_width]
|
||||||
offset % (eye_width + 1) :: eye_width
|
|
||||||
]
|
|
||||||
plt_ax = np.arange(-eye_width // 2, eye_width // 2 + 1) / samplerate
|
plt_ax = np.arange(-eye_width // 2, eye_width // 2 + 1) / samplerate
|
||||||
for slice in slices:
|
for slice in slices:
|
||||||
ax.plot(plt_ax, slice, color=color, alpha=0.1)
|
ax.plot(plt_ax, slice, color=color, alpha=0.1)
|
||||||
@@ -643,13 +704,26 @@ if __name__ == "__main__":
|
|||||||
# lengths = [*lengths, *lengths]
|
# lengths = [*lengths, *lengths]
|
||||||
lengths = (
|
lengths = (
|
||||||
# 8000, 9000,
|
# 8000, 9000,
|
||||||
10000, 20000, 30000, 40000, 50000, 60000, 70000, 80000, 90000,
|
10000,
|
||||||
95000, 100000, 105000, 110000, 115000, 120000
|
20000,
|
||||||
|
30000,
|
||||||
|
40000,
|
||||||
|
50000,
|
||||||
|
60000,
|
||||||
|
70000,
|
||||||
|
80000,
|
||||||
|
90000,
|
||||||
|
95000,
|
||||||
|
100000,
|
||||||
|
105000,
|
||||||
|
110000,
|
||||||
|
115000,
|
||||||
|
120000,
|
||||||
)
|
)
|
||||||
lengths = sorted(lengths)
|
|
||||||
|
|
||||||
length_loop(config, lengths, incremental=False, bireflength=1000, save=True)
|
# lengths = (10000,100000)
|
||||||
|
|
||||||
|
# length_loop(config, lengths, save=True)
|
||||||
# birefringence is constant over coupling length -> several 100m -> bireflength=1000 (m)
|
# birefringence is constant over coupling length -> several 100m -> bireflength=1000 (m)
|
||||||
|
|
||||||
# single_run_with_plot(config, save=True)
|
single_run_with_plot(config, save=False)
|
||||||
|
|
||||||
138
src/single-core-regen/testing/prob_dens.ipynb
Normal file
138
src/single-core-regen/testing/prob_dens.ipynb
Normal file
File diff suppressed because one or more lines are too long
1351
src/single-core-regen/tolerance_testing.py
Normal file
1351
src/single-core-regen/tolerance_testing.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -26,7 +26,7 @@ global_settings = GlobalSettings(
|
|||||||
)
|
)
|
||||||
|
|
||||||
data_settings = DataSettings(
|
data_settings = DataSettings(
|
||||||
config_path="data/*-128-16384-10000-0-0-0-0-PAM4-0-0.4.ini",
|
config_path="data/20241211-105524-128-16384-1-0-0-0-0-PAM4-0-0.ini",
|
||||||
# config_path=[f"data/20241202-*-128-16384-{length}-0-0-17-0-PAM4-0.ini" for length in range(48000, 53000, 1000)],
|
# config_path=[f"data/20241202-*-128-16384-{length}-0-0-17-0-PAM4-0.ini" for length in range(48000, 53000, 1000)],
|
||||||
dtype="complex64",
|
dtype="complex64",
|
||||||
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
|
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
|
||||||
@@ -53,14 +53,14 @@ pytorch_settings = PytorchSettings(
|
|||||||
)
|
)
|
||||||
|
|
||||||
model_settings = ModelSettings(
|
model_settings = ModelSettings(
|
||||||
output_dim=3,
|
output_dim=1,
|
||||||
n_hidden_layers=3,
|
n_hidden_layers=3,
|
||||||
overrides={
|
overrides={
|
||||||
"n_hidden_nodes_0": 2,
|
"n_hidden_nodes_0": 4,
|
||||||
"n_hidden_nodes_1": 2,
|
"n_hidden_nodes_1": 4,
|
||||||
"n_hidden_nodes_2": 2,
|
"n_hidden_nodes_2": 4,
|
||||||
},
|
},
|
||||||
dropout_prob=0.01,
|
dropout_prob=0,
|
||||||
model_layer_function="ONNRect",
|
model_layer_function="ONNRect",
|
||||||
model_activation_func="EOActivation",
|
model_activation_func="EOActivation",
|
||||||
model_layer_kwargs={"square": True},
|
model_layer_kwargs={"square": True},
|
||||||
@@ -110,20 +110,24 @@ model_settings = ModelSettings(
|
|||||||
)
|
)
|
||||||
|
|
||||||
optimizer_settings = OptimizerSettings(
|
optimizer_settings = OptimizerSettings(
|
||||||
optimizer="AdamW",
|
optimizer="RMSprop",
|
||||||
|
# optimizer="AdamW",
|
||||||
optimizer_kwargs={
|
optimizer_kwargs={
|
||||||
"lr": 0.005,
|
"lr": 0.01,
|
||||||
"amsgrad": True,
|
"alpha": 0.9,
|
||||||
|
"momentum": 0.1,
|
||||||
|
"eps": 1e-8,
|
||||||
|
"centered": True,
|
||||||
|
# "amsgrad": True,
|
||||||
# "weight_decay": 1e-7,
|
# "weight_decay": 1e-7,
|
||||||
},
|
},
|
||||||
# learning_rate=0.05,
|
|
||||||
scheduler="ReduceLROnPlateau",
|
scheduler="ReduceLROnPlateau",
|
||||||
scheduler_kwargs={
|
scheduler_kwargs={
|
||||||
"patience": 2**6,
|
"patience": 2**5,
|
||||||
"factor": 0.75,
|
"factor": 0.75,
|
||||||
# "threshold": 1e-3,
|
# "threshold": 1e-3,
|
||||||
"min_lr": 1e-6,
|
"min_lr": 1e-6,
|
||||||
"cooldown": 10,
|
# "cooldown": 10,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -320,6 +320,29 @@ class normalize_by_first(nn.Module):
|
|||||||
def forward(self, data):
|
def forward(self, data):
|
||||||
return data / data[:, 0].unsqueeze(1)
|
return data / data[:, 0].unsqueeze(1)
|
||||||
|
|
||||||
|
class rotate(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(rotate, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, data, angle):
|
||||||
|
# data -> (batch, n*2)
|
||||||
|
# angle -> (batch, n)
|
||||||
|
data_ = data
|
||||||
|
if angle.ndim == 1:
|
||||||
|
angle_ = angle.unsqueeze(1)
|
||||||
|
else:
|
||||||
|
angle_ = angle
|
||||||
|
angle_ = angle_.expand(-1, data_.shape[1]//2)
|
||||||
|
c = torch.cos(angle_)
|
||||||
|
s = torch.sin(angle_)
|
||||||
|
rot = torch.stack([torch.stack([c, -s], dim=2),
|
||||||
|
torch.stack([s, c], dim=2)], dim=3)
|
||||||
|
d = torch.bmm(data_.reshape(-1, 1, 2), rot.view(-1, 2, 2).to(dtype=data_.dtype)).reshape(*data.shape)
|
||||||
|
# d = torch.bmm(data.unsqueeze(-1).mT, rot.to(dtype=data.dtype).mT).mT.squeeze(-1)
|
||||||
|
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
class photodiode(nn.Module):
|
class photodiode(nn.Module):
|
||||||
def __init__(self, size, bias=True):
|
def __init__(self, size, bias=True):
|
||||||
super(photodiode, self).__init__()
|
super(photodiode, self).__init__()
|
||||||
@@ -418,8 +441,7 @@ class input_rotator(nn.Module):
|
|||||||
# return out
|
# return out
|
||||||
|
|
||||||
|
|
||||||
#### as defined by zhang et al
|
#### as defined by zhang et alas
|
||||||
|
|
||||||
|
|
||||||
class DropoutComplex(nn.Module):
|
class DropoutComplex(nn.Module):
|
||||||
def __init__(self, p=0.5):
|
def __init__(self, p=0.5):
|
||||||
@@ -441,7 +463,7 @@ class Scale(nn.Module):
|
|||||||
self.scale = nn.Parameter(torch.ones(size, dtype=torch.float32))
|
self.scale = nn.Parameter(torch.ones(size, dtype=torch.float32))
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return x * self.scale
|
return x * torch.sqrt(self.scale)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"Scale({self.size})"
|
return f"Scale({self.size})"
|
||||||
@@ -459,6 +481,15 @@ class Identity(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
class phase_shift(nn.Module):
|
||||||
|
def __init__(self, size):
|
||||||
|
super(phase_shift, self).__init__()
|
||||||
|
self.size = size
|
||||||
|
self.phase = nn.Parameter(torch.rand(size))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x * torch.exp(1j*self.phase)
|
||||||
|
|
||||||
|
|
||||||
class PowRot(nn.Module):
|
class PowRot(nn.Module):
|
||||||
def __init__(self, bias=False):
|
def __init__(self, bias=False):
|
||||||
@@ -487,7 +518,7 @@ class MZISingle(nn.Module):
|
|||||||
return x * torch.exp(1j * self.phi) * torch.sin(self.omega + self.func(x))
|
return x * torch.exp(1j * self.phi) * torch.sin(self.omega + self.func(x))
|
||||||
|
|
||||||
def naive_angle_loss(x: torch.Tensor, target: torch.Tensor, mod=2*torch.pi):
|
def naive_angle_loss(x: torch.Tensor, target: torch.Tensor, mod=2*torch.pi):
|
||||||
return torch.fmod((x - target), mod).square().mean()
|
return torch.fmod((x.abs().real - target.abs().real), mod).abs().mean()
|
||||||
|
|
||||||
def cosine_loss(x: torch.Tensor, target: torch.Tensor):
|
def cosine_loss(x: torch.Tensor, target: torch.Tensor):
|
||||||
return (2*(1 - torch.cos(x - target))).mean()
|
return (2*(1 - torch.cos(x - target))).mean()
|
||||||
@@ -508,51 +539,46 @@ def angle_mse_loss(x: torch.Tensor, target: torch.Tensor):
|
|||||||
|
|
||||||
class EOActivation(nn.Module):
|
class EOActivation(nn.Module):
|
||||||
def __init__(self, size=None):
|
def __init__(self, size=None):
|
||||||
# 10.1109/SiPhotonics60897.2024.10543376
|
# 10.1109/JSTQE.2019.2930455
|
||||||
super(EOActivation, self).__init__()
|
super(EOActivation, self).__init__()
|
||||||
if size is None:
|
if size is None:
|
||||||
raise ValueError("Size must be specified")
|
raise ValueError("Size must be specified")
|
||||||
self.size = size
|
self.size = size
|
||||||
self.alpha = nn.Parameter(torch.ones(size))
|
self.alpha = nn.Parameter(torch.rand(size))
|
||||||
self.V_bias = nn.Parameter(torch.ones(size))
|
self.gain = nn.Parameter(torch.rand(size))
|
||||||
self.gain = nn.Parameter(torch.ones(size))
|
self.V_bias = nn.Parameter(torch.rand(size))
|
||||||
# if bias:
|
# self.register_buffer("gain", torch.ones(size))
|
||||||
# self.phase_bias = nn.Parameter(torch.zeros(size))
|
# self.register_buffer("responsivity", torch.ones(size))
|
||||||
# else:
|
# self.register_buffer("V_pi", torch.ones(size))
|
||||||
# self.register_buffer("phase_bias", torch.zeros(size))
|
|
||||||
self.register_buffer("phase_bias", torch.clamp(torch.ones(size) + torch.randn(size)*0.1, 0, 1)*torch.pi)
|
|
||||||
self.register_buffer("responsivity", torch.ones(size)*0.9)
|
|
||||||
self.register_buffer("V_pi", torch.ones(size)*3)
|
|
||||||
|
|
||||||
self.reset_weights()
|
self.reset_weights()
|
||||||
|
|
||||||
def reset_weights(self):
|
def reset_weights(self):
|
||||||
if "alpha" in self._parameters:
|
if "alpha" in self._parameters:
|
||||||
self.alpha.data = torch.ones(self.size)*0.5
|
self.alpha.data = torch.rand(self.size)
|
||||||
if "V_pi" in self._parameters:
|
# if "V_pi" in self._parameters:
|
||||||
self.V_pi.data = torch.ones(self.size)*3
|
# self.V_pi.data = torch.rand(self.size)*3
|
||||||
if "V_bias" in self._parameters:
|
if "V_bias" in self._parameters:
|
||||||
self.V_bias.data = torch.zeros(self.size)
|
self.V_bias.data = torch.randn(self.size)
|
||||||
if "gain" in self._parameters:
|
if "gain" in self._parameters:
|
||||||
self.gain.data = torch.ones(self.size)
|
self.gain.data = torch.rand(self.size)
|
||||||
if "responsivity" in self._parameters:
|
# if "responsivity" in self._parameters:
|
||||||
self.responsivity.data = torch.ones(self.size)*0.9
|
# self.responsivity.data = torch.ones(self.size)*0.9
|
||||||
if "bias" in self._parameters:
|
# if "bias" in self._parameters:
|
||||||
self.phase_bias.data = torch.zeros(self.size)
|
# self.phase_bias.data = torch.zeros(self.size)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
phi_b = torch.pi * self.V_bias / (self.V_pi + 1e-8)
|
phi_b = torch.pi * self.V_bias# / (self.V_pi)
|
||||||
g_phi = torch.pi * (self.alpha * self.gain * self.responsivity) / (self.V_pi + 1e-8)
|
g_phi = torch.pi * (self.alpha * self.gain)# * self.responsivity)# / (self.V_pi)
|
||||||
intermediate = g_phi * x.abs().square() + phi_b
|
intermediate = g_phi * x.abs().square() + phi_b
|
||||||
return (
|
return (
|
||||||
1j
|
1j
|
||||||
* torch.sqrt(1 - self.alpha)
|
* torch.sqrt(1 - self.alpha)
|
||||||
* torch.exp(-0.5j * (intermediate + self.phase_bias))
|
* torch.exp(-0.5j * intermediate)
|
||||||
* torch.cos(0.5 * intermediate)
|
* torch.cos(0.5 * intermediate)
|
||||||
* x
|
* x
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class Pow(nn.Module):
|
class Pow(nn.Module):
|
||||||
"""
|
"""
|
||||||
implements the activation function
|
implements the activation function
|
||||||
@@ -693,6 +719,7 @@ __all__ = [
|
|||||||
MZISingle,
|
MZISingle,
|
||||||
EOActivation,
|
EOActivation,
|
||||||
photodiode,
|
photodiode,
|
||||||
|
phase_shift,
|
||||||
# SaturableAbsorberLambertW,
|
# SaturableAbsorberLambertW,
|
||||||
# SaturableAbsorber,
|
# SaturableAbsorber,
|
||||||
# SpreadLayer,
|
# SpreadLayer,
|
||||||
|
|||||||
105
src/single-core-regen/util/core.py
Normal file
105
src/single-core-regen/util/core.py
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
# Copyright (c) 2015, Warren Weckesser. All rights reserved.
|
||||||
|
# This software is licensed according to the "BSD 2-clause" license.
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import h5py
|
||||||
|
import numpy as _np
|
||||||
|
from scipy.interpolate import interp1d as _interp1d
|
||||||
|
from scipy.ndimage import gaussian_filter as _gaussian_filter
|
||||||
|
from ._brescount import bres_curve_count as _bres_curve_count
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ['grid_count']
|
||||||
|
|
||||||
|
|
||||||
|
def grid_count(y, window_size, offset=0, size=None, fuzz=True, blur=0, bounds=None):
|
||||||
|
"""
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
`y` is the 1-d array of signal samples.
|
||||||
|
|
||||||
|
`window_size` is the number of samples to show horizontally in the
|
||||||
|
eye diagram. Typically this is twice the number of samples in a
|
||||||
|
"symbol" (i.e. in a data bit).
|
||||||
|
|
||||||
|
`offset` is the number of initial samples to skip before computing
|
||||||
|
the eye diagram. This allows the overall phase of the diagram to
|
||||||
|
be adjusted.
|
||||||
|
|
||||||
|
`size` must be a tuple of two integers. It sets the size of the
|
||||||
|
array of counts, (height, width). The default is (800, 640).
|
||||||
|
|
||||||
|
`fuzz`: If True, the values in `y` are reinterpolated with a
|
||||||
|
random "fuzz factor" before plotting in the eye diagram. This
|
||||||
|
reduces an aliasing-like effect that arises with the use of
|
||||||
|
Bresenham's algorithm.
|
||||||
|
|
||||||
|
`bounds` must be a tuple of two floating point values, (ymin, ymax).
|
||||||
|
These set the y range of the returned array. If not given, the
|
||||||
|
bounds are `(y.min() - 0.05*A, y.max() + 0.05*A)`, where `A` is
|
||||||
|
`y.max() - y.min()`.
|
||||||
|
|
||||||
|
Return Value
|
||||||
|
------------
|
||||||
|
Returns a numpy array of integers.
|
||||||
|
|
||||||
|
"""
|
||||||
|
# hash input params
|
||||||
|
param_ob = (y, window_size, offset, size, fuzz, blur, bounds)
|
||||||
|
param_hash = hashlib.md5(str(param_ob).encode()).hexdigest()
|
||||||
|
cache_dir = Path.home()/".eyediagram"/".cache"
|
||||||
|
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
if (cache_dir/param_hash).is_file():
|
||||||
|
try:
|
||||||
|
with h5py.File(cache_dir/param_hash, "r") as infile:
|
||||||
|
counts = infile["counts"][:]
|
||||||
|
if counts.len() != 0:
|
||||||
|
return counts
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if size is None:
|
||||||
|
size = (800, 640)
|
||||||
|
height, width = size
|
||||||
|
dt = width / window_size
|
||||||
|
counts = _np.zeros((width, height), dtype=_np.int32)
|
||||||
|
|
||||||
|
if bounds is None:
|
||||||
|
ymin = y.min()
|
||||||
|
ymax = y.max()
|
||||||
|
yamp = ymax - ymin
|
||||||
|
ymin = ymin - 0.05*yamp
|
||||||
|
ymax = ymax + 0.05*yamp
|
||||||
|
ymax = _np.ceil(ymax*10)/10
|
||||||
|
ymin = _np.floor(ymin*10)/10
|
||||||
|
else:
|
||||||
|
ymin, ymax = bounds
|
||||||
|
|
||||||
|
start = offset
|
||||||
|
while start + window_size < len(y):
|
||||||
|
end = start + window_size
|
||||||
|
yy = y[start:end+1]
|
||||||
|
k = _np.arange(len(yy))
|
||||||
|
xx = dt*k
|
||||||
|
if fuzz:
|
||||||
|
f = _interp1d(xx, yy, kind='cubic')
|
||||||
|
jiggle = dt*(_np.random.beta(a=3, b=3, size=len(xx)-2) - 0.5)
|
||||||
|
xx[1:-1] += jiggle
|
||||||
|
yd = f(xx)
|
||||||
|
else:
|
||||||
|
yd = yy
|
||||||
|
iyd = (height * (yd - ymin)/(ymax - ymin)).astype(_np.int32)
|
||||||
|
_bres_curve_count(xx.astype(_np.int32), iyd, counts)
|
||||||
|
|
||||||
|
start = end
|
||||||
|
|
||||||
|
if blur != 0:
|
||||||
|
counts = _gaussian_filter(counts, sigma=blur)
|
||||||
|
|
||||||
|
with h5py.File(cache_dir/param_hash, "w") as outfile:
|
||||||
|
outfile.create_dataset("data", data=counts)
|
||||||
|
|
||||||
|
return counts
|
||||||
@@ -1,10 +1,12 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
import h5py
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
# from torch.utils.data import Sampler
|
# from torch.utils.data import Sampler
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import configparser
|
import configparser
|
||||||
|
import multiprocessing as mp
|
||||||
|
|
||||||
# class SubsetSampler(Sampler[int]):
|
# class SubsetSampler(Sampler[int]):
|
||||||
# """
|
# """
|
||||||
@@ -24,7 +26,22 @@ import configparser
|
|||||||
# return len(self.indices)
|
# return len(self.indices)
|
||||||
|
|
||||||
|
|
||||||
def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=False, device=None, dtype=None):
|
def load_from_file(datapath):
|
||||||
|
if str(datapath).endswith(".h5"):
|
||||||
|
symbols = None
|
||||||
|
with h5py.File(datapath, "r") as infile:
|
||||||
|
data = infile["data"][:]
|
||||||
|
try:
|
||||||
|
symbols = np.swapaxes(infile["symbols"][:], 0, 1)
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
symbols = None
|
||||||
|
data = np.load(datapath)
|
||||||
|
return data, symbols
|
||||||
|
|
||||||
|
|
||||||
|
def load_data(config_path, skipfirst=0, skiplast=0, symbols=None, real=False, normalize=1, device=None, dtype=None):
|
||||||
filepath = Path(config_path)
|
filepath = Path(config_path)
|
||||||
filepath = filepath.parent.glob(filepath.name)
|
filepath = filepath.parent.glob(filepath.name)
|
||||||
config = configparser.ConfigParser()
|
config = configparser.ConfigParser()
|
||||||
@@ -40,14 +57,28 @@ def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=Fals
|
|||||||
if symbols is None:
|
if symbols is None:
|
||||||
symbols = int(config["glova"]["nos"]) - skipfirst
|
symbols = int(config["glova"]["nos"]) - skipfirst
|
||||||
|
|
||||||
data = np.load(datapath)[int(skipfirst * sps) : int(symbols * sps + skipfirst * sps)]
|
data, orig_symbols = load_from_file(datapath)
|
||||||
timestamps = np.arange(int(skipfirst * sps), int(symbols * sps + skipfirst * sps))
|
|
||||||
|
|
||||||
if normalize:
|
data = data[int(skipfirst * sps) : int(symbols * sps + skipfirst * sps - skiplast * sps)]
|
||||||
# square gets normalized to 1, as the power is (proportional to) the square of the amplitude
|
orig_symbols = orig_symbols[skipfirst : symbols + skipfirst - skiplast]
|
||||||
a, b, c, d = np.square(data.T)
|
timestamps = np.arange(int(skipfirst * sps), int(symbols * sps + skipfirst * sps - skiplast * sps))
|
||||||
a, b, c, d = a / np.max(np.abs(a)), b / np.max(np.abs(b)), c / np.max(np.abs(c)), d / np.max(np.abs(d))
|
|
||||||
data = np.sqrt(np.array([a, b, c, d]).T)
|
data *= np.sqrt(normalize)
|
||||||
|
|
||||||
|
launch_power = float(config["signal"]["laser_power"])
|
||||||
|
output_power = float(config["signal"]["edfa_power"])
|
||||||
|
|
||||||
|
target_normalization = 10 ** (output_power / 10) / 10 ** (launch_power / 10)
|
||||||
|
# target_normalization *= 0.5 # allow 50% power loss, so the network can ignore parts of the signal
|
||||||
|
|
||||||
|
data[:, 0:2] *= np.sqrt(target_normalization)
|
||||||
|
|
||||||
|
# if normalize:
|
||||||
|
# # square gets normalized to 1, as the power is (proportional to) the square of the amplitude
|
||||||
|
# a, b, c, d = data.T
|
||||||
|
# a, b, c, d = a - np.min(np.abs(a)), b - np.min(np.abs(b)), c - np.min(np.abs(c)), d - np.min(np.abs(d))
|
||||||
|
# a, b, c, d = a / np.max(np.abs(a)), b / np.max(np.abs(b)), c / np.max(np.abs(c)), d / np.max(np.abs(d))
|
||||||
|
# data = np.array([a, b, c, d]).T
|
||||||
|
|
||||||
if real:
|
if real:
|
||||||
data = np.abs(data)
|
data = np.abs(data)
|
||||||
@@ -58,7 +89,7 @@ def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=Fals
|
|||||||
|
|
||||||
data = torch.tensor(data, device=device, dtype=dtype)
|
data = torch.tensor(data, device=device, dtype=dtype)
|
||||||
|
|
||||||
return data, config
|
return data, config, orig_symbols
|
||||||
|
|
||||||
|
|
||||||
def roll_along(arr, shifts, dim):
|
def roll_along(arr, shifts, dim):
|
||||||
@@ -110,11 +141,15 @@ class FiberRegenerationDataset(Dataset):
|
|||||||
target_delay: float | int = 0,
|
target_delay: float | int = 0,
|
||||||
xy_delay: float | int = 0,
|
xy_delay: float | int = 0,
|
||||||
drop_first: float | int = 0,
|
drop_first: float | int = 0,
|
||||||
|
drop_last=0,
|
||||||
dtype: torch.dtype = None,
|
dtype: torch.dtype = None,
|
||||||
real: bool = False,
|
real: bool = False,
|
||||||
device=None,
|
device=None,
|
||||||
polarisations: tuple | list = (0,),
|
# osnr: float|None = None,
|
||||||
|
polarisations=None,
|
||||||
randomise_polarisations: bool = False,
|
randomise_polarisations: bool = False,
|
||||||
|
repeat_randoms: int = 1,
|
||||||
|
# cross_pol_interference: float = 0,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -148,37 +183,29 @@ class FiberRegenerationDataset(Dataset):
|
|||||||
assert drop_first >= 0, "drop_first must be non-negative"
|
assert drop_first >= 0, "drop_first must be non-negative"
|
||||||
|
|
||||||
self.randomise_polarisations = randomise_polarisations
|
self.randomise_polarisations = randomise_polarisations
|
||||||
|
# self.cross_pol_interference = cross_pol_interference
|
||||||
|
|
||||||
faux = kwargs.pop("faux", False)
|
|
||||||
|
|
||||||
if faux:
|
|
||||||
data_raw = np.array(
|
|
||||||
[[i + 0.1j, i + 0.2j, i + 1.1j, i + 1.2j] for i in range(12800)],
|
|
||||||
dtype=np.complex128,
|
|
||||||
)
|
|
||||||
data_raw = torch.tensor(data_raw, device=device, dtype=dtype)
|
|
||||||
timestamps = torch.arange(12800)
|
|
||||||
|
|
||||||
data_raw = torch.concatenate([data_raw, timestamps.reshape(-1, 1)], axis=-1)
|
|
||||||
|
|
||||||
self.config = {
|
|
||||||
"data": {"dir": '"."', "npy_dir": '"."', "file": "faux"},
|
|
||||||
"glova": {"sps": 128},
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
data_raw = None
|
data_raw = None
|
||||||
self.config = None
|
self.config = None
|
||||||
files = []
|
files = []
|
||||||
|
self.orig_symbols = None
|
||||||
for file_path in file_path if isinstance(file_path, (tuple, list)) else [file_path]:
|
for file_path in file_path if isinstance(file_path, (tuple, list)) else [file_path]:
|
||||||
data, config = load_data(
|
data, config, orig_syms = load_data(
|
||||||
file_path,
|
file_path,
|
||||||
skipfirst=drop_first,
|
skipfirst=drop_first,
|
||||||
|
skiplast=drop_last,
|
||||||
symbols=kwargs.get("num_symbols", None),
|
symbols=kwargs.get("num_symbols", None),
|
||||||
real=real,
|
real=real,
|
||||||
normalize=True,
|
normalize=1000,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
if orig_syms is not None:
|
||||||
|
if self.orig_symbols is None:
|
||||||
|
self.orig_symbols = orig_syms
|
||||||
|
else:
|
||||||
|
self.orig_symbols = np.concat((self.orig_symbols, orig_syms), axis=-1)
|
||||||
|
|
||||||
if data_raw is None:
|
if data_raw is None:
|
||||||
data_raw = data
|
data_raw = data
|
||||||
else:
|
else:
|
||||||
@@ -190,22 +217,19 @@ class FiberRegenerationDataset(Dataset):
|
|||||||
files.append(config["data"]["file"].strip('"'))
|
files.append(config["data"]["file"].strip('"'))
|
||||||
self.config["data"]["file"] = str(files)
|
self.config["data"]["file"] = str(files)
|
||||||
|
|
||||||
for i, angle in enumerate(torch.tensor(np.array(polarisations))):
|
# if polarisations is not None:
|
||||||
data_raw_copy = data_raw.clone()
|
# data_raw_clone = data_raw.clone()
|
||||||
if angle == 0:
|
# # rotate the polarisation by 180 degrees
|
||||||
continue
|
# data_raw_clone[2, :] *= -1
|
||||||
sine = torch.sin(angle)
|
# data_raw_clone[3, :] *= -1
|
||||||
cosine = torch.cos(angle)
|
# data_raw = torch.cat([data_raw, data_raw_clone], dim=0)
|
||||||
data_raw_copy[:, 2] = data_raw[:, 2] * cosine - data_raw[:, 3] * sine
|
|
||||||
data_raw_copy[:, 3] = data_raw[:, 2] * sine + data_raw[:, 3] * cosine
|
self.polarisations = bool(polarisations)
|
||||||
if i == 0:
|
|
||||||
data_raw = data_raw_copy
|
|
||||||
else:
|
|
||||||
data_raw = torch.cat([data_raw, data_raw_copy], dim=0)
|
|
||||||
|
|
||||||
self.device = data_raw.device
|
self.device = data_raw.device
|
||||||
|
|
||||||
self.samples_per_symbol = int(self.config["glova"]["sps"])
|
self.samples_per_symbol = int(self.config["glova"]["sps"])
|
||||||
|
# self.num_symbols = int(self.config["glova"]["nos"])
|
||||||
self.samples_per_slice = int(symbols * self.samples_per_symbol)
|
self.samples_per_slice = int(symbols * self.samples_per_symbol)
|
||||||
self.symbols_per_slice = self.samples_per_slice / self.samples_per_symbol
|
self.symbols_per_slice = self.samples_per_slice / self.samples_per_symbol
|
||||||
|
|
||||||
@@ -278,23 +302,94 @@ class FiberRegenerationDataset(Dataset):
|
|||||||
timestamps = data_raw[4, :]
|
timestamps = data_raw[4, :]
|
||||||
data_raw = data_raw[:4, :]
|
data_raw = data_raw[:4, :]
|
||||||
data_raw = data_raw.view(2, 2, -1)
|
data_raw = data_raw.view(2, 2, -1)
|
||||||
timestamps_doubled = torch.cat([timestamps.unsqueeze(dim=0), timestamps.unsqueeze(dim=0)], dim=0).unsqueeze(
|
fiber_in = data_raw[0, :, :]
|
||||||
dim=1
|
fiber_out = data_raw[1, :, :]
|
||||||
)
|
# timestamps_doubled = torch.cat([timestamps.unsqueeze(dim=0), timestamps.unsqueeze(dim=0)], dim=0).unsqueeze(
|
||||||
data_raw = torch.cat([data_raw, timestamps_doubled], dim=1)
|
# dim=1
|
||||||
|
# )
|
||||||
|
fiber_in = torch.cat([fiber_in, timestamps.unsqueeze(0)], dim=0)
|
||||||
|
fiber_out = torch.cat([fiber_out, timestamps.unsqueeze(0)], dim=0)
|
||||||
|
|
||||||
|
# fiber_out: [E_out_x, E_out_y, timestamps]
|
||||||
|
|
||||||
|
# add noise related to amplification necessary due to splitting of the signal
|
||||||
|
# gain_lin = output_dim*2
|
||||||
|
# gain_lin = 1
|
||||||
|
# edfa_nf = float(self.config["signal"]["edfa_nf"])
|
||||||
|
# nf_lin = 10**(edfa_nf/10)
|
||||||
|
# f0 = float(self.config["glova"]["f0"])
|
||||||
|
|
||||||
|
# noise_add = (gain_lin-1)*nf_lin*6.626e-34*f0*12.5e9
|
||||||
|
|
||||||
|
# noise = torch.randn_like(fiber_out[:2, :])
|
||||||
|
# noise_power = torch.sum(torch.mean(noise.abs().square().squeeze(), dim=-1))
|
||||||
|
# noise = noise * torch.sqrt(noise_add / noise_power)
|
||||||
|
# fiber_out[:2, :] += noise
|
||||||
|
|
||||||
|
# if osnr is None:
|
||||||
|
# noisy = fiber_out[:2, :]
|
||||||
|
# else:
|
||||||
|
# noisy = self.add_noise(fiber_out[:2, :], osnr)
|
||||||
|
|
||||||
|
# fiber_out = torch.cat([fiber_out, noisy], dim=0)
|
||||||
|
|
||||||
|
# fiber_out: [E_out_x, E_out_y, timestamps, E_out_x_noisy, E_out_y_noisy]
|
||||||
|
|
||||||
|
if repeat_randoms > 1:
|
||||||
|
fiber_in = fiber_in.repeat(1, 1, repeat_randoms)
|
||||||
|
fiber_out = fiber_out.repeat(1, 1, repeat_randoms)
|
||||||
|
# review: potential problems with repeated timestamps when plotting
|
||||||
|
else:
|
||||||
|
repeat_randoms = 1
|
||||||
|
|
||||||
|
if self.randomise_polarisations:
|
||||||
|
angles = torch.fmod(torch.randperm(data_raw.shape[-1]*repeat_randoms, device=fiber_out.device), 2) * torch.pi
|
||||||
|
start_angle = torch.rand(1) * 2 * torch.pi
|
||||||
|
angles = start_angle + torch.cumsum(torch.randn(data_raw.shape[-1])/333, dim=0) * torch.pi * 2 # random walk
|
||||||
|
angles = torch.randn(data_raw.shape[-1], device=fiber_out.device) * 2*torch.pi / 36 # sigma = 10 degrees
|
||||||
|
# self.angles = torch.rand(self.data.shape[0]) * 2 * torch.pi
|
||||||
|
else:
|
||||||
|
angles = torch.zeros(data_raw.shape[-1], device=fiber_out.device)
|
||||||
|
|
||||||
|
sin = torch.sin(angles)
|
||||||
|
cos = torch.cos(angles)
|
||||||
|
rot = torch.stack([torch.stack([cos, -sin], dim=1), torch.stack([sin, cos], dim=1)], dim=2)
|
||||||
|
data_rot = torch.bmm(fiber_out[:2, :].T.unsqueeze(1), rot.to(dtype=data_raw.dtype)).squeeze(1).T
|
||||||
|
# data_rot_noisy = torch.bmm(fiber_out[3:5, :].T.unsqueeze(1), rot.to(dtype=data_raw.dtype)).squeeze(1).T
|
||||||
|
fiber_out = torch.cat((fiber_out, data_rot), dim=0)
|
||||||
|
fiber_out = torch.cat([fiber_out, angles.unsqueeze(0)], dim=0)
|
||||||
|
|
||||||
|
# fiber_in:
|
||||||
|
# 0 E_in_x,
|
||||||
|
# 1 E_in_y,
|
||||||
|
# 2 timestamps
|
||||||
|
|
||||||
|
# fiber_out:
|
||||||
|
# 0 E_out_x,
|
||||||
|
# 1 E_out_y,
|
||||||
|
# 2 timestamps,
|
||||||
|
# 3 E_out_x_rot,
|
||||||
|
# 4 E_out_y_rot,
|
||||||
|
# 5 angle
|
||||||
|
|
||||||
# data_raw = torch.cat([data_raw, timestamps_doubled], dim=1)
|
# data_raw = torch.cat([data_raw, timestamps_doubled], dim=1)
|
||||||
# data layout
|
# data layout
|
||||||
# [ [E_in_x, E_in_y, timestamps],
|
# [ [E_in_x, E_in_y, timestamps],
|
||||||
# [E_out_x, E_out_y, timestamps] ]
|
# [E_out_x, E_out_y, timestamps] ]
|
||||||
|
|
||||||
self.data = data_raw.unfold(dimension=-1, size=self.samples_per_slice, step=1)
|
self.fiber_in = fiber_in.unfold(dimension=-1, size=self.samples_per_slice, step=1)
|
||||||
self.data = self.data.movedim(-2, 0)
|
self.fiber_in = self.fiber_in.movedim(-2, 0)
|
||||||
|
|
||||||
if randomise_polarisations:
|
self.fiber_out = fiber_out.unfold(dimension=-1, size=self.samples_per_slice, step=1)
|
||||||
self.angles = torch.rand(self.data.shape[0]) * np.pi * 2
|
self.fiber_out = self.fiber_out.movedim(-2, 0)
|
||||||
# self.data[:, 1, :2, :] = self.rotate(self.data[:, 1, :2, :], self.angles)
|
|
||||||
else:
|
# if self.randomise_polarisations:
|
||||||
self.angles = torch.zeros(self.data.shape[0])
|
# self.angles = torch.cumsum((torch.rand(self.fiber_out.shape[0]) - 0.5) * 2 * torch.pi * 2 / 5000, dim=0)
|
||||||
|
|
||||||
|
# self.data = data_raw.unfold(dimension=-1, size=self.samples_per_slice, step=1)
|
||||||
|
# self.data = self.data.movedim(-2, 0)
|
||||||
|
# self.angles = torch.zeros(self.data.shape[0])
|
||||||
|
...
|
||||||
# ...
|
# ...
|
||||||
# -> [no_slices, 2, 3, samples_per_slice]
|
# -> [no_slices, 2, 3, samples_per_slice]
|
||||||
|
|
||||||
@@ -305,69 +400,109 @@ class FiberRegenerationDataset(Dataset):
|
|||||||
# ...
|
# ...
|
||||||
# ] -> [no_slices, 2, 3, samples_per_slice]
|
# ] -> [no_slices, 2, 3, samples_per_slice]
|
||||||
|
|
||||||
...
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.data.shape[0]
|
return self.fiber_in.shape[0]
|
||||||
|
|
||||||
|
def add_noise(self, data, osnr):
|
||||||
|
osnr_lin = 10 ** (osnr / 10)
|
||||||
|
popt = torch.mean(data.abs().square().squeeze(), dim=-1)
|
||||||
|
noise = torch.randn_like(data)
|
||||||
|
pn = torch.mean(noise.abs().square().squeeze(), dim=-1)
|
||||||
|
|
||||||
|
mult = torch.sqrt(popt / (pn * osnr_lin))
|
||||||
|
mult = mult * torch.eye(popt.shape[0], device=mult.device)
|
||||||
|
mult = mult.to(dtype=noise.dtype)
|
||||||
|
|
||||||
|
noise = mult @ noise
|
||||||
|
pn = torch.mean(noise.abs().square().squeeze(), dim=-1)
|
||||||
|
noisy = data + noise
|
||||||
|
return noisy
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
if isinstance(idx, slice):
|
if isinstance(idx, slice):
|
||||||
return [self.__getitem__(i) for i in range(*idx.indices(len(self)))]
|
return [self.__getitem__(i) for i in range(*idx.indices(len(self)))]
|
||||||
else:
|
else:
|
||||||
data_slice = self.data[idx].squeeze()
|
# fiber in: [E_in_x, E_in_y, timestamps]
|
||||||
|
# fiber out: [E_out_x, E_out_y, timestamps, E_out_x_rot, E_out_y_rot, angle]
|
||||||
|
|
||||||
data_slice = data_slice[:, :, : data_slice.shape[2] // self.output_dim * self.output_dim]
|
# if self.polarisations:
|
||||||
|
output_dim = self.output_dim // 2
|
||||||
|
self.output_dim = output_dim * 2
|
||||||
|
|
||||||
data_slice = data_slice.view(data_slice.shape[0], data_slice.shape[1], self.output_dim, -1)
|
if not self.polarisations:
|
||||||
|
output_dim = 2 * output_dim
|
||||||
|
|
||||||
# if self.randomise_polarisations:
|
|
||||||
# angle = torch.rand(1) * torch.pi * 2
|
|
||||||
# sine = torch.sin(angle)
|
|
||||||
# cosine = torch.cos(angle)
|
|
||||||
# data_slice_ = data_slice[1]
|
|
||||||
# data_slice[1, 0] = data_slice_[0] * cosine - data_slice_[1] * sine
|
|
||||||
# data_slice[1,1] = data_slice_[0] * sine + data_slice_[1] * cosine
|
|
||||||
# else:
|
|
||||||
# angle = torch.zeros(1)
|
|
||||||
|
|
||||||
# data = data_slice[1, :2, :, 0]
|
fiber_in = self.fiber_in[idx].squeeze()
|
||||||
|
fiber_out = self.fiber_out[idx].squeeze()
|
||||||
|
|
||||||
angle = self.angles[idx]
|
fiber_in = fiber_in[..., : fiber_in.shape[-1] // output_dim * output_dim]
|
||||||
|
fiber_out = fiber_out[..., : fiber_out.shape[-1] // output_dim * output_dim]
|
||||||
|
|
||||||
data_index = 1
|
fiber_in = fiber_in.view(fiber_in.shape[0], output_dim, -1)
|
||||||
|
fiber_out = fiber_out.view(fiber_out.shape[0], output_dim, -1)
|
||||||
|
|
||||||
data_slice[1, :2, :, :] = self.rotate(data_slice[data_index, :2, :, :], angle)
|
center_angle = fiber_out[5, output_dim // 2, 0]
|
||||||
|
angles = fiber_out[5, :, 0]
|
||||||
|
plot_data_rot = fiber_out[3:5, output_dim // 2, 0].detach().clone()
|
||||||
|
data = fiber_out[0:2, :, 0]
|
||||||
|
plot_data = fiber_out[0:2, output_dim // 2, 0].detach().clone()
|
||||||
|
|
||||||
data = data_slice[1, :2, :, 0]
|
target = fiber_in[:2, output_dim // 2, 0]
|
||||||
# data = self.rotate(data, angle)
|
plot_target = fiber_in[:2, output_dim // 2, 0].detach().clone()
|
||||||
|
target_timestamp = fiber_in[2, output_dim // 2, 0].real
|
||||||
# for both polarisations (x/y), calculate the mean of the signal around the current symbol (-> corresponding to a lowpass filter)
|
|
||||||
angle_data = data_slice[1, :2, :, :].reshape(2, -1).mean(dim=1)
|
|
||||||
angle_data2 = self.complex_max(data_slice[1, :2, :, :].reshape(2, -1))
|
|
||||||
plot_data = data_slice[1, :2, self.output_dim // 2, 0]
|
|
||||||
sop = self.polarimeter(plot_data)
|
|
||||||
# angle_data = data_slice[1, :2, :self.output_dim//2, :].squeeze().reshape(2, -1).mean(dim=-1)
|
|
||||||
# angle = data_slice[1, 3, self.output_dim // 2, 0].real
|
|
||||||
target = data_slice[0, :2, self.output_dim // 2, 0]
|
|
||||||
target_timestamp = data_slice[0, 2, self.output_dim // 2, 0].real
|
|
||||||
...
|
...
|
||||||
|
|
||||||
# data_timestamps = data[-1,:].real
|
if self.polarisations:
|
||||||
# data = data[:-1, :]
|
rot = int(np.random.randint(2) * 2 - 1)
|
||||||
# target_timestamp = target[-1].real
|
data = rot * data
|
||||||
# target = target[:-1]
|
target = rot * target
|
||||||
# plot_data = plot_data[:-1]
|
plot_data_rot = rot * plot_data_rot
|
||||||
|
center_angle = center_angle + (rot - 1) * torch.pi / 2 # rot: -1 or 1 -> -2 or 0 -> -pi or 0
|
||||||
|
angles = angles + (rot - 1) * torch.pi / 2
|
||||||
|
|
||||||
|
pol_flipped_data = -data
|
||||||
|
pol_flipped_target = -target
|
||||||
|
|
||||||
# transpose to interleave the x and y data in the output tensor
|
# transpose to interleave the x and y data in the output tensor
|
||||||
data = data.transpose(0, 1).flatten().squeeze()
|
data = data.transpose(0, 1).flatten().squeeze()
|
||||||
angle_data = angle_data.flatten().squeeze()
|
data = data / torch.sqrt(torch.ones(1) * len(data)) # power loss due to splitting
|
||||||
angle_data2 = angle_data.flatten().squeeze()
|
pol_flipped_data = pol_flipped_data.transpose(0, 1).flatten().squeeze()
|
||||||
angle = angle.flatten().squeeze()
|
pol_flipped_data = pol_flipped_data / torch.sqrt(
|
||||||
|
torch.ones(1) * len(pol_flipped_data)
|
||||||
|
) # power loss due to splitting
|
||||||
|
# angle_data = angle_data.transpose(0, 1).flatten().squeeze()
|
||||||
|
# angle_data2 = angle_data2.transpose(0,1).flatten().squeeze()
|
||||||
|
center_angle = center_angle.flatten().squeeze()
|
||||||
|
angles = angles.flatten().squeeze()
|
||||||
# data_timestamps = data_timestamps.flatten().squeeze()
|
# data_timestamps = data_timestamps.flatten().squeeze()
|
||||||
|
# target = target.transpose(0,1).flatten().squeeze()
|
||||||
target = target.flatten().squeeze()
|
target = target.flatten().squeeze()
|
||||||
|
pol_flipped_target = pol_flipped_target.flatten().squeeze()
|
||||||
target_timestamp = target_timestamp.flatten().squeeze()
|
target_timestamp = target_timestamp.flatten().squeeze()
|
||||||
|
plot_target = plot_target.flatten().squeeze()
|
||||||
|
plot_data = plot_data.flatten().squeeze()
|
||||||
|
plot_data_rot = plot_data_rot.flatten().squeeze()
|
||||||
|
|
||||||
return {"x": data, "y": target, "angle": angle, "sop": sop, "angle_data": angle_data, "angle_data2": angle_data2, "timestamp": target_timestamp, "plot_data": plot_data}
|
return {
|
||||||
|
"x": data,
|
||||||
|
"x_flipped": pol_flipped_data,
|
||||||
|
"x_stacked": torch.cat([data, pol_flipped_data], dim=-1),
|
||||||
|
"y": target,
|
||||||
|
"y_flipped": pol_flipped_target,
|
||||||
|
"y_stacked": torch.cat([target, pol_flipped_target], dim=-1),
|
||||||
|
"center_angle": center_angle,
|
||||||
|
"angles": angles,
|
||||||
|
"mean_angle": angles.mean(),
|
||||||
|
# "sop": sop,
|
||||||
|
# "angle_data": angle_data,
|
||||||
|
# "angle_data2": angle_data2,
|
||||||
|
"timestamp": target_timestamp,
|
||||||
|
"plot_target": plot_target,
|
||||||
|
"plot_data": plot_data,
|
||||||
|
"plot_data_rot": plot_data_rot,
|
||||||
|
# "plot_clean": fiber_out_plot_clean,
|
||||||
|
}
|
||||||
|
|
||||||
def complex_max(self, data, dim=-1):
|
def complex_max(self, data, dim=-1):
|
||||||
# returns element(s) with the maximum absolute value along a given dimension
|
# returns element(s) with the maximum absolute value along a given dimension
|
||||||
@@ -376,7 +511,6 @@ class FiberRegenerationDataset(Dataset):
|
|||||||
# return max_values
|
# return max_values
|
||||||
return torch.gather(data, dim, torch.argmax(data.abs(), dim=dim, keepdim=True)).squeeze(dim=dim)
|
return torch.gather(data, dim, torch.argmax(data.abs(), dim=dim, keepdim=True)).squeeze(dim=dim)
|
||||||
|
|
||||||
|
|
||||||
def rotate(self, data, angle):
|
def rotate(self, data, angle):
|
||||||
# rotates a 2d tensor by a given angle
|
# rotates a 2d tensor by a given angle
|
||||||
# data: [2, ...]
|
# data: [2, ...]
|
||||||
@@ -389,6 +523,24 @@ class FiberRegenerationDataset(Dataset):
|
|||||||
|
|
||||||
return torch.stack([data[0] * cosine - data[1] * sine, data[0] * sine + data[1] * cosine], dim=0)
|
return torch.stack([data[0] * cosine - data[1] * sine, data[0] * sine + data[1] * cosine], dim=0)
|
||||||
|
|
||||||
|
def rotate_all(self):
|
||||||
|
def do_rotation(j, num_processes):
|
||||||
|
for i in range(len(self) // num_processes):
|
||||||
|
index = i * num_processes + j
|
||||||
|
self.data[index, 1, :2, :] = self.rotate(self.data[index, 1, :2, :], self.angles[index])
|
||||||
|
|
||||||
|
self.processes = []
|
||||||
|
|
||||||
|
for j in range(mp.cpu_count()):
|
||||||
|
self.processes.append(mp.Process(target=do_rotation, args=(j, mp.cpu_count())))
|
||||||
|
self.processes[-1].start()
|
||||||
|
|
||||||
|
for p in self.processes:
|
||||||
|
p.join()
|
||||||
|
|
||||||
|
for i in range(len(self) // mp.cpu_count() * mp.cpu_count(), len(self)):
|
||||||
|
self.data[i, 1, :2, :] = self.rotate(self.data[i, 1, :2, :], self.angles[i])
|
||||||
|
|
||||||
def polarimeter(self, data):
|
def polarimeter(self, data):
|
||||||
# data: [2, ...] -> x, y
|
# data: [2, ...] -> x, y
|
||||||
# returns [4] -> S0, S1, S2, S3
|
# returns [4] -> S0, S1, S2, S3
|
||||||
@@ -404,4 +556,4 @@ class FiberRegenerationDataset(Dataset):
|
|||||||
S2 = (2 * I_45 - S0) / S0
|
S2 = (2 * I_45 - S0) / S0
|
||||||
S3 = (2 * I_RHC - S0) / S0
|
S3 = (2 * I_RHC - S0) / S0
|
||||||
|
|
||||||
return torch.stack([S1, S2, S3], dim=0)
|
return torch.stack([S0, S1, S2, S3], dim=0)
|
||||||
|
|||||||
@@ -1,16 +1,23 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Literal
|
||||||
|
import h5py
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
from matplotlib.colors import LinearSegmentedColormap
|
from matplotlib.colors import LinearSegmentedColormap
|
||||||
|
# from cmap import Colormap as cm
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from scipy.cluster.vq import kmeans2
|
from scipy.cluster.vq import kmeans2
|
||||||
import warnings
|
import warnings
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
|
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
from rich import pretty
|
|
||||||
from rich import print
|
|
||||||
|
|
||||||
install()
|
install()
|
||||||
pretty.install()
|
# from rich import pretty
|
||||||
|
# from rich import print
|
||||||
|
|
||||||
|
# pretty.install()
|
||||||
|
|
||||||
|
|
||||||
def generate_sample_data(n_symbols=1000, sps=128, noise=0.01, skew=1):
|
def generate_sample_data(n_symbols=1000, sps=128, noise=0.01, skew=1):
|
||||||
@@ -21,6 +28,7 @@ def generate_sample_data(n_symbols=1000, sps=128, noise=0.01, skew=1):
|
|||||||
xaxis = np.arange(0, len(signal)) / sps
|
xaxis = np.arange(0, len(signal)) / sps
|
||||||
return np.vstack([xaxis, signal])
|
return np.vstack([xaxis, signal])
|
||||||
|
|
||||||
|
|
||||||
def create_symbol_sequence(n_symbols, skew=1):
|
def create_symbol_sequence(n_symbols, skew=1):
|
||||||
np.random.seed(42)
|
np.random.seed(42)
|
||||||
data = np.random.randint(0, 4, n_symbols) / 4
|
data = np.random.randint(0, 4, n_symbols) / 4
|
||||||
@@ -39,6 +47,14 @@ def generate_signal(data, sps):
|
|||||||
signal = np.convolve(data_padded, wavelet)
|
signal = np.convolve(data_padded, wavelet)
|
||||||
signal = np.cumsum(signal)
|
signal = np.cumsum(signal)
|
||||||
signal = signal[sps // 2 + len(wavelet) // 2 - 1 : -len(wavelet) // 2]
|
signal = signal[sps // 2 + len(wavelet) // 2 - 1 : -len(wavelet) // 2]
|
||||||
|
mi, ma = np.min(signal), np.max(signal)
|
||||||
|
|
||||||
|
signal = (signal - mi) / (ma - mi)
|
||||||
|
|
||||||
|
mod = 0.8
|
||||||
|
|
||||||
|
signal *= mod
|
||||||
|
signal += 1 - mod
|
||||||
|
|
||||||
return signal
|
return signal
|
||||||
|
|
||||||
@@ -49,8 +65,8 @@ def normalization_with_noise(signal, noise=0):
|
|||||||
signal += awgn
|
signal += awgn
|
||||||
|
|
||||||
# min-max normalization
|
# min-max normalization
|
||||||
signal = signal - np.min(signal)
|
# signal = signal - np.min(signal)
|
||||||
signal = signal / np.max(signal)
|
# signal = signal / np.max(signal)
|
||||||
return signal
|
return signal
|
||||||
|
|
||||||
|
|
||||||
@@ -68,27 +84,132 @@ def generate_wavelet(sps, oversample=3):
|
|||||||
|
|
||||||
|
|
||||||
class eye_diagram:
|
class eye_diagram:
|
||||||
def __init__(self, data, *, channel_names=None, horizontal_bins=256, vertical_bins=1000, n_levels=4, multithreaded=True):
|
def __init__(
|
||||||
|
self,
|
||||||
|
data,
|
||||||
|
*,
|
||||||
|
channel_names=None,
|
||||||
|
horizontal_bins=256,
|
||||||
|
vertical_bins=1000,
|
||||||
|
n_levels=4,
|
||||||
|
multithreaded=True,
|
||||||
|
save_file_or_dir=None,
|
||||||
|
):
|
||||||
# data has shape [channels, 2, samples]
|
# data has shape [channels, 2, samples]
|
||||||
# each sample has a timestamp and a value
|
# each sample has a timestamp and a value
|
||||||
if data.ndim == 2:
|
if data.ndim == 2:
|
||||||
data = data[np.newaxis, :, :]
|
data = data[np.newaxis, :, :]
|
||||||
self.channel_names = channel_names
|
|
||||||
self.raw_data = data
|
self.raw_data = data
|
||||||
self.channels = data.shape[0]
|
|
||||||
|
self.y_bins = np.zeros(1)
|
||||||
|
self.x_bins = np.zeros(1)
|
||||||
|
self.eye_data = np.zeros(1)
|
||||||
|
self.channel_names = channel_names
|
||||||
|
self.n_channels = data.shape[0]
|
||||||
self.n_levels = n_levels
|
self.n_levels = n_levels
|
||||||
self.eye_stats = [{"success": False} for _ in range(self.channels)]
|
self.eye_stats = [{"success": False} for _ in range(self.n_channels)]
|
||||||
self.horizontal_bins = horizontal_bins
|
self.horizontal_bins = horizontal_bins
|
||||||
self.vertical_bins = vertical_bins
|
self.vertical_bins = vertical_bins
|
||||||
self.multi_threaded = multithreaded
|
self.multi_threaded = multithreaded
|
||||||
|
self.analysed = False
|
||||||
self.eye_built = False
|
self.eye_built = False
|
||||||
self.analyse()
|
|
||||||
|
|
||||||
def generate_eye_data(self):
|
self.save_file = save_file_or_dir
|
||||||
|
|
||||||
|
def load_data(self, file=None):
|
||||||
|
file = self.save_file if file is None else file
|
||||||
|
|
||||||
|
if file is None:
|
||||||
|
raise FileNotFoundError("No file specified.")
|
||||||
|
|
||||||
|
self.save_file = str(file)
|
||||||
|
# self.file_or_dir = self.save_file
|
||||||
|
with h5py.File(file, "r") as infile:
|
||||||
|
self.y_bins = infile["y_bins"][:]
|
||||||
|
self.x_bins = infile["x_bins"][:]
|
||||||
|
self.eye_data = infile["eye_data"][:]
|
||||||
|
self.channel_names = infile.attrs["channel_names"]
|
||||||
|
self.n_channels = infile.attrs["n_channels"]
|
||||||
|
self.n_levels = infile.attrs["n_levels"]
|
||||||
|
self.eye_stats = infile.attrs["eye_stats"]
|
||||||
|
self.eye_stats = [json.loads(stat) for stat in self.eye_stats]
|
||||||
|
self.horizontal_bins = infile.attrs["horizontal_bins"]
|
||||||
|
self.vertical_bins = infile.attrs["vertical_bins"]
|
||||||
|
self.multi_threaded = infile.attrs["multithreaded"]
|
||||||
|
self.analysed = infile.attrs["analysed"]
|
||||||
|
self.eye_built = infile.attrs["eye_built"]
|
||||||
|
|
||||||
|
def save_data(self, file_or_dir=None):
|
||||||
|
file_or_dir = self.save_file if file_or_dir is None else file_or_dir
|
||||||
|
if file_or_dir is None:
|
||||||
|
file = Path(f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_eye_data.eye.h5")
|
||||||
|
elif Path(file_or_dir).is_dir():
|
||||||
|
file = Path(file_or_dir) / f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_eye_data.eye.h5"
|
||||||
|
else:
|
||||||
|
file = Path(file_or_dir)
|
||||||
|
|
||||||
|
# file.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
self.save_file = str(file)
|
||||||
|
|
||||||
|
with h5py.File(file, "w") as outfile:
|
||||||
|
outfile.create_dataset("eye_data", data=self.eye_data)
|
||||||
|
outfile.create_dataset("y_bins", data=self.y_bins)
|
||||||
|
outfile.create_dataset("x_bins", data=self.x_bins)
|
||||||
|
outfile.attrs["channel_names"] = self.channel_names
|
||||||
|
outfile.attrs["n_channels"] = self.n_channels
|
||||||
|
outfile.attrs["n_levels"] = self.n_levels
|
||||||
|
self.eye_stats = eye_diagram.convert_arrays(self.eye_stats)
|
||||||
|
outfile.attrs["eye_stats"] = [json.dumps(stat) for stat in self.eye_stats]
|
||||||
|
outfile.attrs["horizontal_bins"] = self.horizontal_bins
|
||||||
|
outfile.attrs["vertical_bins"] = self.vertical_bins
|
||||||
|
outfile.attrs["multithreaded"] = self.multi_threaded
|
||||||
|
outfile.attrs["analysed"] = self.analysed
|
||||||
|
outfile.attrs["eye_built"] = self.eye_built
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def convert_arrays(input_object):
|
||||||
|
"""
|
||||||
|
convert ndarrays in (nested) dict to lists
|
||||||
|
"""
|
||||||
|
|
||||||
|
if isinstance(input_object, np.ndarray):
|
||||||
|
return input_object.tolist()
|
||||||
|
elif isinstance(input_object, list):
|
||||||
|
return [eye_diagram.convert_arrays(old) for old in input_object]
|
||||||
|
elif isinstance(input_object, tuple):
|
||||||
|
return tuple(eye_diagram.convert_arrays(old) for old in input_object)
|
||||||
|
elif isinstance(input_object, dict):
|
||||||
|
dict_out = {}
|
||||||
|
for key, value in input_object.items():
|
||||||
|
dict_out[key] = eye_diagram.convert_arrays(value)
|
||||||
|
return dict_out
|
||||||
|
return input_object
|
||||||
|
|
||||||
|
def generate_eye_data(
|
||||||
|
self, mode: Literal["default", "load", "save", "nosave"] = "default", file_or_dir: Path | None | str = None
|
||||||
|
):
|
||||||
|
# modes:
|
||||||
|
# default: try to load eye data from file, if not found, generate and save
|
||||||
|
# load: try to load eye data from file, if not found, generate but don't save
|
||||||
|
# save: generate eye data and save
|
||||||
|
update_save = True
|
||||||
|
if mode == "load":
|
||||||
|
self.load_data(file_or_dir)
|
||||||
|
update_save = False
|
||||||
|
elif mode == "default":
|
||||||
|
try:
|
||||||
|
self.load_data(file_or_dir)
|
||||||
|
update_save = False
|
||||||
|
except (FileNotFoundError, IsADirectoryError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
if not self.eye_built:
|
||||||
|
update_save = True
|
||||||
self.x_bins = np.linspace(0, 2, self.horizontal_bins, endpoint=False)
|
self.x_bins = np.linspace(0, 2, self.horizontal_bins, endpoint=False)
|
||||||
self.y_bins = np.zeros((self.channels, self.vertical_bins))
|
self.y_bins = np.zeros((self.n_channels, self.vertical_bins))
|
||||||
self.eye_data = np.zeros((self.channels, self.vertical_bins, self.horizontal_bins))
|
self.eye_data = np.zeros((self.n_channels, self.vertical_bins, self.horizontal_bins))
|
||||||
datas = [self.raw_data[i] for i in range(self.channels)]
|
datas = [self.raw_data[i] for i in range(self.n_channels)]
|
||||||
if self.multi_threaded:
|
if self.multi_threaded:
|
||||||
with multiprocessing.Pool() as pool:
|
with multiprocessing.Pool() as pool:
|
||||||
results = pool.map(self.generate_eye_data_single, datas)
|
results = pool.map(self.generate_eye_data_single, datas)
|
||||||
@@ -99,67 +220,128 @@ class eye_diagram:
|
|||||||
self.eye_data[i], self.y_bins[i] = self.generate_eye_data_single(data)
|
self.eye_data[i], self.y_bins[i] = self.generate_eye_data_single(data)
|
||||||
self.eye_built = True
|
self.eye_built = True
|
||||||
|
|
||||||
|
if mode == "save" or (mode == "default" and update_save):
|
||||||
|
self.save_data(file_or_dir)
|
||||||
|
|
||||||
def generate_eye_data_single(self, data):
|
def generate_eye_data_single(self, data):
|
||||||
eye_data = np.zeros((self.vertical_bins, self.horizontal_bins))
|
eye_data = np.zeros((self.vertical_bins, self.horizontal_bins))
|
||||||
data_min = np.min(data[1, :])
|
data_min = np.min(data[1, :])
|
||||||
data_max = np.max(data[1, :])
|
data_max = np.max(data[1, :])
|
||||||
|
# round down/up to 1 decimal
|
||||||
|
data_min = np.floor(data_min*10)/10
|
||||||
|
data_max = np.ceil(data_max*10)/10
|
||||||
|
# data_range = data_max - data_min
|
||||||
|
# data_min -= 0.1 * data_range
|
||||||
|
# data_max += 0.1 * data_range
|
||||||
|
# data_min = -0.05
|
||||||
|
# data_max += 0.05
|
||||||
|
# data[1,:] -= np.min(data[1, :])
|
||||||
|
# data[1,:] /= np.max(data[1, :])
|
||||||
|
# data_min = 0
|
||||||
|
# data_max = 1
|
||||||
y_bins = np.linspace(data_min, data_max, self.vertical_bins, endpoint=False)
|
y_bins = np.linspace(data_min, data_max, self.vertical_bins, endpoint=False)
|
||||||
t_vals = data[0, :] % 2
|
t_vals = data[0, :] % 2 # + np.random.randn(*data[0, :].shape) * 1 / (512)
|
||||||
val_vals = data[1, :]
|
val_vals = data[1, :] # + np.random.randn(*data[1, :].shape) * 1 / (320)
|
||||||
x_indices = np.digitize(t_vals, self.x_bins) - 1
|
x_indices = np.digitize(t_vals, self.x_bins) - 1
|
||||||
y_indices = np.digitize(val_vals, y_bins) - 1
|
y_indices = np.digitize(val_vals, y_bins) - 1
|
||||||
np.add.at(eye_data, (y_indices, x_indices), 1)
|
np.add.at(eye_data, (y_indices, x_indices), 1)
|
||||||
return eye_data, y_bins
|
return eye_data, y_bins
|
||||||
|
|
||||||
def plot(self, title="Eye Diagram", stats=True, all_stats=True, show=True):
|
def plot(
|
||||||
|
self,
|
||||||
|
title="Eye Diagram",
|
||||||
|
stats=True,
|
||||||
|
all_stats=True,
|
||||||
|
show=True,
|
||||||
|
mode: Literal["default", "load", "save", "nosave"] = "default",
|
||||||
|
# save_images = False,
|
||||||
|
# image_dir = None,
|
||||||
|
# cmap=None,
|
||||||
|
):
|
||||||
|
if stats and not self.analysed:
|
||||||
|
self.analyse(mode=mode)
|
||||||
if not self.eye_built:
|
if not self.eye_built:
|
||||||
self.generate_eye_data()
|
self.generate_eye_data(mode=mode)
|
||||||
cmap = LinearSegmentedColormap.from_list(
|
cmap = LinearSegmentedColormap.from_list(
|
||||||
"eyemap",
|
"eyemap",
|
||||||
[(0, "white"), (0.1, "blue"), (0.2, "cyan"), (0.5, "green"), (0.8, "yellow"), (0.9, "red"), (1, "magenta")],
|
[
|
||||||
|
(0, "#FFFFFF00"),
|
||||||
|
(0.1, "blue"),
|
||||||
|
(0.2, "cyan"),
|
||||||
|
(0.5, "green"),
|
||||||
|
(0.8, "yellow"),
|
||||||
|
(0.9, "red"),
|
||||||
|
(1, "magenta"),
|
||||||
|
],
|
||||||
)
|
)
|
||||||
if self.channels % 2 == 0:
|
# cmap = cm('google:turbo_r' if cmap is None else cmap)
|
||||||
|
# first = cmap(-1)
|
||||||
|
# cmap = cmap.to_mpl()
|
||||||
|
# cmap.set_under(first, alpha=0)
|
||||||
|
if self.n_channels % 2 == 0:
|
||||||
rows = 2
|
rows = 2
|
||||||
cols = self.channels // 2
|
cols = self.n_channels // 2
|
||||||
else:
|
else:
|
||||||
cols = int(np.ceil(np.sqrt(self.channels)))
|
cols = int(np.ceil(np.sqrt(self.n_channels)))
|
||||||
rows = int(np.ceil(self.channels / cols))
|
rows = int(np.ceil(self.n_channels / cols))
|
||||||
fig, ax = plt.subplots(rows, cols, sharex=True, sharey=False)
|
fig, ax = plt.subplots(rows, cols, sharex=True, sharey=False)
|
||||||
fig.suptitle(title)
|
fig.suptitle(title)
|
||||||
|
fig.tight_layout()
|
||||||
ax = np.atleast_1d(ax).transpose().flatten()
|
ax = np.atleast_1d(ax).transpose().flatten()
|
||||||
for i in range(self.channels):
|
for i in range(self.n_channels):
|
||||||
ax[i].set_title(self.channel_names[i] if self.channel_names is not None else f"Channel {i + 1}")
|
ax[i].set_title(self.channel_names[i] if self.channel_names is not None else f"Channel {i + 1}")
|
||||||
if (i + 1) % rows == 0:
|
if (i + 1) % rows == 0:
|
||||||
ax[i].set_xlabel("Symbol")
|
ax[i].set_xlabel("Symbol")
|
||||||
if i < rows:
|
if i < rows:
|
||||||
ax[i].set_ylabel("Amplitude")
|
ax[i].set_ylabel("Amplitude")
|
||||||
ax[i].grid()
|
ax[i].grid()
|
||||||
|
ax[i].set_axisbelow(True)
|
||||||
ax[i].imshow(
|
ax[i].imshow(
|
||||||
self.eye_data[i],
|
self.eye_data[i] - 0.1,
|
||||||
origin="lower",
|
origin="lower",
|
||||||
aspect="auto",
|
aspect="auto",
|
||||||
cmap=cmap,
|
cmap=cmap,
|
||||||
extent=[self.x_bins[0], self.x_bins[-1], self.y_bins[i][0], self.y_bins[i][-1]],
|
extent=[self.x_bins[0], self.x_bins[-1], self.y_bins[i][0], self.y_bins[i][-1]],
|
||||||
|
interpolation="gaussian",
|
||||||
|
vmin=0,
|
||||||
|
zorder=3,
|
||||||
)
|
)
|
||||||
ax[i].set_xlim((self.x_bins[0], self.x_bins[-1]))
|
ax[i].set_xlim((self.x_bins[0], self.x_bins[-1]))
|
||||||
ymin = np.min(self.y_bins[:, 0])
|
ymin = np.min(self.y_bins[:, 0])
|
||||||
ymax = np.max(self.y_bins[:, -1])
|
ymax = np.max(self.y_bins[:, -1])
|
||||||
yspan = ymax - ymin
|
yspan = ymax - ymin
|
||||||
ax[i].set_ylim((ymin - 0.1 * yspan, ymax + 0.1 * yspan))
|
ax[i].set_ylim((ymin - 0.1 * yspan, ymax + 0.1 * yspan))
|
||||||
|
# if save_images:
|
||||||
|
# image_dir = "images_out" if image_dir is None else image_dir
|
||||||
|
# image_path = Path(image_dir) / (slugify(f"{datetime.now().strftime("%Y%m%d_%H%M%S")}_{title.replace(" ","_")}_{self.channel_names[i].replace(" ", "_") if self.channel_names is not None else f"{i + 1}"}_{ymin:.1f}_{ymax:.1f}") + ".png")
|
||||||
|
# image_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
# # plt.imsave(
|
||||||
|
# # image_path,
|
||||||
|
# # self.eye_data[i] - 0.1,
|
||||||
|
# # origin="lower",
|
||||||
|
# # # aspect="auto",
|
||||||
|
# # cmap=cmap,
|
||||||
|
# # # extent=[self.x_bins[0], self.x_bins[-1], self.y_bins[i][0], self.y_bins[i][-1]],
|
||||||
|
# # # interpolation="gaussian",
|
||||||
|
# # vmin=0,
|
||||||
|
# # # zorder=3,
|
||||||
|
# # )
|
||||||
if stats and self.eye_stats[i]["success"]:
|
if stats and self.eye_stats[i]["success"]:
|
||||||
# add min_area above the plot
|
# # add min_area above the plot
|
||||||
ax[i].annotate(
|
# ax[i].annotate(
|
||||||
f"Min Area: {self.eye_stats[i]['min_area']:.2e}",
|
# f"Min Area: {self.eye_stats[i]['min_area']:.2e}",
|
||||||
xy=(0.05, ymax + 0.05 * yspan),
|
# xy=(0.05, ymax + 0.05 * yspan),
|
||||||
# xycoords="axes fraction",
|
# # xycoords="axes fraction",
|
||||||
ha="left",
|
# ha="left",
|
||||||
va="center",
|
# va="center",
|
||||||
bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"),
|
# bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"),
|
||||||
)
|
# )
|
||||||
|
|
||||||
if all_stats:
|
if all_stats:
|
||||||
ax[i].plot([0, 2], [self.eye_stats[i]["levels"], self.eye_stats[i]["levels"]], "k--")
|
ax[i].plot([0, 2], [self.eye_stats[i]["levels"], self.eye_stats[i]["levels"]], "k--")
|
||||||
ax[i].set_yticks(self.eye_stats[i]["levels"])
|
y_ticks = (*self.eye_stats[i]["levels"], *self.eye_stats[i]["thresholds"])
|
||||||
|
# y_ticks = np.sort(y_ticks)
|
||||||
|
ax[i].set_yticks(y_ticks)
|
||||||
# add arrows for amplitudes
|
# add arrows for amplitudes
|
||||||
for j in range(len(self.eye_stats[i]["amplitudes"])):
|
for j in range(len(self.eye_stats[i]["amplitudes"])):
|
||||||
ax[i].annotate(
|
ax[i].annotate(
|
||||||
@@ -193,35 +375,35 @@ class eye_diagram:
|
|||||||
except (ValueError, IndexError):
|
except (ValueError, IndexError):
|
||||||
pass
|
pass
|
||||||
# add arrows for eye widths
|
# add arrows for eye widths
|
||||||
for j in range(len(self.eye_stats[i]["widths"])):
|
# for j in range(len(self.eye_stats[i]["widths"])):
|
||||||
try:
|
# try:
|
||||||
left = np.max(self.eye_stats[i]["time_clusters"][j][0])
|
# left = np.max(self.eye_stats[i]["time_clusters"][j][0])
|
||||||
right = np.min(self.eye_stats[i]["time_clusters"][j][1])
|
# right = np.min(self.eye_stats[i]["time_clusters"][j][1])
|
||||||
vertical = (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2
|
# vertical = (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2
|
||||||
|
|
||||||
ax[i].annotate(
|
# ax[i].annotate(
|
||||||
"",
|
# "",
|
||||||
xy=(left, vertical),
|
# xy=(left, vertical),
|
||||||
xytext=(right, vertical),
|
# xytext=(right, vertical),
|
||||||
arrowprops=dict(arrowstyle="<->", facecolor="black"),
|
# arrowprops=dict(arrowstyle="<->", facecolor="black"),
|
||||||
)
|
# )
|
||||||
ax[i].annotate(
|
# ax[i].annotate(
|
||||||
f"{self.eye_stats[i]['widths'][j]:.2e}",
|
# f"{self.eye_stats[i]['widths'][j]:.2e}",
|
||||||
xy=((left + right) / 2 - 0.15, vertical + 0.01),
|
# xy=((left + right) / 2 - 0.15, vertical + 0.01),
|
||||||
bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"),
|
# bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"),
|
||||||
)
|
# )
|
||||||
except (ValueError, IndexError):
|
# except (ValueError, IndexError):
|
||||||
pass
|
# pass
|
||||||
|
|
||||||
# add area
|
# # add area
|
||||||
for j in range(len(self.eye_stats[i]["areas"])):
|
# for j in range(len(self.eye_stats[i]["areas"])):
|
||||||
horizontal = self.eye_stats[i]["time_midpoint"]
|
# horizontal = self.eye_stats[i]["time_midpoint"]
|
||||||
vertical = (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2
|
# vertical = (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2
|
||||||
ax[i].annotate(
|
# ax[i].annotate(
|
||||||
f"{self.eye_stats[i]['areas'][j]:.2e}",
|
# f"{self.eye_stats[i]['areas'][j]:.2e}",
|
||||||
xy=(horizontal + 0.035, vertical - 0.07),
|
# xy=(horizontal + 0.035, vertical - 0.07),
|
||||||
bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"),
|
# bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"),
|
||||||
)
|
# )
|
||||||
|
|
||||||
fig.tight_layout()
|
fig.tight_layout()
|
||||||
|
|
||||||
@@ -229,6 +411,12 @@ class eye_diagram:
|
|||||||
plt.show()
|
plt.show()
|
||||||
return fig
|
return fig
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def calculate_thresholds(levels):
|
||||||
|
ret = np.cumsum(levels, dtype=float)
|
||||||
|
ret[2:] = ret[2:] - ret[:-2]
|
||||||
|
return ret[1:] / 2
|
||||||
|
|
||||||
def analyse_single(self, data, index):
|
def analyse_single(self, data, index):
|
||||||
warnings.filterwarnings("error")
|
warnings.filterwarnings("error")
|
||||||
eye_stats = {}
|
eye_stats = {}
|
||||||
@@ -238,17 +426,18 @@ class eye_diagram:
|
|||||||
|
|
||||||
time_bounds = eye_diagram.calculate_time_bounds(data, approx_levels)
|
time_bounds = eye_diagram.calculate_time_bounds(data, approx_levels)
|
||||||
|
|
||||||
eye_stats["time_midpoint"] = (time_bounds[0] + time_bounds[1]) / 2
|
eye_stats["time_midpoint"] = float((time_bounds[0] + time_bounds[1]) / 2)
|
||||||
|
# eye_stats["time_midpoint"] = 1.0
|
||||||
|
|
||||||
eye_stats["levels"], eye_stats["amplitude_clusters"] = eye_diagram.calculate_levels(
|
eye_stats["levels"], eye_stats["amplitude_clusters"] = eye_diagram.calculate_levels(
|
||||||
data, approx_levels, time_bounds
|
data, approx_levels, time_bounds
|
||||||
)
|
)
|
||||||
|
|
||||||
|
eye_stats["thresholds"] = self.calculate_thresholds(eye_stats["levels"])
|
||||||
|
|
||||||
eye_stats["amplitudes"] = np.diff(eye_stats["levels"])
|
eye_stats["amplitudes"] = np.diff(eye_stats["levels"])
|
||||||
|
|
||||||
eye_stats["heights"] = eye_diagram.calculate_eye_heights(
|
eye_stats["heights"] = eye_diagram.calculate_eye_heights(eye_stats["amplitude_clusters"])
|
||||||
eye_stats["amplitude_clusters"]
|
|
||||||
)
|
|
||||||
|
|
||||||
eye_stats["widths"], eye_stats["time_clusters"] = eye_diagram.calculate_eye_widths(
|
eye_stats["widths"], eye_stats["time_clusters"] = eye_diagram.calculate_eye_widths(
|
||||||
data, eye_stats["levels"]
|
data, eye_stats["levels"]
|
||||||
@@ -260,36 +449,59 @@ class eye_diagram:
|
|||||||
# if not (np.max(eye_stats['time_clusters'][j][0]) < eye_stats["time_midpoint"] < np.min(eye_stats['time_clusters'][j][1])):
|
# if not (np.max(eye_stats['time_clusters'][j][0]) < eye_stats["time_midpoint"] < np.min(eye_stats['time_clusters'][j][1])):
|
||||||
# raise ValueError
|
# raise ValueError
|
||||||
|
|
||||||
eye_stats["areas"] = eye_stats["heights"] * eye_stats["widths"]
|
# eye_stats["areas"] = eye_stats["heights"] * eye_stats["widths"]
|
||||||
eye_stats["mean_area"] = np.mean(eye_stats["areas"])
|
# eye_stats["mean_area"] = np.mean(eye_stats["areas"])
|
||||||
eye_stats["min_area"] = np.min(eye_stats["areas"])
|
# eye_stats["min_area"] = np.min(eye_stats["areas"])
|
||||||
|
|
||||||
eye_stats["success"] = True
|
eye_stats["success"] = True
|
||||||
except (RuntimeWarning, UserWarning, ValueError):
|
except (RuntimeWarning, UserWarning, ValueError):
|
||||||
eye_stats["success"] = False
|
eye_stats["success"] = False
|
||||||
eye_stats["time_midpoint"] = 0
|
eye_stats["time_midpoint"] = None
|
||||||
eye_stats["levels"] = np.zeros(self.n_levels)
|
eye_stats["levels"] = None
|
||||||
eye_stats["amplitude_clusters"] = []
|
eye_stats["thresholds"] = None
|
||||||
eye_stats["amplitudes"] = np.zeros(self.n_levels - 1)
|
eye_stats["amplitude_clusters"] = None
|
||||||
eye_stats["heights"] = np.zeros(self.n_levels - 1)
|
eye_stats["amplitudes"] = None
|
||||||
eye_stats["widths"] = np.zeros(self.n_levels - 1)
|
eye_stats["heights"] = None
|
||||||
eye_stats["areas"] = np.zeros(self.n_levels - 1)
|
eye_stats["widths"] = None
|
||||||
eye_stats["mean_area"] = 0
|
# eye_stats["areas"] = np.zeros(self.n_levels - 1)
|
||||||
eye_stats["min_area"] = 0
|
# eye_stats["mean_area"] = 0
|
||||||
|
# eye_stats["min_area"] = 0
|
||||||
warnings.resetwarnings()
|
warnings.resetwarnings()
|
||||||
return eye_stats
|
return eye_stats
|
||||||
|
|
||||||
|
def analyse(
|
||||||
|
self, mode: Literal["default", "load", "save", "nosave"] = "default", file_or_dir: Path | None | str = None
|
||||||
|
):
|
||||||
|
# modes:
|
||||||
|
# default: try to load eye data from file, if not found, generate and save
|
||||||
|
# load: try to load eye data from file, if not found, generate but don't save
|
||||||
|
# save: generate eye data and save
|
||||||
|
update_save = True
|
||||||
|
if mode == "load":
|
||||||
|
self.load_data(file_or_dir)
|
||||||
|
update_save = False
|
||||||
|
elif mode == "default":
|
||||||
|
try:
|
||||||
|
self.load_data(file_or_dir)
|
||||||
|
update_save = False
|
||||||
|
except (FileNotFoundError, IsADirectoryError):
|
||||||
|
pass
|
||||||
|
|
||||||
def analyse(self):
|
if not self.analysed:
|
||||||
|
update_save = True
|
||||||
self.eye_stats = []
|
self.eye_stats = []
|
||||||
if self.multi_threaded:
|
if self.multi_threaded:
|
||||||
with multiprocessing.Pool() as pool:
|
with multiprocessing.Pool() as pool:
|
||||||
results = pool.starmap(self.analyse_single, [(self.raw_data[i], i) for i in range(self.channels)])
|
results = pool.starmap(self.analyse_single, [(self.raw_data[i], i) for i in range(self.n_channels)])
|
||||||
for i, result in enumerate(results):
|
for i, result in enumerate(results):
|
||||||
self.eye_stats.append(result)
|
self.eye_stats.append(result)
|
||||||
else:
|
else:
|
||||||
for i in range(self.channels):
|
for i in range(self.n_channels):
|
||||||
self.eye_stats.append(self.analyse_single(self.raw_data[i], i))
|
self.eye_stats.append(self.analyse_single(self.raw_data[i], i))
|
||||||
|
self.analysed = True
|
||||||
|
|
||||||
|
if mode == "save" or (mode == "default" and update_save):
|
||||||
|
self.save_data(file_or_dir)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def approximate_levels(data, levels):
|
def approximate_levels(data, levels):
|
||||||
@@ -431,7 +643,7 @@ class eye_diagram:
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
length = int(2**14)
|
length = int(2**16)
|
||||||
# data = generate_sample_data(length, noise=1)
|
# data = generate_sample_data(length, noise=1)
|
||||||
# data1 = generate_sample_data(length, noise=0.01)
|
# data1 = generate_sample_data(length, noise=0.01)
|
||||||
# data2 = generate_sample_data(length, noise=0.01, skew=1.2)
|
# data2 = generate_sample_data(length, noise=0.01, skew=1.2)
|
||||||
@@ -439,12 +651,13 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# data = np.stack([data, data1, data2, data3])
|
# data = np.stack([data, data1, data2, data3])
|
||||||
|
|
||||||
data = generate_sample_data(length, noise=0.005)
|
data = generate_sample_data(length, noise=0.0000)
|
||||||
eye = eye_diagram(data, horizontal_bins=256, vertical_bins=256)
|
eye = eye_diagram(data, horizontal_bins=1056, vertical_bins=1200)
|
||||||
attrs = ("success", "amplitudes", "time_midpoint", "levels", "heights", "widths", "area", "mean_area", "min_area")
|
eye.plot(mode="nosave", stats=False)
|
||||||
for i, channel in enumerate(eye.eye_stats):
|
# attrs = ("success", "amplitudes", "time_midpoint", "levels", "heights", "widths")#, "area", "mean_area", "min_area")
|
||||||
print(f"Channel {i}")
|
# for i, channel in enumerate(eye.eye_stats):
|
||||||
print_data = {attr: channel[attr] for attr in attrs}
|
# print(f"Channel {i}")
|
||||||
print(print_data)
|
# print_data = {attr: channel[attr] for attr in attrs}
|
||||||
|
# print(print_data)
|
||||||
|
|
||||||
eye.plot()
|
# eye.plot()
|
||||||
|
|||||||
122
src/single-core-regen/util/mpl.py
Normal file
122
src/single-core-regen/util/mpl.py
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
# Copyright (c) 2015, Warren Weckesser. All rights reserved.
|
||||||
|
# This software is licensed according to the "BSD 2-clause" license.
|
||||||
|
|
||||||
|
# modified by Joseph Hopfmüller in 2025,
|
||||||
|
# for integration into optical regeneration analysis scripts
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from matplotlib.colors import LinearSegmentedColormap
|
||||||
|
import matplotlib.colors as colors
|
||||||
|
import numpy as _np
|
||||||
|
from .core import grid_count as _grid_count
|
||||||
|
import matplotlib.pyplot as _plt
|
||||||
|
import numpy as np
|
||||||
|
from scipy.ndimage import gaussian_filter
|
||||||
|
|
||||||
|
|
||||||
|
# from ._common import _common_doc
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["eyediagram"] # , 'eyediagram_lines']
|
||||||
|
|
||||||
|
|
||||||
|
# def eyediagram_lines(y, window_size, offset=0, **plotkwargs):
|
||||||
|
# """
|
||||||
|
# Plot an eye diagram using matplotlib by repeatedly calling the `plot`
|
||||||
|
# function.
|
||||||
|
# <common>
|
||||||
|
|
||||||
|
# """
|
||||||
|
# start = offset
|
||||||
|
# while start < len(y):
|
||||||
|
# end = start + window_size
|
||||||
|
# if end > len(y):
|
||||||
|
# end = len(y)
|
||||||
|
# yy = y[start:end+1]
|
||||||
|
# _plt.plot(_np.arange(len(yy)), yy, 'k', **plotkwargs)
|
||||||
|
# start = end
|
||||||
|
|
||||||
|
# eyediagram_lines.__doc__ = eyediagram_lines.__doc__.replace("<common>",
|
||||||
|
# _common_doc)
|
||||||
|
|
||||||
|
|
||||||
|
eyemap = LinearSegmentedColormap.from_list(
|
||||||
|
"eyemap",
|
||||||
|
[
|
||||||
|
(0, "#0000FF00"),
|
||||||
|
(0.1, "blue"),
|
||||||
|
(0.2, "cyan"),
|
||||||
|
(0.5, "green"),
|
||||||
|
(0.8, "yellow"),
|
||||||
|
(0.9, "red"),
|
||||||
|
(1, "magenta"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def eyediagram(
|
||||||
|
y,
|
||||||
|
window_size,
|
||||||
|
offset=0,
|
||||||
|
colorbar=False,
|
||||||
|
show=False,
|
||||||
|
save_im=False,
|
||||||
|
overwrite=False,
|
||||||
|
blur: int | bool = True,
|
||||||
|
save_path="out.png",
|
||||||
|
bounds=None,
|
||||||
|
**imshowkwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Plot an eye diagram using matplotlib by creating an image and calling
|
||||||
|
the `imshow` function.
|
||||||
|
<common>
|
||||||
|
"""
|
||||||
|
if bounds is None:
|
||||||
|
ymax = y.max()
|
||||||
|
ymin = y.min()
|
||||||
|
yamp = ymax - ymin
|
||||||
|
ymin = ymin - 0.05 * yamp
|
||||||
|
ymax = ymax + 0.05 * yamp
|
||||||
|
ymin = np.floor(ymin * 10) / 10
|
||||||
|
ymax = np.ceil(ymax * 10) / 10
|
||||||
|
bounds = (ymin, ymax)
|
||||||
|
counts = _grid_count(y, window_size, offset, bounds=bounds, size=(1000, 1200), blur=int(blur))
|
||||||
|
counts = counts.astype(_np.float32)
|
||||||
|
origin = imshowkwargs.pop("origin", "lower")
|
||||||
|
cmap: colors.Colormap = imshowkwargs.pop("cmap", eyemap)
|
||||||
|
vmin = imshowkwargs.pop("vmin", 1)
|
||||||
|
vmax = imshowkwargs.pop("vmax", None)
|
||||||
|
cmap.set_under("white", alpha=0)
|
||||||
|
|
||||||
|
if show:
|
||||||
|
_plt.imshow(
|
||||||
|
counts.T[::-1, :],
|
||||||
|
extent=[0, 2, *bounds],
|
||||||
|
origin=origin,
|
||||||
|
cmap=cmap,
|
||||||
|
vmin=vmin,
|
||||||
|
vmax=vmax,
|
||||||
|
**imshowkwargs,
|
||||||
|
)
|
||||||
|
_plt.grid()
|
||||||
|
if colorbar:
|
||||||
|
_plt.colorbar()
|
||||||
|
|
||||||
|
if Path(save_path).is_file() and not overwrite:
|
||||||
|
save_im = False
|
||||||
|
if save_im:
|
||||||
|
from PIL import Image
|
||||||
|
arr = counts.T[::-1, :]
|
||||||
|
if origin == "lower":
|
||||||
|
arr = arr[::-1]
|
||||||
|
arr = (arr-arr.min())/(arr.max()-arr.min())
|
||||||
|
image = Image.fromarray((cmap(arr)[:, :, :] * 255).astype(np.uint8))
|
||||||
|
image.save(save_path)
|
||||||
|
# print("-")
|
||||||
|
|
||||||
|
if show:
|
||||||
|
_plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
# eyediagram.__doc__ = eyediagram.__doc__.replace("<common>", _common_doc)
|
||||||
@@ -1,5 +1,8 @@
|
|||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from datasets import load_data
|
||||||
|
else:
|
||||||
from .datasets import load_data
|
from .datasets import load_data
|
||||||
|
|
||||||
def eye(*, path=None, data=None, sps=None, title=None, symbols=1000, skipfirst=0, width=2, alpha=None, complex=False, show=True):
|
def eye(*, path=None, data=None, sps=None, title=None, symbols=1000, skipfirst=0, width=2, alpha=None, complex=False, show=True):
|
||||||
@@ -20,6 +23,7 @@ def eye(*, path=None, data=None, sps=None, title=None, symbols=1000, skipfirst=0
|
|||||||
raise ValueError("Either path or data and sps must be given.")
|
raise ValueError("Either path or data and sps must be given.")
|
||||||
if path is not None:
|
if path is not None:
|
||||||
data, config = load_data(path, skipfirst, symbols)
|
data, config = load_data(path, skipfirst, symbols)
|
||||||
|
data = data.detach().cpu().numpy()[:, :4]
|
||||||
sps = int(config["glova"]["sps"])
|
sps = int(config["glova"]["sps"])
|
||||||
if sps is None:
|
if sps is None:
|
||||||
raise ValueError("sps not set.")
|
raise ValueError("sps not set.")
|
||||||
@@ -71,3 +75,6 @@ def eye(*, path=None, data=None, sps=None, title=None, symbols=1000, skipfirst=0
|
|||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
return fig
|
return fig
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
eye(path="data/20241229-163838-128-16384-50000-0-0.2-16.8-0.058-PAM4-0-0.16.ini", symbols=1000, width=2, alpha=0.1, complex=False)
|
||||||
Reference in New Issue
Block a user