Compare commits

..

10 Commits

94 changed files with 88490 additions and 1344 deletions

2
.gitignore vendored
View File

@@ -163,3 +163,5 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
tolerance_results/*
data/*

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:f86feab7455c42e0234090b5b3461341b2a3637414356863b92443e8970692a4
size 599

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:ac7eb8830fa7bb91e18ae87ab7b55d463738dbfd1f5662957c06c69ac98a214f
size 599

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:670487527d53db564874c6bdbe057fcca48e19963275af92690a9f22e22764b2
size 616

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b2cf718ae05cebb5bdcbe4bb9e7924fc0856c7f236ec79b3a4e0cc0f3e84bc72
size 616

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:edc005000a97d3eba4f573d35736c57350f17ff69017acf0c873413d2b824abe
size 616

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:91e4b00d042393e967d783605da22ddea0846a4ff7328354033eb47fa96e6055
size 616

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:9495d7059e92069b4182e1096b4690e5bc2122a4c7ac3cde84d3fd2e71537405
size 616

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:db22dc55b809c359a432294308b3cd70c861ff7f67b1002cd296f2e05e58da53
size 616

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d78908aeda67365c2efee3efcbf71d05a91c2391b54893a4f6b060a7d2327714
size 616

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b75f51f0d35df1953c29874c2e3c111ccacdff624ce76aa13f24f37fa8f37997
size 616

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:f7e942995d3ddf4009a93d3f4178457a2219822e8790e0cc73abd4cae0429333
size 616

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:2befcef5adb9f8206292b443278d603f02575b8ffcd2e93f3d6d0db8eb009798
size 616

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c96f91e93524cb49e8cad67da973989cd66bfb56d04c370c770c6ed2610b35f3
size 618

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:2ef989d392f318b37fecabeb7e5fa24db14c8867227862153736eeae908ef30f
size 618

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:44ffed0e7387f70fc8d524d9d3d9587017860752b6199cc78545bcb6692e68af
size 618

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:27191a866d99a3ab81959bc93f496e9d656b59e13ec77f2ff393079b468765ae
size 618

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d3c61cf0f502298db563d97bd3aadd3694e5834434377ff28f553de0e78873ee
size 618

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:bc02b0099ea3bb136733e3d20817cad79b6c50c2e4b845f0d206455dde188cc4
size 615

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d80ff6f2a84acf973fbdf81a05ed0b1902f8bf97856cd5132b646f6b1173f496
size 615

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:54ac6b6a452aa6b7d312a4c8ab8f7ebe2f96c1c4170cbc56147e8f2f9d934ad6
size 615

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:bd2c6f4050488b6e857759d48aa1f1f37399d81cee1667d3668145e938d17c83
size 615

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c59b8113092459a8751b385a7b1a6f10828626d2ec2f29775512157fd9bbc75c
size 615

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:04eaa2a29b3302e5de3bf99bf6a57fc8f27361fd3df3cac9245e25ab99324829
size 615

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:704d3b0b17b9d320f4717b5a5a8bdbc5714f3caa4efa7153e980766429e834f4
size 615

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:29304f35a88fd777566105f8666fff1c9927beb32756822365bcf9c159feb98e
size 615

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:4a87488c12e0253b2bb5d1ac7aa1536f69ef569e62c2aab6a10149d753e049b4
size 615

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:ff47c8d5413881edb03cfadfcde3b550ef7089543615a467c4f0027edaf1455e
size 615

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:9e05a0a54c7e3aaffeca0ac90cce1b274a544d90b329a93d453392f5df4e91a8
size 616

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:25e45b06b551ab8031c2159030d658999fcb3d1f0a34538c90768b94c8116771
size 616

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:707ed73713b6c2d80d4e333b1ccdf4650f50aefefff58ddb471c1d5411954b3d
size 616

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a8c3baf878943741d83835c1afed05c1f9780ff3f0df260c0d92706343e59c50
size 616

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:98878d09510dedc24f1429bbd12cce49c66be6c9d279a28765b120efe820a171
size 616

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:fcbdaffa211d6b0b44b3ae1c66645999e95901bfdb2fffee4c45e34a0d901ee1
size 649

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:73e922310068a66ab1a0c3d39b0b0fd7db798f2637b199a8c4bd113a38bb28c8
size 134217856

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:43e80dc7d21aeff62c73f0ed02b20a4ac9b573d0e131a3ab8d1471077e03634b
size 134217856

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:bfe3303d51a74602dfeb8c185f1cd49f39811fbda5e72399f8e112bc7fc7d5bc
size 134217856

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:074c07bf5673060e5fe26162a01959030595671e10d607f80df7ff0e035c8f7b
size 134217856

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:05ca3cddb57c739ba57c6a5a871607715cb926d87d6ececbcb385c6f11ad5690
size 134217856

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:798a9cb026a88a336d612e28033354f00c412e74bea98672ae4d4dd276af97be
size 134217856

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:50477064b1f8f9d0629a0cc294cf4195ea4bffb0f853df1c9e1177723d0ac9a6
size 134217856

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:fb74cfbeec54b4f263c08510312881527fc7e484604fa0a1213b596f175fecc2
size 134217856

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:4bd3a9265ac0200b04a2f5f095edefbbf7df18436cf1066813c2eb93b234f5f3
size 134217856

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:447cca0af25e309c8be61216cea4fb2d3a8a967b0522760f1dab8f15e6b41574
size 134217856

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:63321906920de825fc7857aa9f2e4944c3b32f3eadf99da06b66f6599924bc4c
size 134217856

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:420fffd570da7bce1912a75d0d3d1cafb6faa0029701b204d9054914fa5d499a
size 134217856

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:92e3fcc41f05380cb7b334fefc0a30bb8c1dfd9ca5c3b2cfaad36b0c7093914e
size 134217856

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:0bf671d666d35617edd7bfb58548a5120bd92e5f9cb6edb4b5fc8d3bf5db8987
size 134217856

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:bd429a5776555671b19d42e3ae152026dafd5bf95aeed9847df1432ed37f3eba
size 134217856

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:4c39a10ad5ff46d977afbf84c5c6ecce4c7c0f8e37faf00c1a3b7a264f01b1cd
size 134217856

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:1df90745cc2e6d4b0ad964fca2de1441e6e0b4b8345fbb0fbc1ffe9820674269
size 134481920

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:1953d4308a837de3e976f9568cd571925e1c377746013220e263bbc1454b3be6
size 134217856

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:477365edd696f66610298198257422a017f825e1f8bec4363bdfb0da0d741ebc
size 134217856

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:1814f8ae0e6cdb69c0030741a3c6b35c74f2d6985eed34c0d5b4135384014abc
size 134217856

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d6c84d6be6411e0447e51d0090bce1465bcf609ae63a711ad88823f646809efc
size 134217856

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:39975a057e7c41a50d99164212b9279d69ad872295b3195d74e34cb6695e6938
size 134217856

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d53bc0fd0897c7372c2f182b84163bcf401ad91f26b6949c6d7a1d70c5dbb513
size 134217856

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:81aa3425ea73ef060976d4021908dd453a69ba8036de53cff93f52dcec05ba32
size 134217856

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:7a4e1cf76fb2c54ad755745f70757cbebce0b093d33ba2ad437636497da4f68b
size 134217856

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:43b78cdded1b5acd9d569920357d45c9433348a955b794524ee705845828825c
size 134217856

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:fc882f29530f7d683be631214f6667611e0aba87453a11157d71c50f3548fb3c
size 134217856

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:0f0bc616c2e581444a6fa658e45ee942c1ef5a4d21f22363518331f8d80cbe62
size 134217856

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b6311b36169637883c3b6723adac7626f68053d7bbeaee09336bf43b7754c662
size 134217856

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:00383256ee1eea343837cf72c2e7b923c540f8871d7b929230e23475f87d993c
size 134217856

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:ad94f2af43a2a06ebddc78566cbef0ea538c3da9191682e2fde4ecddb061b0f0
size 134217856

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:89c3c809fbc2e605869ce9347a10734330aaf3cb56f9bbd66e97295b9edeb642
size 134217856

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:6930349d1ae1479dcb4c9ee9eaefe2da3eae62539cb4b97de873a7a8b175e809
size 134217856

View File

@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:746ea83013870351296e01e294905ee027291ef79fd78c1e6b69dd9ebaa1cba0
size 10240000
oid sha256:76934d1d202aea1311ba67f5ea35eeb99a9c5c856f491565032e7d54ca6f072d
size 13598720

37
notes/models.md Normal file
View File

@@ -0,0 +1,37 @@
# models
## no polarisation flipping
```py
config_path="data/20241229-163*-128-16384-50000-*.ini"
model=".models/best_20241230_011907.tar"
```
```py
config_path="data/20241229-163*-128-16384-80000-*.ini"
model=".models/best_20241230_103752.tar"
```
```py
config_path="data/20241229-163*-128-16384-100000-*.ini"
model=".models/best_20241230_164534.tar"
```
## with polarisation flipping
polarisation flipping: signal is randomly rotated by 180°. polarization rotation can be detected by adding a tone on one of the polarisations, but only to mod 180° with a direct detection setup. the randomly flipped signal should allow the network to hopefully learn to compensate for dispersion, pmd independently from the polarization rot. the training data includes the flipped signal as well, but no indication if the polarisation is flipped.
```py
config_path="data/20241229-163*-128-16384-50000-*.ini"
model=".models/best_20241231_000328.tar"
```
```py
config_path="data/20241229-163*-128-16384-80000-*.ini"
model=".models/best_20241231_163614.tar"
```
```py
config_path="data/20241229-163*-128-16384-100000-*.ini"
model=".models/best_20241231_170532.tar"
```

View File

@@ -0,0 +1,59 @@
# Baseline Models
## a) D+S, pol_error 0, ortho_error 0, DGD 0
dataset
```raw
data/20250118-225840-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini
```
model
```raw
.models/best_20250118_225918.tar
```
## b) D+S, pol_error 0.4, ortho_error 0, DGD 0
dataset
```raw
data/20250116-214606-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini
```
model
```raw
.models/best_20250116_214816.tar
```
## c) D+S, pol_error 0, ortho_error 0.1, DGD 0
dataset
```raw
data/20250116-214547-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini
```
model
```raw
.models/best_20250117_122319.tar
```
## d) D+S, pol_error 0, ortho_error 0, DGD 10ps (1 T_sym)
birefringence angle pi/2 (worst case)
dataset
```raw
data/20250117-143926-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini
```
model
```raw
.models/best_20250117_144001.tar
```

2
pypho

Submodule pypho updated: dd015f4852...e44fc477fe

View File

@@ -1,535 +0,0 @@
"""
generate_signal.py
This file is part of the repo "optical-regeneration"
https://git.suuppl.dev/seppl/optical-regeneration.git
Joseph Hopfmüller
Copyright 2024
Licensed under the EUPL
Full license text in LICENSE file
"""
import configparser
from datetime import datetime
import hashlib
from pathlib import Path
import time
from matplotlib import pyplot as plt # noqa: F401
import numpy as np
import add_pypho # noqa: F401
import pypho
default_config = f"""
[glova]
nos = 256
sps = 256
f0 = 193414489032258.06
symbolrate = 10e9
wisdom_dir = "{str((Path.home() / ".pypho"))}"
flags = "FFTW_PATIENT"
nthreads = 32
[fiber]
length = 10000
gamma = 1.14
alpha = 0.2
D = 17
S = 0
birefsteps = 0
max_delta_beta = 0.4
; birefseed = 0xC0FFEE
[signal]
; seed = 0xC0FFEE
modulation = "pam"
mod_order = 4
mod_depth = 0.8
max_jitter = 0.02
; jitter_seed = 0xC0FFEE
laser_power = 0
edfa_power = 3
edfa_nf = 5
pulse_shape = "gauss"
fwhm = 0.33
[data]
dir = "data"
npy_dir = "npys"
"""
def get_config(config_file=None):
"""
DANGER! The function uses eval() to parse the config file. Do not use this function with untrusted input.
"""
if config_file is None:
config_file = Path(__file__).parent / "signal_generation.ini"
if not config_file.exists():
with open(config_file, "w") as f:
f.write(default_config)
config = configparser.ConfigParser()
config.read(config_file)
conf = {}
for section in config.sections():
# print(f"[{section}]")
conf[section] = {}
for key in config[section]:
# print(f"{key} = {config[section][key]}")
conf[section][key] = eval(config[section][key])
# if isinstance(conf[section][key], str):
# conf[section][key] = config[section][key].strip('"')
return conf
class pam_generator:
def __init__(
self,
glova,
mod_order=None,
mod_depth=0.5,
pulse_shape="gauss",
fwhm=0.33,
seed=None,
single_channel=False
) -> None:
self.glova = glova
self.pulse_shape = pulse_shape
self.modulation_depth = mod_depth
self.mod_order = mod_order
self.fwhm = fwhm
self.seed = seed
self.single_channel = single_channel
def __call__(self, E, symbols, max_jitter=0):
max_jitter = int(round(max_jitter * self.glova.sps))
if self.pulse_shape == "gauss":
wavelet = self.gauss(oversampling=6)
else:
raise ValueError(f"Unknown pulse shape: {self.pulse_shape}")
# prepare symbols
symbols_x = symbols[0] / (self.mod_order or np.max(symbols[0]))
diffs_x = np.diff(symbols_x, prepend=symbols_x[0])
digital_x = self.generate_digital_signal(diffs_x, max_jitter)
digital_x = np.pad(
digital_x, (0, self.glova.sps // 2), "constant", constant_values=(0, 0)
)
# create analog signal of diff of symbols
E_x = np.convolve(digital_x, wavelet)
# convert to pam and set modulation depth (scale and move up such that 1 stays at 1)
E_x = np.cumsum(E_x) * self.modulation_depth + (1 - self.modulation_depth)
# cut off the wavelet tails
E_x = E_x[self.glova.sps // 2 + len(wavelet) // 2 - 1 : -len(wavelet) // 2]
# modulate the laser
E[0]["E"][0] = np.sqrt(np.multiply(np.square(E[0]["E"][0]), E_x))
if not self.single_channel:
symbols_y = symbols[1] / (self.mod_order or np.max(symbols[1]))
diffs_y = np.diff(symbols_y, prepend=symbols_y[0])
digital_y = self.generate_digital_signal(diffs_y, max_jitter)
digital_y = np.pad(
digital_y, (0, self.glova.sps // 2), "constant", constant_values=(0, 0)
)
E_y = np.convolve(digital_y, wavelet)
E_y = np.cumsum(E_y) * self.modulation_depth + (1 - self.modulation_depth)
E_y = E_y[self.glova.sps // 2 + len(wavelet) // 2 - 1 : -len(wavelet) // 2]
E[0]["E"][1] = np.sqrt(np.multiply(np.square(E[0]["E"][1]), E_y))
else:
E[0]["E"][1] = np.zeros_like(E[0]["E"][0], dtype=E[0]["E"][0].dtype)
return E
def generate_digital_signal(self, symbols, max_jitter=0):
rs = np.random.RandomState(self.seed)
signal = np.zeros(self.glova.nos * self.glova.sps)
for index in range(self.glova.nos):
jitter = max_jitter != 0 and rs.randint(-max_jitter, max_jitter)
signal_index = index * self.glova.sps + jitter
if signal_index < 0:
continue
if signal_index >= len(signal):
continue
signal[signal_index] = symbols[index]
return signal
def gauss(self, oversampling=1):
sample_points = np.linspace(
-oversampling * self.glova.sps,
oversampling * self.glova.sps,
oversampling * 2 * self.glova.sps,
endpoint=True,
)
sigma = self.fwhm / (1 * np.sqrt(2 * np.log(2))) * self.glova.sps
pulse = (
1
/ (sigma * np.sqrt(2 * np.pi))
* np.exp(-np.square(sample_points) / (2 * np.square(sigma)))
)
return pulse
def initialize_fiber_and_data(config, input_data_override=None):
py_glova = pypho.setup(
nos=config["glova"]["nos"],
sps=config["glova"]["sps"],
f0=config["glova"]["f0"],
symbolrate=config["glova"]["symbolrate"],
wisdom_dir=config["glova"]["wisdom_dir"],
flags=config["glova"]["flags"],
nthreads=config["glova"]["nthreads"],
)
c_glova = pypho.cfiber.GlovaWrapper.from_setup(py_glova)
c_data = pypho.cfiber.DataWrapper(py_glova.sps * py_glova.nos)
py_edfa = pypho.oamp(py_glova, Pmean=config["signal"]["edfa_power"], NF=config["signal"]["edfa_nf"])
if input_data_override is not None:
c_data.E_in = input_data_override[0]
noise = input_data_override[1]
else:
config["signal"]["seed"] = config["signal"].get(
"seed", (int(time.time() * 1000)) % 2**32
)
config["signal"]["jitter_seed"] = config["signal"].get(
"jitter_seed", (int(time.time() * 1000)) % 2**32
)
symbolsrc = pypho.symbols(
py_glova, py_glova.nos, pattern="ones", p1=config["signal"]["mod_order"], seed=config["signal"]["seed"]
)
laser = pypho.lasmod(
py_glova, power=config["signal"]["laser_power"]+1.5, Df=0, theta=np.pi / 4
)
modulator = pam_generator(
py_glova,
mod_depth=config["signal"]["mod_depth"],
pulse_shape=config["signal"]["pulse_shape"],
fwhm=config["signal"]["fwhm"],
seed=config["signal"]["jitter_seed"],
single_channel=False
)
symbols_x = symbolsrc(pattern="random")
symbols_y = symbolsrc(pattern="random")
symbols_x[:3] = 0
symbols_y[:3] = 0
cw = laser()
source_signal = modulator(E=cw, symbols=(symbols_x, symbols_y))
source_signal = py_edfa(E=source_signal)
c_data.E_in = source_signal[0]["E"]
noise = source_signal[0]["noise"]
py_fiber = pypho.fiber(
glova=py_glova,
l=config["fiber"]["length"],
alpha=pypho.functions.dB_to_Neper(config["fiber"]["alpha"]) / 1000,
gamma=config["fiber"]["gamma"],
D=config["fiber"]["d"],
S=config["fiber"]["s"],
)
if config["fiber"].get("birefsteps", 0) > 0:
seed = config["fiber"].get(
"birefseed", (int(time.time() * 1000)) % 2**32
)
py_fiber.birefarray = pypho.birefringence_segment.create_pmd_fibre(
py_fiber.l,
py_fiber.l / config["fiber"]["birefsteps"],
# maxDeltaD=config["fiber"]["d"]/5,
maxDeltaBeta = config["fiber"].get("max_delta_beta", 0),
seed=seed,
)
c_params = pypho.cfiber.ParamsWrapper.from_fiber(
py_fiber, max_step=1e3 if py_fiber.gamma == 0 else 200
)
c_fiber = pypho.cfiber.FiberWrapper(c_data, c_params, c_glova)
return c_fiber, c_data, noise, py_edfa
def save_data(data, config):
data_dir = Path(config["data"]["dir"])
npy_dir = config["data"].get("npy_dir", "")
save_dir = data_dir / npy_dir if len(npy_dir) else data_dir
save_dir.mkdir(parents=True, exist_ok=True)
save_data = np.column_stack([
data.E_in[0],
data.E_in[1],
data.E_out[0],
data.E_out[1],
])
timestamp = datetime.now()
seed = config["signal"].get("seed", False)
jitter_seed = config["signal"].get("jitter_seed", False)
birefseed = config["fiber"].get("birefseed", False)
config_content = "\n".join((
f"; Generated by {str(Path(__file__).name)} @ {timestamp.strftime('%Y-%m-%d %H:%M:%S')}",
"[glova]",
f"sps = {config['glova']['sps']}",
f"nos = {config['glova']['nos']}",
f"f0 = {config['glova']['f0']}",
f"symbolrate = {config['glova']['symbolrate']}",
f'wisdom_dir = "{config["glova"]["wisdom_dir"]}"',
f'flags = "{config["glova"]["flags"]}"',
f"nthreads = {config['glova']['nthreads']}",
" ",
"[fiber]",
f"length = {config['fiber']['length']}",
f"gamma = {config['fiber']['gamma']}",
f"alpha = {config['fiber']['alpha']}",
f"D = {config['fiber']['d']}",
f"S = {config['fiber']['s']}",
f"birefsteps = {config['fiber'].get('birefsteps',0)}",
f"max_delta_beta = {config['fiber'].get('max_delta_beta', 0)}",
f"birefseed = {hex(birefseed)}" if birefseed else "; birefseed = not set",
"",
"[signal]",
f"seed = {hex(seed)}" if seed else "; seed = not set",
"",
f'modulation = "{config["signal"]["modulation"]}"',
f"mod_order = {config['signal']['mod_order']}",
f"mod_depth = {config['signal']['mod_depth']}",
""
f"max_jitter = {config['signal']['max_jitter']}",
f"jitter_seed = {hex(jitter_seed)}" if jitter_seed else "; jitter_seed = not set",
""
f"laser_power = {config['signal']['laser_power']}",
f"edfa_power = {config['signal']['edfa_power']}",
f"edfa_nf = {config['signal']['edfa_nf']}",
""
f'pulse_shape = "{config["signal"]["pulse_shape"]}"',
f"fwhm = {config['signal']['fwhm']}",
"",
"[data]",
f'dir = "{str(data_dir)}"',
f'npy_dir = "{npy_dir}"',
"file = "
))
config_hash = hashlib.md5(config_content.encode()).hexdigest()
save_file = f"{config_hash}.npy"
config_content += f'"{str(save_file)}"\n'
filename_components = (
timestamp.strftime("%Y%m%d-%H%M%S"),
config["glova"]["sps"],
config["glova"]["nos"],
config["fiber"]["length"],
config["fiber"]["gamma"],
config["fiber"]["alpha"],
config["fiber"]["d"],
config["fiber"]["s"],
f"{config['signal']['modulation'].upper()}{config['signal']['mod_order']}",
config['fiber'].get('birefsteps',0),
config["fiber"].get("max_delta_beta", 0),
)
lookup_file = "-".join(map(str, filename_components)) + ".ini"
with open(data_dir / lookup_file, "w") as f:
f.write(config_content)
np.save(save_dir / save_file, save_data)
print("Saved config to", data_dir / lookup_file)
print("Saved data to", save_dir / save_file)
def length_loop(config, lengths, incremental=False, bireflength=None, save=True):
lengths = sorted(lengths)
input_override = None
birefsteps_running = 0
for lind, length in enumerate(lengths):
# print(f"\nGenerating data for fiber length {length}")
if lind > 0 and incremental:
# set the length to the difference between the current and previous length -> incremental
length = lengths[lind] - lengths[lind - 1]
if incremental:
print(
f"\nGenerating data for fiber length {lengths[lind]}m [using {length}m increment]"
)
else:
print(f"\nGenerating data for fiber length {length}m")
config["fiber"]["length"] = length
if bireflength is not None and bireflength > 0:
config["fiber"]["birefsteps"] = length // bireflength
birefsteps_running += config["fiber"]["birefsteps"]
# set the input data to the output data of the previous run
cfiber, cdata, noise, edfa = initialize_fiber_and_data(
config, input_data_override=input_override
)
if lind == 0:
cdata_orig = cdata
mean_power_in = np.sum(pypho.functions.getpower_W(cdata.E_in))
print(
f"Mean power in: {mean_power_in:.3e} W ({pypho.functions.W_to_dBm(mean_power_in):.3e} dBm)"
)
cfiber()
mean_power_out = np.sum(pypho.functions.getpower_W(cdata.E_out))
print(
f"Mean power out: {mean_power_out:.3e} W ({pypho.functions.W_to_dBm(mean_power_out):.3e} dBm)"
)
if incremental:
input_override = (cdata.E_out, noise)
cdata.E_in = cdata_orig.E_in
config["fiber"]["length"] = lengths[lind]
if bireflength is not None:
config["fiber"]["birefsteps"] = birefsteps_running
E_tmp = [{'E': cdata.E_out, 'noise': noise*(-cfiber.params.l*cfiber.params.alpha)}]
E_tmp = edfa(E=E_tmp)
cdata.E_out = E_tmp[0]['E']
if save:
save_data(cdata, config)
in_out_eyes(cfiber, cdata)
def single_run_with_plot(config, save=True):
cfiber, cdata, noise, edfa = initialize_fiber_and_data(config)
mean_power_in = np.sum(pypho.functions.getpower_W(cdata.E_in))
print(
f"Mean power in: {mean_power_in:.3e} W ({pypho.functions.W_to_dBm(mean_power_in):.3e} dBm)"
)
cfiber()
mean_power_out = np.sum(pypho.functions.getpower_W(cdata.E_out))
print(
f"Mean power out: {mean_power_out:.3e} W ({pypho.functions.W_to_dBm(mean_power_out):.3e} dBm)"
)
E_tmp = [{'E': cdata.E_out, 'noise': noise*(-cfiber.params.l*cfiber.params.alpha)}]
E_tmp = edfa(E=E_tmp)
cdata.E_out = E_tmp[0]['E']
if save:
save_data(cdata, config)
in_out_eyes(cfiber, cdata)
def in_out_eyes(cfiber, cdata):
fig, axs = plt.subplots(2, 2, sharex=True, sharey=True)
eye_head = min(cfiber.glova.nos, 2000)
symbolrate_scale = 1e12
amplitude_scale = 1e3
plot_eye_diagram(
amplitude_scale * np.abs(cdata.E_in[0]) ** 2,
2 * cfiber.glova.sps,
normalize=False,
samplerate=cfiber.glova.symbolrate * cfiber.glova.sps / symbolrate_scale,
head=eye_head,
ax=axs[0][0],
show=False,
)
plot_eye_diagram(
amplitude_scale * np.abs(cdata.E_out[0]) ** 2,
2 * cfiber.glova.sps,
normalize=False,
samplerate=cfiber.glova.symbolrate * cfiber.glova.sps / symbolrate_scale,
head=eye_head,
ax=axs[0][1],
color="C1",
show=False,
)
plot_eye_diagram(
amplitude_scale * np.abs(cdata.E_in[1]) ** 2,
2 * cfiber.glova.sps,
normalize=False,
samplerate=cfiber.glova.symbolrate * cfiber.glova.sps / symbolrate_scale,
head=eye_head,
ax=axs[1][0],
show=False,
)
plot_eye_diagram(
amplitude_scale * np.abs(cdata.E_out[1]) ** 2,
2 * cfiber.glova.sps,
normalize=False,
samplerate=cfiber.glova.symbolrate * cfiber.glova.sps / symbolrate_scale,
head=eye_head,
ax=axs[1][1],
color="C1",
show=False,
)
title_map = [
["Input x", "Output x"],
["Input y", "Output y"],
]
title_map = np.array(title_map)
for ax, title in zip(axs.flatten(), title_map.flatten()):
ax.grid(True)
ax.set_xlabel("Time [ps]")
ax.set_ylabel("Power [mW]")
ax.set_title(title)
fig.tight_layout()
plt.show()
def plot_eye_diagram(
signal: np.ndarray,
eye_width,
offset=0,
*,
head=None,
samplerate=1,
normalize=True,
ax=None,
color="C0",
show=True,
):
ax = ax or plt.gca()
if head is not None:
signal = signal[: head * eye_width]
if normalize:
signal = signal / np.max(signal)
slices = np.lib.stride_tricks.sliding_window_view(signal, eye_width + 1)[
offset % (eye_width + 1) :: eye_width
]
plt_ax = np.arange(-eye_width // 2, eye_width // 2 + 1) / samplerate
for slice in slices:
ax.plot(plt_ax, slice, color=color, alpha=0.1)
ax.grid()
if show:
plt.show()
if __name__ == "__main__":
add_pypho.show_log()
config = get_config()
ranges = (10000,)
# scales = tuple(range(1, 10))
scales = (1,)
lengths = [range_ * scale for range_ in ranges for scale in scales]
# lengths.append(10*max(ranges))
lengths = [*lengths, *lengths]
lengths = sorted(lengths)
length_loop(config, lengths, incremental=False, bireflength=None, save=True)
# birefringence is constant over coupling length -> several 100m -> bireflength=1000 (m)
# single_run_with_plot(config, save=False)

View File

@@ -1,7 +1,16 @@
import copy
from datetime import datetime
from pathlib import Path
import random
from typing import Literal
import matplotlib
from matplotlib.colors import LinearSegmentedColormap
import torch.nn.utils.parametrize
try:
matplotlib.use("cairo")
except ImportError:
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
@@ -11,35 +20,19 @@ import optuna
import warnings
import torch
import torch.nn as nn
# import torch.nn as nn
# import torch.nn.functional as F # mse_loss doesn't support complex numbers
import torch.optim as optim
import torch.utils.data
from torch.utils.tensorboard import SummaryWriter
import hypertraining.models as models
# from rich.progress import (
# Progress,
# TextColumn,
# BarColumn,
# TaskProgressColumn,
# TimeRemainingColumn,
# MofNCompleteColumn,
# TimeElapsedColumn,
# )
# from rich.console import Console
# from rich import print as rprint
from torch.utils.tensorboard import SummaryWriter
import multiprocessing
from util.datasets import FiberRegenerationDataset
# from util.optuna_helpers import (
# suggest_categorical_optional, # noqa: F401
# suggest_float_optional, # noqa: F401
# suggest_int_optional, # noqa: F401
# )
from util.optuna_helpers import install_optional_suggests
import util
@@ -65,7 +58,6 @@ class HyperTraining:
model_settings,
optimizer_settings,
optuna_settings,
# console=None,
):
self.global_settings: GlobalSettings = global_settings
self.data_settings: DataSettings = data_settings
@@ -75,11 +67,8 @@ class HyperTraining:
self.optuna_settings: OptunaSettings = optuna_settings
self.processes = None
# self.console = console or Console()
# set some extra settings to make the code more readable
self._extra_optuna_settings()
self.stop_study = True
self.stop_study = False
def setup_tb_writer(self, study_name=None, append=None):
log_dir = self.pytorch_settings.summary_dir + "/" + (study_name or self.optuna_settings.study_name)
@@ -229,7 +218,7 @@ class HyperTraining:
self.optuna_settings._parallel = self.optuna_settings._n_threads > 1
def define_model(self, trial: optuna.Trial, writer=None):
n_layers = trial.suggest_int_optional("model_n_hidden_layers", self.model_settings.n_hidden_layers)
n_hidden_layers = trial.suggest_int_optional("model_n_hidden_layers", self.model_settings.n_hidden_layers)
input_dim = trial.suggest_int_optional(
"model_input_dim",
@@ -245,32 +234,44 @@ class HyperTraining:
dtype = getattr(torch, dtype)
afunc = trial.suggest_categorical_optional("model_activation_func", self.model_settings.model_activation_func)
# T0 = trial.suggest_float_optional("T0", self.model_settings.satabsT0 , log=True)
afunc = getattr(util.complexNN, afunc)
layer_func = trial.suggest_categorical_optional("model_layer_function", self.model_settings.model_layer_function)
layer_func = getattr(util.complexNN, layer_func)
layer_parametrizations = self.model_settings.model_layer_parametrizations
layers = []
last_dim = input_dim
n_nodes = last_dim
for i in range(n_layers):
scale_layers = trial.suggest_categorical_optional("model_enable_scale_layers", self.model_settings.scale)
hidden_dims = []
for i in range(n_hidden_layers):
if hidden_dim_override := self.model_settings.overrides.get(f"n_hidden_nodes_{i}", False):
hidden_dim = trial.suggest_int_optional(f"model_hidden_dim_{i}", hidden_dim_override)
hidden_dims.append(trial.suggest_int_optional(f"model_hidden_dim_{i}", hidden_dim_override))
else:
hidden_dim = trial.suggest_int_optional(
hidden_dims.append(trial.suggest_int_optional(
f"model_hidden_dim_{i}",
self.model_settings.n_hidden_nodes,
)
layers.append(util.complexNN.ONNRect(last_dim, hidden_dim, dtype=dtype))
last_dim = hidden_dim
layers.append(getattr(util.complexNN, afunc)())
n_nodes += last_dim
layers.append(util.complexNN.ONNRect(last_dim, self.model_settings.output_dim, dtype=dtype))
model = nn.Sequential(*layers)
))
model_kwargs = {
"dims": (input_dim, *hidden_dims, self.model_settings.output_dim),
"layer_function": layer_func,
"layer_func_kwargs": self.model_settings.model_layer_kwargs,
"act_function": afunc,
"act_func_kwargs": None,
"parametrizations": layer_parametrizations,
"dtype": dtype,
"dropout_prob": self.model_settings.dropout_prob,
"scale_layers": scale_layers,
"rotate": False,
}
model = models.regenerator(*model_kwargs.pop("dims"), **model_kwargs)
n_nodes = sum(hidden_dims)
if writer is not None:
writer.add_graph(model, torch.zeros(1, input_dim, dtype=dtype), use_strict_trace=False)
n_params = sum(p.numel() for p in model.parameters())
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
trial.set_user_attr("model_n_params", n_params)
trial.set_user_attr("model_n_nodes", n_nodes)
@@ -384,7 +385,11 @@ class HyperTraining:
running_loss2 = 0.0
running_loss = 0.0
model.train()
for batch_idx, (x, y) in enumerate(train_loader):
loader_len = len(train_loader)
for batch_idx, batch in enumerate(train_loader):
x = batch["x"]
y = batch["y"]
if batch_idx >= self.optuna_settings._n_train_batches:
break
model.zero_grad(set_to_none=True)
@@ -393,7 +398,7 @@ class HyperTraining:
y.to(self.pytorch_settings.device),
)
y_pred = model(x)
loss = util.complexNN.complex_mse_loss(y_pred, y)
loss = util.complexNN.complex_mse_loss(y_pred, y, power=True)
loss_value = loss.item()
loss.backward()
optimizer.step()
@@ -408,14 +413,14 @@ class HyperTraining:
writer.add_scalar(
"training loss",
running_loss2 / (self.pytorch_settings.write_every if batch_idx > 0 else 1),
epoch * min(len(train_loader), self.optuna_settings._n_train_batches) + batch_idx,
epoch * min(loader_len, self.optuna_settings._n_train_batches) + batch_idx,
)
running_loss2 = 0.0
# if enable_progress:
# progress.stop()
return running_loss / min(len(train_loader), self.optuna_settings._n_train_batches)
return running_loss / min(loader_len, self.optuna_settings._n_train_batches)
def eval_model(
self,
@@ -446,9 +451,10 @@ class HyperTraining:
model.eval()
running_error = 0
running_error_2 = 0
with torch.no_grad():
for batch_idx, (x, y) in enumerate(valid_loader):
for batch_idx, batch in enumerate(valid_loader):
x = batch["x"]
y = batch["y"]
if batch_idx >= self.optuna_settings._n_valid_batches:
break
x, y = (
@@ -456,72 +462,91 @@ class HyperTraining:
y.to(self.pytorch_settings.device),
)
y_pred = model(x)
error = util.complexNN.complex_mse_loss(y_pred, y)
error = util.complexNN.complex_mse_loss(y_pred, y, power=True)
error_value = error.item()
running_error += error_value
running_error_2 += error_value
# if enable_progress:
# progress.update(task, advance=1, description=f"{error_value:.3e}")
if writer is not None:
if batch_idx % self.pytorch_settings.write_every == 0:
writer.add_scalar(
"eval loss",
running_error_2 / (self.pytorch_settings.write_every if batch_idx > 0 else 1),
epoch * min(len(valid_loader), self.optuna_settings._n_valid_batches) + batch_idx,
)
running_error_2 = 0.0
running_error /= min(len(valid_loader), self.optuna_settings._n_valid_batches)
if writer is not None:
title_append, subtitle = self.build_title(trial)
writer.add_figure(
"fiber response",
self.plot_model_response(
trial,
model=model,
title_append=title_append,
subtitle=subtitle,
show=False,
),
epoch + 1,
writer.add_scalar(
"eval loss",
running_error,
epoch,
)
# if (epoch + 1) % 10 == 0 or epoch < 10:
# # plotting is slow, so only do it every 10 epochs
# title_append, subtitle = self.build_title(trial)
# head_fig, eye_fig, powers_fig = self.plot_model_response(
# model=model,
# title_append=title_append,
# subtitle=subtitle,
# show=False,
# )
# writer.add_figure(
# "fiber response",
# head_fig,
# epoch + 1,
# )
# writer.add_figure(
# "eye diagram",
# eye_fig,
# epoch + 1,
# )
# writer.add_figure(
# "powers",
# powers_fig,
# epoch + 1,
# )
# writer.flush()
# if enable_progress:
# progress.stop()
return running_error
def run_model(self, model, loader):
def run_model(self, model, loader, trace_powers=False):
model.eval()
xs = []
ys = []
y_preds = []
fiber_out = []
fiber_in = []
regen = []
timestamps = []
with torch.no_grad():
model = model.to(self.pytorch_settings.device)
for x, y in loader:
for batch in loader:
x = batch["x"]
y = batch["y"]
timestamp = batch["timestamp"]
x, y = (
x.to(self.pytorch_settings.device),
y.to(self.pytorch_settings.device),
)
y_pred = model(x).cpu()
if trace_powers:
y_pred, powers = model(x, trace_powers=True).cpu()
else:
y_pred = model(x, trace_powers=True).cpu()
# x = x.cpu()
# y = y.cpu()
y_pred = y_pred.view(y_pred.shape[0], -1, 2)
y = y.view(y.shape[0], -1, 2)
x = x.view(x.shape[0], -1, 2)
xs.append(x[:, 0, :].squeeze())
ys.append(y.squeeze())
y_preds.append(y_pred.squeeze())
# timestamp = timestamp.view(-1, 1)
fiber_out.append(x[:, x.shape[1] // 2, :].squeeze())
fiber_in.append(y.squeeze())
regen.append(y_pred.squeeze())
timestamps.append(timestamp.squeeze())
xs = torch.vstack(xs).cpu()
ys = torch.vstack(ys).cpu()
y_preds = torch.vstack(y_preds).cpu()
return ys, xs, y_preds
fiber_out = torch.vstack(fiber_out).cpu()
fiber_in = torch.vstack(fiber_in).cpu()
regen = torch.vstack(regen).cpu()
timestamps = torch.concat(timestamps).cpu()
if trace_powers:
return fiber_in, fiber_out, regen, timestamps, powers
return fiber_in, fiber_out, regen, timestamps
def objective(self, trial: optuna.Trial, plot_before=False):
def objective(self, trial: optuna.Trial):
if self.stop_study:
trial.study.stop()
model = None
@@ -537,29 +562,54 @@ class HyperTraining:
title_append, subtitle = self.build_title(trial)
writer.add_figure(
"fiber response",
self.plot_model_response(
trial,
model=model,
title_append=title_append,
subtitle=subtitle,
show=plot_before,
),
0,
)
# writer.add_figure(
# "fiber response",
# self.plot_model_response(
# trial,
# model=model,
# title_append=title_append,
# subtitle=subtitle,
# show=False,
# ),
# 0,
# )
# writer.add_figure(
# "eye diagram",
# self.plot_model_response(
# trial,
# model=self.model,
# title_append=title_append,
# subtitle=subtitle,
# mode="eye",
# show=False,
# ),
# 0,
# )
# writer.add_figure(
# "powers",
# self.plot_model_response(
# trial,
# model=self.model,
# title_append=title_append,
# subtitle=subtitle,
# mode="powers",
# show=False,
# ),
# 0,
# )
train_loader, valid_loader = self.get_sliced_data(trial)
optimizer_name = trial.suggest_categorical_optional("optimizer", self.optimizer_settings.optimizer)
lr = trial.suggest_float_optional("lr", self.optimizer_settings.learning_rate, log=True)
lr = trial.suggest_float_optional("lr", self.optimizer_settings.optimizer_kwargs["lr"], log=True)
optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)
if self.optimizer_settings.scheduler is not None:
scheduler = getattr(optim.lr_scheduler, self.optimizer_settings.scheduler)(
optimizer, **self.optimizer_settings.scheduler_kwargs
)
# if self.optimizer_settings.scheduler is not None:
# scheduler = getattr(optim.lr_scheduler, self.optimizer_settings.scheduler)(
# optimizer, **self.optimizer_settings.scheduler_kwargs
# )
for epoch in range(self.pytorch_settings.epochs):
trial.set_user_attr("epoch", epoch)
@@ -585,8 +635,8 @@ class HyperTraining:
writer,
# enable_progress=enable_progress,
)
if self.optimizer_settings.scheduler is not None:
scheduler.step(error)
# if self.optimizer_settings.scheduler is not None:
# scheduler.step(error)
trial.set_user_attr("mse", error)
trial.set_user_attr("log_mse", np.log10(error + np.finfo(float).eps))
@@ -602,14 +652,16 @@ class HyperTraining:
if self.optuna_settings._multi_objective:
return -np.log10(error + np.finfo(float).eps), trial.user_attrs.get("model_n_nodes", -1)
if self.pytorch_settings.save_models and model is not None:
save_path = Path(self.pytorch_settings.model_dir) / f"{self.optuna_settings.study_name}_{trial.number}.pth"
save_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(model, save_path)
# if self.pytorch_settings.save_models and model is not None:
# save_path = Path(self.pytorch_settings.model_dir) / f"{self.optuna_settings.study_name}_{trial.number}.pth"
# save_path.parent.mkdir(parents=True, exist_ok=True)
# torch.save(model, save_path)
return error
def _plot_model_response_eye(self, *signals, labels=None, sps=None, title_append="", subtitle="", show=True):
def _plot_model_response_eye(
self, *signals, timestamps, labels=None, sps=None, title_append="", subtitle="", show=True
):
if sps is None:
raise ValueError("sps must be provided")
if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))):
@@ -624,27 +676,84 @@ class HyperTraining:
if not any(labels):
labels = [f"signal {i + 1}" for i in range(len(signals))]
fig, axs = plt.subplots(2, len(signals), sharex=True, sharey=True)
x_bins = np.linspace(0, 2, 2 * sps, endpoint=False)
y_bins = np.zeros((2 * len(signals), 1000))
eye_data = np.zeros((2 * len(signals), 1000, 2 * sps))
# signals = [signal.cpu().numpy() for signal in signals]
for i in range(len(signals) * 2):
eye_signal = signals[i // 2][:, i % 2] # x, y, x, y, ...
eye_signal = np.real(np.square(np.abs(eye_signal)))
data_min = np.min(eye_signal)
data_max = np.max(eye_signal)
y_bins[i] = np.linspace(data_min, data_max, 1000, endpoint=False)
for j in range(len(timestamps)):
t = timestamps[j] / sps
val = eye_signal[j]
x = np.digitize(t % 2, x_bins) - 1
y = np.digitize(val, y_bins[i]) - 1
eye_data[i][y][x] += 1
cmap = LinearSegmentedColormap.from_list(
"eyemap",
[
(0, "white"),
(0.001, "dodgerblue"),
(0.1, "blue"),
(0.2, "cyan"),
(0.5, "lime"),
(0.8, "gold"),
(1, "red"),
],
)
# ordering = np.argsort(timestamps)
# signals = [signal[ordering] for signal in signals]
# timestamps = timestamps[ordering]
fig, axs = plt.subplots(1, 2 * len(signals), sharex=True, sharey=True)
fig.set_figwidth(18)
fig.suptitle(f"Eye diagram{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}")
xaxis = np.linspace(0, 2, 2 * sps, endpoint=False)
for j, (label, signal) in enumerate(zip(labels, signals)):
# xaxis = timestamps / sps
# xaxis = np.arange(2 * sps) / sps
for j, label in enumerate(labels):
x = eye_data[2 * j]
y = eye_data[2 * j + 1]
# x, y = signal.T
# signal = signal.cpu().numpy()
for i in range(len(signal) // sps - 1):
x, y = signal[i * sps : (i + 2) * sps].T
axs[0, j].plot(xaxis, np.abs(x) ** 2, color="C0", alpha=0.02)
axs[1, j].plot(xaxis, np.abs(y) ** 2, color="C0", alpha=0.02)
axs[0, j].set_title(label + " x")
axs[1, j].set_title(label + " y")
axs[0, j].set_xlabel("Symbol")
axs[1, j].set_xlabel("Symbol")
axs[0, j].set_ylabel("normalized power")
axs[1, j].set_ylabel("normalized power")
# for i in range(len(signal) // sps - 1):
# x, y = signal[i * sps : (i + 2) * sps].T
# axs[0 + 2 * j].scatter((timestamps/sps) % 2, np.abs(x) ** 2, color=f"C{j}", alpha=1 / (len(signal) // sps) * 10, s=1)
# axs[1 + 2 * j].scatter((timestamps/sps) % 2, np.abs(y) ** 2, color=f"C{j}", alpha=1 / (len(signal) // sps) * 10, s=1)
axs[0 + 2 * j].imshow(
x, aspect="auto", cmap=cmap, origin="lower", extent=[0, 2, y_bins[2 * j][0], y_bins[2 * j][-1]]
)
axs[1 + 2 * j].imshow(
y, aspect="auto", cmap=cmap, origin="lower", extent=[0, 2, y_bins[2 * j + 1][0], y_bins[2 * j + 1][-1]]
)
axs[0 + 2 * j].set_xlim((x_bins[0], x_bins[-1]))
axs[1 + 2 * j].set_xlim((x_bins[0], x_bins[-1]))
ymin = np.min(y_bins[:, 0])
ymax = np.max(y_bins[:, -1])
ydiff = ymax - ymin
axs[0 + 2 * j].set_ylim((ymin - 0.05 * ydiff, ymax + 0.05 * ydiff))
axs[1 + 2 * j].set_ylim((ymin - 0.05 * ydiff, ymax + 0.05 * ydiff))
axs[0 + 2 * j].set_title(label + " x")
axs[1 + 2 * j].set_title(label + " y")
axs[0 + 2 * j].set_xlabel("Symbol")
axs[1 + 2 * j].set_xlabel("Symbol")
axs[0 + 2 * j].set_box_aspect(1)
axs[1 + 2 * j].set_box_aspect(1)
axs[0].set_ylabel("normalized power")
fig.tight_layout()
# axs[1+2*len(labels)-1].set_ylabel("normalized power")
if show:
plt.show()
return fig
def _plot_model_response_head(self, *signals, labels=None, sps=None, title_append="", subtitle="", show=True):
def _plot_model_response_head(
self, *signals, timestamps, labels=None, sps=None, title_append="", subtitle="", show=True
):
if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))):
labels = [labels]
else:
@@ -657,19 +766,31 @@ class HyperTraining:
if not any(labels):
labels = [f"signal {i + 1}" for i in range(len(signals))]
ordering = np.argsort(timestamps)
signals = [signal[ordering] for signal in signals]
timestamps = timestamps[ordering]
fig, axs = plt.subplots(1, 2, sharex=True, sharey=True)
fig.set_size_inches(18, 6)
fig.set_figwidth(18)
fig.set_figheight(4)
fig.suptitle(f"Fiber response{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}")
for i, ax in enumerate(axs):
ax: plt.Axes
for signal, label in zip(signals, labels):
if sps is not None:
xaxis = np.linspace(0, len(signal) / sps, len(signal), endpoint=False)
xaxis = timestamps / sps
else:
xaxis = np.arange(len(signal))
xaxis = timestamps
ax.plot(xaxis, np.abs(signal[:, i]) ** 2, label=label)
ax.set_xlabel("Sample" if sps is None else "Symbol")
ax.set_ylabel("normalized power")
ax.minorticks_on()
ax.tick_params(axis="y", which="minor", left=False, right=False)
ax.grid(which="major", axis="x")
ax.grid(which="minor", axis="x", linestyle=":")
ax.grid(which="major", axis="y")
ax.legend(loc="upper right")
fig.tight_layout()
if show:
plt.show()
return fig
@@ -680,22 +801,52 @@ class HyperTraining:
model=None,
title_append="",
subtitle="",
mode: Literal["eye", "head"] = "head",
show=True,
mode: Literal["eye", "head", "powers"] = "head",
show=False,
):
if mode == "powers":
input_data = torch.ones(
1, 2 * self.data_settings.output_size, dtype=getattr(torch, self.data_settings.dtype)
).to(self.pytorch_settings.device)
model = model.to(self.pytorch_settings.device)
model.eval()
with torch.no_grad():
_, powers = model(input_data, trace_powers=True)
powers = [power.item() for power in powers]
layer_names = ["input", *[str(x).split("(")[0] for x in model._layers._modules.values()]]
# remove dropout layers
mask = [1 if "Dropout" not in layer_name else 0 for layer_name in layer_names]
layer_names = [layer_name for layer_name, m in zip(layer_names, mask) if m]
powers = [power for power, m in zip(powers, mask) if m]
fig = self._plot_model_response_powers(
powers, layer_names, title_append=title_append, subtitle=subtitle, show=show
)
return fig
data_settings_backup = copy.deepcopy(self.data_settings)
pytorch_settings_backup = copy.deepcopy(self.pytorch_settings)
self.data_settings.drop_first = 100*128
self.data_settings.drop_first = 99.5 + random.randint(0, 1000)
self.data_settings.shuffle = False
self.data_settings.train_split = 1.0
self.pytorch_settings.batchsize = (
self.pytorch_settings.eye_symbols if mode == "eye" else self.pytorch_settings.head_symbols
)
plot_loader, _ = self.get_sliced_data(trial, override={"num_symbols": self.pytorch_settings.batchsize})
config_path = random.choice(self.data_settings.config_path) if isinstance(self.data_settings.config_path, (list, tuple)) else self.data_settings.config_path
fiber_length = int(float(str(config_path).split('-')[-7])/1000)
plot_loader, _ = self.get_sliced_data(
trial,
override={
"num_symbols": self.pytorch_settings.batchsize,
"config_path": config_path,
}
)
self.data_settings = data_settings_backup
self.pytorch_settings = pytorch_settings_backup
fiber_in, fiber_out, regen = self.run_model(model, plot_loader)
fiber_in, fiber_out, regen, timestamps = self.run_model(model, plot_loader)
fiber_in = fiber_in.view(-1, 2)
fiber_out = fiber_out.view(-1, 2)
regen = regen.view(-1, 2)
@@ -703,6 +854,7 @@ class HyperTraining:
fiber_in = fiber_in.numpy()
fiber_out = fiber_out.numpy()
regen = regen.numpy()
timestamps = timestamps.numpy()
# https://github.com/matplotlib/matplotlib/issues/27713#issue-2104110987
# https://github.com/matplotlib/matplotlib/issues/27713#issuecomment-1915497463
@@ -713,9 +865,10 @@ class HyperTraining:
fiber_in,
fiber_out,
regen,
timestamps=timestamps,
labels=("fiber in", "fiber out", "regen"),
sps=plot_loader.dataset.samples_per_symbol,
title_append=title_append,
title_append=title_append + f" ({fiber_length} km)",
subtitle=subtitle,
show=show,
)
@@ -725,9 +878,10 @@ class HyperTraining:
fiber_in,
fiber_out,
regen,
timestamps=timestamps,
labels=("fiber in", "fiber out", "regen"),
sps=plot_loader.dataset.samples_per_symbol,
title_append=title_append,
title_append=title_append + f" ({fiber_length} km)",
subtitle=subtitle,
show=show,
)
@@ -739,7 +893,7 @@ class HyperTraining:
@staticmethod
def build_title(trial: optuna.trial.Trial):
title_append = f"for trial {trial.number}"
title_append = f"at epoch {trial.user_attrs.get("epoch", -1)} for trial {trial.number}"
model_n_hidden_layers = util.misc.multi_getattr((trial.params, trial.user_attrs), "model_n_hidden_layers", 0)
input_dim = util.misc.multi_getattr((trial.params, trial.user_attrs), "model_input_dim", 0)
model_dims = [

View File

@@ -0,0 +1,443 @@
from typing import Any
import lightning as L
import numpy as np
import torch
import torch.nn as nn
# import torch.nn.functional as F
from util.complexNN import DropoutComplex, Scale, ONNRect, EOActivation, energy_conserving, clamp, complex_mse_loss
from util.datasets import FiberRegenerationDataset
class regeneratorData(L.LightningDataModule):
def __init__(
self,
config_globs,
output_symbols,
output_dim,
dtype,
drop_first,
shuffle=True,
train_split=None,
batch_size=None,
loader_settings=None,
seed=None,
num_symbols=None,
test_globs=None,
):
super().__init__()
self._config_globs = config_globs
self._test_globs = test_globs
self._test_data_available = test_globs is not None
if self._test_data_available:
self.test_dataloader = self._test_dataloader
self._output_symbols = output_symbols
self._output_dim = output_dim
self._dtype = dtype
self._drop_first = drop_first
self._seed = seed
self._shuffle = shuffle
self._num_symbols = num_symbols
self._train_split = train_split if train_split is not None else 0.8
self.batch_size = batch_size if batch_size is not None else 1024
self._loader_settings = loader_settings if loader_settings is not None else {}
def _get_data(self):
self._data_train = FiberRegenerationDataset(
file_path=self._config_globs,
symbols=self._output_symbols,
output_dim=self._output_dim,
dtype=self._dtype,
real=not self._dtype.is_complex,
drop_first=self._drop_first,
num_symbols=self._num_symbols,
)
# self._data_plot = FiberRegenerationDataset(
# file_path=self._config_globs,
# symbols=self._output_symbols,
# output_dim=self._output_dim,
# dtype=self._dtype,
# real=not self._dtype.is_complex,
# drop_first=self._drop_first,
# num_symbols=400,
# )
if self._test_data_available:
self._data_test = FiberRegenerationDataset(
file_path=self._test_globs,
symbols=self._output_symbols,
output_dim=self._output_dim,
dtype=self._dtype,
real=not self._dtype.is_complex,
drop_first=self._drop_first,
num_symbols=self._num_symbols,
)
return self._data_train, self._data_test
return self._data_train
def _split_data(self, stage="fit", split=None, shuffle=None):
_split = split if split is not None else self._train_split
_shuffle = shuffle if shuffle is not None else self._shuffle
dataset_size = len(self._data_train)
indices = list(range(dataset_size))
split_index = int(np.floor(_split * dataset_size))
train_indices, valid_indices = indices[:split_index], indices[split_index:]
if _shuffle:
np.random.seed(self._seed)
np.random.shuffle(train_indices)
if _shuffle:
if stage == "fit" or stage == "predict":
self._train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
# if stage == "fit" or stage == "validate":
# self._valid_sampler = torch.utils.data.SubsetRandomSampler(valid_indices)
else:
if stage == "fit" or stage == "predict":
self._train_sampler = train_indices
if stage == "fit" or stage == "validate":
self._valid_sampler = valid_indices
if stage == "fit":
return self._train_sampler, self._valid_sampler
elif stage == "validate":
return self._valid_sampler
elif stage == "predict":
return self._train_sampler
def prepare_data(self):
self._get_data()
def setup(self, stage=None):
stage = stage or "fit"
self._split_data(stage=stage)
def train_dataloader(self):
return torch.utils.data.DataLoader(
self._data_train,
batch_size=self.batch_size,
sampler=self._train_sampler,
**self._loader_settings
)
def val_dataloader(self):
return torch.utils.data.DataLoader(
self._data_train,
batch_size=self.batch_size,
sampler=self._valid_sampler,
**self._loader_settings
)
def _test_dataloader(self):
return torch.utils.data.DataLoader(
self._data_test,
shuffle=self._shuffle,
batch_size=self.batch_size,
**self._loader_settings
)
def predict_dataloader(self):
return torch.utils.data.DataLoader(
self._data_plot,
shuffle=False,
batch_size=40,
pin_memory=True,
drop_last=True,
num_workers=4,
prefetch_factor=2,
)
# def plot_dataloader(self):
class regenerator(L.LightningModule):
def __init__(
self,
*dims,
layer_function=ONNRect,
layer_func_kwargs: dict | None = {"square": True},
act_function=EOActivation,
act_func_kwargs: dict | None = None,
parametrizations: list[dict] | None = [
{
"tensor_name": "weight",
"parametrization": energy_conserving,
},
{
"tensor_name": "alpha",
"parametrization": clamp,
},
{
"tensor_name": "alpha",
"parametrization": clamp,
},
],
dtype=torch.complex64,
dropout_prob=0.01,
scale_layers=False,
optimizer=torch.optim.AdamW,
optimizer_kwargs: dict | None = {
"lr": 0.01,
"amsgrad": True,
},
lr_scheduler=None,
lr_scheduler_kwargs: dict | None = {
"patience": 20,
"factor": 0.5,
"min_lr": 1e-6,
"cooldown": 10,
},
sps = 128,
# **kwargs,
):
torch.set_float32_matmul_precision('high')
layer_func_kwargs = layer_func_kwargs if layer_func_kwargs is not None else {}
act_func_kwargs = act_func_kwargs if act_func_kwargs is not None else {}
optimizer_kwargs = optimizer_kwargs if optimizer_kwargs is not None else {}
lr_scheduler_kwargs = lr_scheduler_kwargs if lr_scheduler_kwargs is not None else {}
super().__init__()
self.example_input_array = torch.randn(1, dims[0], dtype=dtype)
self._sps = sps
self.optimizer_settings = {
"optimizer": optimizer,
"optimizer_kwargs": optimizer_kwargs,
"lr_scheduler": lr_scheduler,
"lr_scheduler_kwargs": lr_scheduler_kwargs,
}
# if len(dims) == 0:
# try:
# dims = kwargs["dims"]
# except KeyError:
# raise ValueError("dims must be provided")
self._n_hidden_layers = len(dims) - 2
self.build_model(dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob, scale_layers)
def build_model(self, dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob, scale_layers):
input_layer = nn.Sequential(
layer_function(dims[0], dims[1], dtype=dtype, **layer_func_kwargs),
act_function(size=dims[1], **act_func_kwargs),
DropoutComplex(p=dropout_prob),
)
if scale_layers:
input_layer = nn.Sequential(Scale(dims[0]), input_layer)
self.layer_0 = input_layer
for i in range(1, self._n_hidden_layers):
layer = nn.Sequential(
layer_function(dims[i], dims[i + 1], dtype=dtype, **layer_func_kwargs),
act_function(size=dims[i + 1], **act_func_kwargs),
DropoutComplex(p=dropout_prob),
)
if scale_layers:
layer = nn.Sequential(Scale(dims[i]), layer)
setattr(self, f"layer_{i}", layer)
output_layer = nn.Sequential(
layer_function(dims[-2], dims[-1], dtype=dtype, **layer_func_kwargs),
act_function(size=dims[-1], **act_func_kwargs),
Scale(dims[-1]),
)
setattr(self, f"layer_{self._n_hidden_layers}", output_layer)
if parametrizations is not None:
self._apply_parametrizations(self, parametrizations)
def _apply_parametrizations(self, layer, parametrizations):
for sub_layer in layer.children():
if len(sub_layer._modules) > 0:
self._apply_parametrizations(sub_layer, parametrizations)
else:
for parametrization in parametrizations:
tensor_name = parametrization.get("tensor_name", None)
if tensor_name is None:
continue
parametrization_func = parametrization.get("parametrization", None)
if parametrization_func is None:
continue
param_kwargs = parametrization.get("kwargs", {})
if tensor_name in sub_layer._parameters:
parametrization_func(sub_layer, tensor_name, **param_kwargs)
def _trace_powers(self, enable, x, powers=None):
if not enable:
return
if powers is None:
powers = []
powers.append(x.abs().square().sum())
return powers
# def plot(self, mode):
# self.predict_step()
# def validation_epoch_end(self, outputs):
# x = torch.vstack([output['x'].view(output['x'].shape[0], -1, 2)[:, output['x'].shape[1]//2, :].squeeze() for output in outputs])
# y = torch.vstack([output['y'].view(output['y'].shape[0], -1, 2).squeeze() for output in outputs])
# y_hat = torch.vstack([output['y_hat'].view(output['y_hat'].shape[0], -1, 2).squeeze() for output in outputs])
# timesteps = torch.vstack([output['timesteps'].squeeze() for output in outputs])
# powers = torch.vstack([output['powers'] for output in outputs])
# return {'x': x, 'y': y, 'y_hat': y_hat, 'timesteps': timesteps, 'powers': powers}
def on_validation_epoch_end(self):
if self.current_epoch % 10 == 0 or self.current_epoch == self.trainer.max_epochs - 1 or self.current_epoch < 10:
x = self.val_outputs['x']
# x = x.view(x.shape[0], -1, 2)
# x = x[:, x.shape[1]//2, :].squeeze()
y = self.val_outputs['y']
# y = y.view(y.shape[0], -1, 2).squeeze()
y_hat = self.val_outputs['y_hat']
# y_hat = y_hat.view(y_hat.shape[0], -1, 2).squeeze()
timesteps = self.val_outputs['timesteps']
# timesteps = timesteps.squeeze()
powers = self.val_outputs['powers']
# powers = powers.squeeze()
fiber_in = x.detach().cpu().numpy()
fiber_out = y.detach().cpu().numpy()
regen = y_hat.detach().cpu().numpy()
timesteps = timesteps.detach().cpu().numpy()
# powers = np.array([power.detach().cpu().numpy() for power in powers])
# fiber_in = np.concat(fiber_in, axis=0)
# fiber_out = np.concat(fiber_out, axis=0)
# regen = np.concat(regen, axis=0)
# timesteps = np.concat(timesteps, axis=0)
# powers = powers.detach().cpu().numpy()
import gc
fig = self.plot_model_head(fiber_in, fiber_out, regen, timesteps, sps=self._sps)
self.logger.experiment.add_figure("model response", fig, self.current_epoch)
# fig = self.plot_model_eye(fiber_in, fiber_out, regen, timesteps, sps=self._sps)
# self.logger.experiment.add_figure("model eye", fig, self.current_epoch)
# fig = self.plot_model_powers(powers)
# self.logger.experiment.add_figure("powers", fig, self.current_epoch)
gc.collect()
# x, y, y_hat, timesteps, powers = self.validation_epoch_end(self.outputs)
# self.plot(x, y, y_hat, timesteps, powers)
def plot_model_head(self, fiber_in, fiber_out, regen, timesteps, sps):
import matplotlib
matplotlib.use("TkCairo")
import matplotlib.pyplot as plt
ordering = np.argsort(timesteps)
signals = [signal[ordering] for signal in [fiber_in, fiber_out, regen]]
timesteps = timesteps[ordering]
signals = [signal[:sps*40] for signal in signals]
timesteps = timesteps[:sps*40]
fig, axs = plt.subplots(1, 2, sharex=True, sharey=True)
fig.set_figwidth(16)
fig.set_figheight(4)
for i, ax in enumerate(axs):
for j, signal in enumerate(signals):
ax.plot(timesteps / sps, np.square(np.abs(signal[:,i])), label=["fiber in", "fiber out", "regen"][j] + [" x", " y"][i])
ax.set_xlabel("symbol")
ax.set_ylabel("amplitude")
ax.minorticks_on()
ax.tick_params(axis="y", which="minor", left=False, right=False)
ax.grid(which="major", axis="x")
ax.grid(which="minor", axis="x", linestyle=":")
ax.grid(which="major", axis="y")
ax.legend(loc="upper right")
fig.tight_layout()
return fig
def plot_model_eye(self, fiber_in, fiber_out, regen, timesteps, sps):
...
def plot_model_powers(self, powers):
...
def forward(self, x, trace_powers=False):
powers = self._trace_powers(trace_powers, x)
x = self.layer_0(x)
powers = self._trace_powers(trace_powers, x, powers)
for i in range(1, self._n_hidden_layers):
x = getattr(self, f"layer_{i}")(x)
powers = self._trace_powers(trace_powers, x, powers)
x = getattr(self, f"layer_{self._n_hidden_layers}")(x)
powers = self._trace_powers(trace_powers, x, powers)
if trace_powers:
return x, powers
return x
def configure_optimizers(self):
optimizer = self.optimizer_settings["optimizer"](
self.parameters(), **self.optimizer_settings["optimizer_kwargs"]
)
if self.optimizer_settings["lr_scheduler"] is not None:
lr_scheduler = self.optimizer_settings["lr_scheduler"](
optimizer, **self.optimizer_settings["lr_scheduler_kwargs"]
)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": lr_scheduler,
"monitor": "val_loss",
}
}
return {"optimizer": optimizer}
def training_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0):
x, y, timesteps = batch
y_hat = self(x)
loss = complex_mse_loss(y_hat, y, power=True)
self.log("train_loss", loss, on_epoch=True, on_step=True)
return loss
def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0):
x, y, timesteps = batch
if batch_idx == 0:
y_hat, powers = self.forward(x, trace_powers=True)
else:
y_hat = self.forward(x)
loss = complex_mse_loss(y_hat, y, power=True)
self.log("val_loss", loss, on_epoch=True)
y = y.view(y.shape[0], -1, 2).squeeze()
x = x.view(x.shape[0], -1, 2)
x = x[:, x.shape[1]//2, :].squeeze()
y_hat = y_hat.view(y_hat.shape[0], -1, 2).squeeze()
timesteps = timesteps.squeeze()
if batch_idx == 0:
powers = np.array([power.detach().cpu() for power in powers])
self.val_outputs = {"y": y, "x": x, "y_hat": y_hat, "timesteps": timesteps, "powers": powers}
else:
self.val_outputs["y"] = torch.vstack([self.val_outputs["y"], y])
self.val_outputs["x"] = torch.vstack([self.val_outputs["x"], x])
self.val_outputs["y_hat"] = torch.vstack([self.val_outputs["y_hat"], y_hat])
self.val_outputs["timesteps"] = torch.concat([self.val_outputs["timesteps"], timesteps], dim=0)
return loss
def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0):
x, y, timesteps = batch
y_hat = self(x)
loss = complex_mse_loss(y_hat, y, power=True)
self.log("test_loss", loss, on_epoch=True)
return loss
# def predict_step(self, batch, batch_idx):
# x, y, timesteps = batch
# y_hat = self(x)
# return y, x, y_hat, timesteps

View File

@@ -0,0 +1,229 @@
import torch
from torch.nn import Module, Sequential
from util.complexNN import (
DropoutComplex,
Scale,
ONNRect,
photodiode,
EOActivation,
polarimeter,
# normalize_by_first,
rotate,
)
class polarisation_estimator2(Module):
def __init__(self):
super(polarisation_estimator2, self).__init__()
self.layers = Sequential(
polarimeter(),
torch.nn.Linear(4, 4),
torch.nn.ReLU(),
# torch.nn.Dropout(p=0.01),
torch.nn.Linear(4, 4),
torch.nn.ReLU(),
# torch.nn.Dropout(p=0.01),
torch.nn.Linear(4, 1),
)
def forward(self, x):
# x = self.polarimeter(x)
for layer in self.layers:
x = layer(x)
return x
class polarisation_estimator(Module):
def __init__(
self,
*dims,
layer_function=ONNRect,
layer_func_kwargs: dict | None = None,
output_layer_function=photodiode,
# output_layer_func_kwargs: dict | None = None,
act_function=EOActivation,
act_func_kwargs: dict | None = None,
parametrizations: list[dict] = None,
dtype=torch.float64,
dropout_prob=0.01,
scale_layers=False,
):
super(polarisation_estimator, self).__init__()
self._n_hidden_layers = len(dims) - 2
layer_func_kwargs = layer_func_kwargs or {}
act_func_kwargs = act_func_kwargs or {}
self.build_model(dims, layer_function, layer_func_kwargs, output_layer_function, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob, scale_layers)
def forward(self, x):
x = self.layer_0(x)
for i in range(1, self._n_hidden_layers):
x = getattr(self, f"layer_{i}")(x)
x = getattr(self, f"layer_{self._n_hidden_layers}")(x)
x = torch.remainder(x, torch.ones_like(x) * 2 * torch.pi)
return x.squeeze()
def build_model(self, dims, layer_function, layer_func_kwargs, output_layer_function, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob, scale_layers):
for i in range(0, self._n_hidden_layers):
self.add_module(f"layer_{i}", Sequential())
if scale_layers:
self.get_submodule(f"layer_{i}").add_module("scale", Scale(dims[i]))
module = layer_function(dims[i], dims[i + 1], dtype=dtype, **layer_func_kwargs)
self.get_submodule(f"layer_{i}").add_module("ONN", module)
module = act_function(size=dims[i + 1], **act_func_kwargs)
self.get_submodule(f"layer_{i}").add_module("activation", module)
module = DropoutComplex(p=dropout_prob)
self.get_submodule(f"layer_{i}").add_module("dropout", module)
self.add_module(f"layer_{self._n_hidden_layers}", Sequential())
if scale_layers:
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("scale", Scale(dims[-2]))
module = layer_function(dims[-2], dims[-1], dtype=dtype, **layer_func_kwargs)
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("ONN", module)
module = output_layer_function(size=dims[-1])
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("photodiode", module)
# module = normalize_by_first()
# self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("normalize", module)
if parametrizations is not None:
self._apply_parametrizations(self, parametrizations)
def _apply_parametrizations(self, layer, parametrizations):
for sub_layer in layer.children():
if len(sub_layer._modules) > 0:
self._apply_parametrizations(sub_layer, parametrizations)
else:
for parametrization in parametrizations:
tensor_name = parametrization.get("tensor_name", None)
if tensor_name is None:
continue
parametrization_func = parametrization.get("parametrization", None)
if parametrization_func is None:
continue
param_kwargs = parametrization.get("kwargs", {})
if tensor_name in sub_layer._parameters:
parametrization_func(sub_layer, tensor_name, **param_kwargs)
class regenerator(Module):
def __init__(
self,
*dims,
layer_function=ONNRect,
layer_func_kwargs: dict | None = None,
act_function=EOActivation,
act_func_kwargs: dict | None = None,
parametrizations: list[dict] = None,
dtype=torch.float64,
dropout_prob=0.01,
prescale=1,
rotate=False,
):
super(regenerator, self).__init__()
self._n_hidden_layers = len(dims) - 2
layer_func_kwargs = layer_func_kwargs or {}
act_func_kwargs = act_func_kwargs or {}
self.rotation = rotate
self.prescale = prescale
self.build_model(dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob)
def build_model(self, dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob):
for i in range(0, self._n_hidden_layers):
self.add_module(f"layer_{i}", Sequential())
module = layer_function(dims[i], dims[i + 1], dtype=dtype, **layer_func_kwargs)
self.get_submodule(f"layer_{i}").add_module("ONN", module)
module = act_function(size=dims[i + 1], **act_func_kwargs)
self.get_submodule(f"layer_{i}").add_module("activation", module)
if dropout_prob is not None and dropout_prob > 0:
module = DropoutComplex(p=dropout_prob)
self.get_submodule(f"layer_{i}").add_module("dropout", module)
self.add_module(f"layer_{self._n_hidden_layers}", Sequential())
# if scale_layers:
# self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("scale", Scale(dims[-2]))
module = layer_function(dims[-2], dims[-1], dtype=dtype, **layer_func_kwargs)
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("ONN", module)
module = act_function(size=dims[-1], **act_func_kwargs)
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("activation", module)
module = Scale(size=dims[-1])
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("out_scale", module)
if self.rotation:
module = rotate()
self.add_module("rotate", module)
# module = Scale(size=dims[-1])
# self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("out_scale", module)
if parametrizations is not None:
self._apply_parametrizations(self, parametrizations)
def _apply_parametrizations(self, layer, parametrizations):
for sub_layer in layer.children():
if len(sub_layer._modules) > 0:
self._apply_parametrizations(sub_layer, parametrizations)
else:
for parametrization in parametrizations:
tensor_name = parametrization.get("tensor_name", None)
if tensor_name is None:
continue
parametrization_func = parametrization.get("parametrization", None)
if parametrization_func is None:
continue
param_kwargs = parametrization.get("kwargs", {})
if tensor_name in sub_layer._parameters:
parametrization_func(sub_layer, tensor_name, **param_kwargs)
def _trace_powers(self, enable, x, powers=None):
if not enable:
return
if powers is None:
powers = []
powers.append(x.abs().square().sum())
return powers
def forward(self, x, angle=None, pre_rot=False, trace_powers=False):
x = x * self.prescale
powers = self._trace_powers(trace_powers, x)
# x = self.layer_0(x)
# powers = self._trace_powers(trace_powers, x, powers)
for i in range(0, self._n_hidden_layers):
x = getattr(self, f"layer_{i}")(x)
powers = self._trace_powers(trace_powers, x, powers)
x = getattr(self, f"layer_{self._n_hidden_layers}")(x)
if self.rotation:
try:
x_rot = self.rotate(x, angle)
except AttributeError:
pass
powers = self._trace_powers(trace_powers, x_rot, powers)
else:
x_rot = x
if pre_rot and trace_powers:
return x_rot, x, powers
if pre_rot and not trace_powers:
return x_rot, x
if not pre_rot and trace_powers:
return x_rot, powers
return x_rot

View File

@@ -18,8 +18,28 @@ class DataSettings:
shuffle: bool = True
in_out_delay: float = 0
xy_delay: tuple | float | int = 0
drop_first: int = 1000
drop_first: int = 64
drop_last: int = 64
train_split: float = 0.8
polarisations: tuple | list = (0,)
# cross_pol_interference: float = 0
randomise_polarisations: bool = False
osnr: float | int = None
seed: int = None
"""
change to:
config_path: tuple | list | None = None
dtype: torch.dtype | None = None
symbols: int | float = 1
output_dim: int = 2
shuffle: bool = True
drop_first: float | int = 0
train_split: float = 0.8
randomise_polarisations: bool = False
"""
# pytorch settings
@@ -30,8 +50,8 @@ class PytorchSettings:
device: str = "cuda"
dataloader_workers: int = 2
dataloader_prefetch: int = 2
dataloader_workers: int = 1
dataloader_prefetch: int = 1
save_models: bool = True
model_dir: str = ".models"
@@ -56,6 +76,30 @@ class ModelSettings:
model_layer_kwargs: dict | None = None
model_layer_parametrizations: list= field(default_factory=list)
"""
change to:
dims: tuple | list | None = None
layer_function: nn.Module | None = None
layer_func_kwargs: dict | None = None
activation_function: nn.Module | None = None
activation_func_kwargs: dict | None = None
output_function: nn.Module | None = None
output_func_kwargs: dict | None = None
dropout_function: nn.Module | None = None
dropout_func_kwargs: dict | None = None
scale_function: nn.Module | None = None
scale_func_kwargs: dict | None = None
parametrizations: list | None = None
"""
def _early_stop_default_kwargs():
return {
"threshold": 1e-05,
"plateau": 25,
}
@dataclass
class OptimizerSettings:
@@ -65,6 +109,20 @@ class OptimizerSettings:
scheduler: str | None = None
scheduler_kwargs: dict | None = None
early_stopping: bool = False
early_stop_kwargs: dict = field(default_factory=_early_stop_default_kwargs)
"""
change to:
optimizer: torch.optim.Optimizer | None = None
optimizer_kwargs: dict | None = None
learning_rate: float | None = None
scheduler: torch.optim.lr_scheduler | None = None
scheduler_kwargs: dict | None = None
"""
def _pruner_default_kwargs():
# MedianPruner

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,217 @@
from pathlib import Path
import sys
from matplotlib import pyplot as plt
import numpy as np
import torch
import util
from hypertraining.settings import GlobalSettings, DataSettings, ModelSettings, OptimizerSettings, PytorchSettings
from hypertraining import models
# def move_to_location_in_size(array, location, size):
# array_x, array_y = array.shape
# location_x, location_y = location
# size_x, size_y = size
# left = location_x
# right = size_x - array_x - location_x
# top = location_y
# bottom = size_y - array_y - location_y
# return np.pad(
# array,
# (
# (left, right),
# (top, bottom),
# ),
# constant_values=(-np.inf, -np.inf),
# )
def register_puccs_cmap(puccs_path=None):
puccs_path = Path(__file__).resolve().parent / 'puccs.csv' if puccs_path is None else puccs_path
colors = []
# keys = None
with open(puccs_path, "r") as f:
for i, line in enumerate(f.readlines()):
elements = tuple(line.split(","))
# if i == 0:
# # keys = elements
# continue
# else:
try:
colors.append(tuple(map(float, elements[4:])))
except ValueError:
continue
# colors = []
# for current in puccs_csv_data:
# colors.append(tuple(current[4:]))
from matplotlib.colors import LinearSegmentedColormap
import matplotlib as mpl
mpl.colormaps.register(LinearSegmentedColormap.from_list('puccs', colors))
def pad_to_size(array, size):
if not hasattr(size, "__len__"):
size = (size, size)
left = (
(size[0] - array.shape[0] + 1) // 2 if size[0] is not None else 0
)
right = (
(size[0] - array.shape[0]) // 2 if size[0] is not None else 0
)
top = (
(size[1] - array.shape[1] + 1) // 2 if size[1] is not None else 0
)
bottom = (
(size[1] - array.shape[1]) // 2 if size[1] is not None else 0
)
array: np.ndarray = array
if array.ndim == 2:
return np.pad(
array,
(
(left, right),
(top, bottom),
),
constant_values=(np.nan, np.nan),
)
elif array.ndim == 3:
return np.pad(
array,
(
(left, right),
(top, bottom),
(0,0)
),
constant_values=(np.nan, np.nan),
)
def model_plot(model_path, show=True):
torch.serialization.add_safe_globals([
*util.complexNN.__all__,
GlobalSettings,
DataSettings,
ModelSettings,
OptimizerSettings,
PytorchSettings,
models.regenerator,
torch.nn.utils.parametrizations.orthogonal,
])
checkpoint_dict = torch.load(model_path, weights_only=True)
dims = checkpoint_dict["model_kwargs"].pop("dims")
model = models.regenerator(*dims, **checkpoint_dict["model_kwargs"])
model.load_state_dict(checkpoint_dict["model_state_dict"], strict=False)
model_params = []
plots = []
max_size = np.max(dims)
# max_act_size = np.max(dims[1:])
# angles = [None, None]
# weights = [None, None]
for num, (layer_name, layer) in enumerate(model.named_children()):
# each layer contains an "ONN" layer and an "activation" layer
# activation layer is approximately the same for all layers and nodes -> rotation by 90 degrees
onn_weights = layer.ONN.weight
onn_weights = onn_weights.detach().cpu().numpy()
onn_values = np.abs(onn_weights).real
onn_angles = np.mod(np.angle(onn_weights), 2*np.pi).real
model_params.append({layer_name: onn_weights})
plots.append({layer_name: (num, onn_values, onn_angles)})#, act_values, act_angles)})
# fig, axs = plt.subplots(3, len(model_params)*2-1, figsize=(20, 5))
for plot in plots:
layer_name, (num, onn_values, onn_angles) = plot.popitem()
if num == 0:
value_img = onn_values
angle_img = onn_angles
onn_angles = pad_to_size(onn_angles, (max_size, None))
onn_values = pad_to_size(onn_values, (max_size, None))
else:
onn_values = pad_to_size(onn_values, (max_size, onn_values.shape[1]+1))
onn_angles = pad_to_size(onn_angles, (max_size, onn_angles.shape[1]+1))
value_img = np.concatenate((value_img, onn_values), axis=1)
angle_img = np.concatenate((angle_img, onn_angles), axis=1)
value_img = np.ma.array(value_img, mask=np.isnan(value_img))
angle_img = np.ma.array(angle_img, mask=np.isnan(angle_img))
# from cmcrameri import cm
from cmap import Colormap as cm
import scicomap as sc
# from matplotlib import colors as mcolors
# alpha_map = mcolors.LinearSegmentedColormap(
# 'alphamap',
# {
# 'red': [(0, 0, 0), (1, 0, 0)],
# 'green': [(0, 0, 0), (1, 0, 0)],
# 'blue': [(0, 0, 0), (1, 0, 0)],
# 'alpha': [
# (0, 1, 1),
# # (0.2, 0.2, 0.1),
# (1, 0, 0)
# ]
# }
# )
# alpha_map.set_bad(color="#AAAAAA")
from mpl_toolkits.axes_grid1 import make_axes_locatable
fig, axs = plt.subplots(1, 2, sharex=True, sharey=True, figsize=(16, 5))
# fig.tight_layout()
dividers = map(make_axes_locatable, axs)
caxs = list(map(lambda d: d.append_axes("right", size="5%", pad=0.05), dividers))
# masked_value_img = np.ma.masked_where(np.isnan(value_img), value_img)
masked_value_img = value_img
cmap = cm('google:turbo').to_matplotlib()
# cmap = sc.ScicoSequential("rainbow").get_mpl_color_map()
cmap.set_bad(color="#AAAAAA")
im_val = axs[0].imshow(masked_value_img, cmap=cmap, vmin=0, vmax=1)
fig.colorbar(im_val, cax=caxs[0], orientation="vertical")
masked_angle_img = np.ma.masked_where(np.isnan(angle_img), angle_img)
# cmap = cm('crameri:romao').to_matplotlib()
# cmap = plt.get_cmap('puccs')
# cmap = sc.ScicoCircular("colorwheel").get_mpl_color_map()
cmap = cm('colorcet:CET_C8').to_matplotlib()
cmap.set_bad(color="#AAAAAA")
im_ang = axs[1].imshow(masked_angle_img, cmap=cmap, vmin=0, vmax=2*np.pi)
cbar = fig.colorbar(im_ang, cax=caxs[1], orientation="vertical", ticks=[i/8 * 2*np.pi for i in range(9)])
cbar.ax.set_yticklabels(["0", "π/4", "π/2", "3π/4", "π", "5π/4", "3π/2", "7π/4", ""])
# im_ang_w = axs[2].imshow(masked_angle_img, cmap=cmap)
# im_ang_w = axs[2].imshow(masked_value_img, cmap=alpha_map)
axs[0].axis("off")
axs[1].axis("off")
# axs[2].axis("off")
axs[0].set_title("Values")
axs[1].set_title("Angles")
# axs[2].set_title("Values and Angles")
...
if show:
plt.show()
return fig
# model = models.regenerator(*dims, **model_kwargs)
if __name__ == "__main__":
register_puccs_cmap()
if len(sys.argv) > 1:
model_plot(sys.argv[1])
else:
print("Please provide a model path as an argument")
# model_plot(".models/best_20250114_224234.tar")

View File

@@ -0,0 +1,102 @@
"x","L","a","b","R","G","B"
0.,0.5187848173343539,0.6399990176455989,0.67,0.8889427469969852,0.22673227640012172,0.
0.01,0.5374499525557803,0.604014067614707,0.6777967519386492,0.8956274406155226,0.27553288030331824,0.
0.02,0.5560867887452998,0.5680836759482211,0.6855816828789898,0.9019507507843885,0.318608215541461,0.
0.03,0.5746877595125583,0.5322224300667823,0.6933516322080414,0.907905487190649,0.3580633000693721,0.
0.04,0.5932314662487472,0.49647158484797804,0.7010976613543587,0.9134808162089558,0.3949845524063657,0.
0.05,0.6117000836392819,0.46086550613202343,0.7088123243737041,0.918668356138916,0.43002019316005363,0.
0.06,0.6300828534995973,0.4254249348741487,0.7164911273850869,0.923462736751354,0.4635961938811463,0.
0.07,0.6483763163456417,0.3901565406944371,0.7241326253017896,0.9278609626724071,0.49601354353255284,0.
0.08,0.6665840140182806,0.3550534951951814,0.7317382976124045,0.9318616057744784,0.5274983630587982,0.
0.09,0.6847162776119433,0.3200958808181962,0.7393124597949372,0.9354640163365924,0.5582303922647159,0.
0.1,0.7027902128942014,0.2852507189547545,0.7468622572263107,0.9386675557407496,0.5883604892249517,0.004034952213848706
0.11,0.7208298719332069,0.25047163906104203,0.7543977368741345,0.9414708123927996,0.6180221032545026,0.016031521294251994
0.12,0.7388665670611175,0.2156982733607376,0.7619319784446927,0.943870754968487,0.6473392272576862,0.029857267582036696
0.13,0.7569392765472108,0.18085547473834482,0.7694812638396673,0.9458617774020323,0.676432172396153,0.045365670193636125
0.14,0.7750950944867471,0.14585244938794778,0.7770652650825484,0.9474345911958609,0.7054219201084561,0.06017985923530026
0.15,0.793389684293558,0.11058188251425949,0.7847072337503834,0.9485749196617762,0.7344334940032564,0.07418869502646075
0.16,0.8117919447684838,0.07510373484536464,0.792394178330817,0.9492596163836376,0.7634480277996188,0.08767517868137237
0.17,0.8293050962981561,0.03629277424762101,0.799038155466063,0.9462308253550155,0.7922009241807345,0.10066327128139077
0.18,0.8213303100752708,-0.0062517290795987,0.7879999288492758,0.9088702681901394,0.7940579017644396,0.10139639009534024
0.19,0.8134831311534617,-0.048115463155645855,0.7771383286984362,0.8716809050191757,0.7954897210083888,0.10232311621802098
0.2,0.80558613530069,-0.0902449644291895,0.7662077749032042,0.8337524177888596,0.7965471523787845,0.10344968926026826
0.21,0.7975860185564765,-0.13292460297117392,0.7551344872795225,0.7947193410849823,0.7972381033243311,0.10477682283894393
0.22,0.7894147026971006,-0.17651756772919341,0.7438242359834689,0.7540941866826836,0.7975605026647324,0.10631182441371936
0.23,0.7809997374598548,-0.2214103719409295,0.7321767396537806,0.7112894518675287,0.7974995317311054,0.1080672415170634
0.24,0.7722646970273015,-0.2680107379394189,0.7200862142018722,0.6655745739336695,0.7970267795229349,0.11006041388465265
0.25,0.7631307298557146,-0.3167393290089981,0.7074435179925446,0.6160047476007512,0.7960993904970947,0.11231257117602686
0.26,0.7535192192483822,-0.36801555555407994,0.6941398344519211,0.5612859274945571,0.794659599537827,0.11484733363789801
0.27,0.7433557597838075,-0.42223636134393283,0.6800721760037781,0.4994862901720824,0.7926351396848288,0.11768844813479104
0.28,0.732575139048096,-0.479749646583324,0.6651502794883674,0.42731393423789277,0.7899410218414098,0.12085678487511567
0.29,0.7211269294461059,-0.5408244362880141,0.6493043460161184,0.3378265607222193,0.786483110019224,0.124366774034814
0.3,0.7090756028785993,-0.6051167807996883,0.6326236137723747,0.2098475715121697,0.7821998608677176,0.12819222127525928
0.31,0.7094510768540225,-0.6165036055456403,0.5630307498747129,0.15061488620640032,0.7845112116922692,0.21943537230975235
0.32,0.7174669421288304,-0.5917687864932311,0.4797229624661701,0.18766933782916642,0.7905828987725732,0.31091344246312086
0.33,0.7249009746435938,-0.5688293479200438,0.40246208306061504,0.21160609617940718,0.7962175427587832,0.38519766326885596
0.34,0.7317072855135611,-0.5478268906666535,0.3317250285377912,0.22717569971119178,0.8013847719431052,0.4490960048955565
0.35,0.7379328517830899,-0.5286164561226088,0.26702357292455026,0.23690087622812972,0.8061220291668977,0.5056371468159843
0.36,0.7436229063122554,-0.5110584677642499,0.20788761731555405,0.24226377668817778,0.8104638164122776,0.5563570758573497
0.37,0.7488251728809415,-0.4950056627547577,0.15382117501783654,0.24424372086048424,0.8144455902164638,0.6022301663745243
0.38,0.7535943992285348,-0.48028910419451787,0.10425526029155024,0.24352232677523483,0.818107753931944,0.6440238320299774
0.39,0.757994865186593,-0.4667104416936734,0.05852182167144754,0.240562414747303,0.8214980148949816,0.6824536572462205
0.4,0.7620994844391137,-0.4540446830999986,0.015863077249098356,0.2356325204239052,0.8246710357361025,0.7182393675419642
0.41,0.7659871096124125,-0.4420485102716773,-0.024540477496154123,0.22880568593963535,0.8276865975886148,0.7521146815529202
0.42,0.7697410958994951,-0.4304647113488041,-0.06355514164248566,0.21993360985514526,0.8306086550266585,0.7848331944479765
0.43,0.773446484628189,-0.4190308715098135,-0.10206473803580057,0.20858849290850018,0.833503273690861,0.8171544357676854
0.44,0.7771893686864673,-0.4074813310994203,-0.14096401824224686,0.1939295692427068,0.8364382500400466,0.8498448067259188
0.45,0.7810574093604746,-0.3955455908045306,-0.18116403397486242,0.17438366103820427,0.839483669055626,0.8836865023336339
0.46,0.7851360804917298,-0.3829599011818591,-0.2235531031349741,0.14679145002531463,0.8427091517444469,0.9194481212717681
0.47,0.789525027020907,-0.369416784561489,-0.26916682191206776,0.10278921007810798,0.8461971304126237,0.9580316568065935
0.48,0.7942371698732826,-0.35487637041943493,-0.3181394757087982,0.0013920913109500188,0.8499626968466341,0.9995866371771526
0.49,0.7773897680996302,-0.31852357140025195,-0.34537976514700053,0.10740420703601522,0.8254781216972907,1.
0.5,0.7604011244310231,-0.28211213216592784,-0.3722846952738428,0.1581725581872408,0.8008522647497104,1.
0.51,0.7433440454962605,-0.2455540169176899,-0.3992980063927199,0.19300141807932156,0.7761561224913385,1.
0.52,0.7262590833969331,-0.20893614020926626,-0.42635547610418184,0.2194621842292243,0.751443124097109,1.
0.53,0.709058602701224,-0.17207067467417486,-0.453595892719742,0.2405673704012788,0.7265803324554873,1.
0.54,0.6915768892539101,-0.1346024482921609,-0.48128169789479536,0.25788347992973676,0.701321051230534,1.
0.55,0.6736331627810209,-0.09614399811510127,-0.5096991935104321,0.2722888922216317,0.6753950894563805,1.
0.56,0.6551463184003872,-0.05652149358027936,-0.5389768254408652,0.28422807900785235,0.6486730893521468,1.
0.57,0.6361671326276888,-0.01584376303510615,-0.5690341788729347,0.293907374075009,0.6212117649042732,1.
0.58,0.6168396823565967,0.025580396234342995,-0.5996430791016598,0.301442767979156,0.5931976878638505,1.
0.59,0.5973210287815495,0.06741435793529688,-0.6305547881733555,0.30694603901024253,0.5648312189065924,1.
0.6,0.5777303704171711,0.10940264614179468,-0.661580531294122,0.3105418468883679,0.5362525958007331,1.
0.61,0.5581475370499237,0.15137416317967575,-0.6925938819599547,0.3123531986526998,0.5075386530652202,1.
0.62,0.5386227795100639,0.19322120739317136,-0.7235152578861672,0.31248922600720636,0.4787151440558522,1.
0.63,0.5191666876024412,0.23492108185347996,-0.754327887989376,0.31103663081260624,0.44973844514160927,1.
0.64,0.4996990584326256,0.2766456839100268,-0.7851587896650079,0.30803814950244496,0.4204116611935119,1.
0.65,0.479957679121191,0.3189570094767831,-0.8164232296840259,0.30343473603466015,0.390226489453496,1.
0.66,0.4600072725872886,0.3617163391430824,-0.8480187063016573,0.29717122075330515,0.3591178757512998,1.
0.67,0.44600100870220305,0.4113853615984094,-0.8697728377551008,0.3178994129506999,0.3295740682997879,1.
0.68,0.4574651571354146,0.44026390446569547,-0.8504539292487465,0.3842479358768364,0.3280946443367561,1.
0.69,0.4691809168948424,0.46977626401045774,-0.830711015748157,0.44293649140770447,0.3260767554252525,1.
0.7,0.4811696900083858,0.49997635259991063,-0.8105080314416201,0.49708450874457527,0.3234487047238236,1.
0.71,0.49350094811609174,0.5310391714342613,-0.7897279055963483,0.5485591109413528,0.3201099534066949,1.
0.72,0.5062548753068121,0.5631667067020758,-0.7682355153041539,0.5985798481027601,0.3159263917472715,1.
0.73,0.5195243020949684,0.5965928013272943,-0.7458744264238399,0.6480500606439057,0.31071717884730565,1.
0.74,0.5334043922713477,0.6315571758288618,-0.7224842728734379,0.6976685401842261,0.3042411890803418,1.
0.75,0.5479805812358602,0.6682750446095802,-0.697921082452685,0.7479712773579563,0.29618040787504757,1.
0.76,0.5633244502526606,0.7069267230777347,-0.6720642293775535,0.7993701361353484,0.28611136999256687,1.
0.77,0.5794956601139,0.7476624986056212,-0.6448131757501174,0.8521918014427678,0.2734527325942473,1.
0.78,0.5965429098573916,0.7906050455688622,-0.6160858559672187,0.9067003897516911,0.2573693489198746,1.
0.79,0.6145761476424179,0.8360313267658297,-0.5856969899409387,0.963334644317004,0.23648492980159264,1.
0.8,0.6232910688128902,0.859291371252556,-0.5300995185388214,1.,0.21867949406239662,0.9712088595948508
0.81,0.6159984336377875,0.8439887543380684,-0.44635440435952856,1.,0.21606849746358275,0.9041480210597966
0.82,0.6091642745073532,0.8296481879180277,-0.36787420852419694,1.,0.21421830096504035,0.8419706002336461
0.83,0.6025478038652375,0.8157644115969636,-0.2918938425681935,1.,0.21295365915197917,0.7823908751330636
0.84,0.5961857222953111,0.8024144366282877,-0.21883475834162458,0.9971140114799418,0.21220068235083267,0.7256713129328118
0.85,0.5900921771070883,0.7896279492437488,-0.1488594167412921,0.993273906363258,0.2118788857127918,0.671860243327784
0.86,0.5842771639541229,0.7774259239818333,-0.08208260304413262,0.9887084084529413,0.21191070453347688,0.6209624706933893
0.87,0.578741582584259,0.7658102488427286,-0.018514649521559012,0.9835846378805114,0.2122246941077346,0.5728987835613306
0.88,0.5734741590353537,0.7547572669288056,0.04197390858426542,0.9780378159372328,0.21275878699579343,0.5274829957183049
0.89,0.5684517008574971,0.7442183119942206,0.09964940221121898,0.9721670725313721,0.21346242315895625,0.4844270603851604
0.9,0.5636419856510335,0.7341257696545772,0.15488185789614228,0.9660363209686843,0.21429691147008262,0.4433660148378527
0.91,0.5590069340453534,0.7243997354573974,0.20810856081277884,0.9596781387247791,0.2152344151262528,0.4038812338146013
0.92,0.5545051525321143,0.7149533506766244,0.25980485409830323,0.9530986696850675,0.21625626438013962,0.3655130449917989
0.93,0.5500961975299247,0.705701749880514,0.3104351723857584,0.9462863346513658,0.21735046958786286,0.327780364198278
0.94,0.545740378056064,0.6965616468647046,0.36045530782708896,0.93921469089265,0.21851014470332586,0.29014917175372823
0.95,0.5414004092067859,0.6874548042588865,0.41029342232076466,0.9318478255642132,0.21973168075163751,0.2519897371806688
0.96,0.5370416605957644,0.6783085548415655,0.46034719456417006,0.9241434776436454,0.22101341980094052,0.2124579038400577
0.97,0.5326309593934517,0.6690532898786764,0.5109975653738162,0.9160532016485884,0.22235495330179011,0.17018252385769012
0.98,0.5281374148557197,0.6596241892863608,0.5625992691950712,0.90752576202319,0.22375597459867458,0.1223073280126531
0.99,0.5235317096396147,0.6499597345521199,0.615488972291106,0.8985077346125597,0.22521565729028564,0.05933950582860665
1.,0.5187848173343539,0.6399990176455989,0.67,0.8889427469969852,0.22673227640012172,0.
1 x L a b R G B
2 0. 0.5187848173343539 0.6399990176455989 0.67 0.8889427469969852 0.22673227640012172 0.
3 0.01 0.5374499525557803 0.604014067614707 0.6777967519386492 0.8956274406155226 0.27553288030331824 0.
4 0.02 0.5560867887452998 0.5680836759482211 0.6855816828789898 0.9019507507843885 0.318608215541461 0.
5 0.03 0.5746877595125583 0.5322224300667823 0.6933516322080414 0.907905487190649 0.3580633000693721 0.
6 0.04 0.5932314662487472 0.49647158484797804 0.7010976613543587 0.9134808162089558 0.3949845524063657 0.
7 0.05 0.6117000836392819 0.46086550613202343 0.7088123243737041 0.918668356138916 0.43002019316005363 0.
8 0.06 0.6300828534995973 0.4254249348741487 0.7164911273850869 0.923462736751354 0.4635961938811463 0.
9 0.07 0.6483763163456417 0.3901565406944371 0.7241326253017896 0.9278609626724071 0.49601354353255284 0.
10 0.08 0.6665840140182806 0.3550534951951814 0.7317382976124045 0.9318616057744784 0.5274983630587982 0.
11 0.09 0.6847162776119433 0.3200958808181962 0.7393124597949372 0.9354640163365924 0.5582303922647159 0.
12 0.1 0.7027902128942014 0.2852507189547545 0.7468622572263107 0.9386675557407496 0.5883604892249517 0.004034952213848706
13 0.11 0.7208298719332069 0.25047163906104203 0.7543977368741345 0.9414708123927996 0.6180221032545026 0.016031521294251994
14 0.12 0.7388665670611175 0.2156982733607376 0.7619319784446927 0.943870754968487 0.6473392272576862 0.029857267582036696
15 0.13 0.7569392765472108 0.18085547473834482 0.7694812638396673 0.9458617774020323 0.676432172396153 0.045365670193636125
16 0.14 0.7750950944867471 0.14585244938794778 0.7770652650825484 0.9474345911958609 0.7054219201084561 0.06017985923530026
17 0.15 0.793389684293558 0.11058188251425949 0.7847072337503834 0.9485749196617762 0.7344334940032564 0.07418869502646075
18 0.16 0.8117919447684838 0.07510373484536464 0.792394178330817 0.9492596163836376 0.7634480277996188 0.08767517868137237
19 0.17 0.8293050962981561 0.03629277424762101 0.799038155466063 0.9462308253550155 0.7922009241807345 0.10066327128139077
20 0.18 0.8213303100752708 -0.0062517290795987 0.7879999288492758 0.9088702681901394 0.7940579017644396 0.10139639009534024
21 0.19 0.8134831311534617 -0.048115463155645855 0.7771383286984362 0.8716809050191757 0.7954897210083888 0.10232311621802098
22 0.2 0.80558613530069 -0.0902449644291895 0.7662077749032042 0.8337524177888596 0.7965471523787845 0.10344968926026826
23 0.21 0.7975860185564765 -0.13292460297117392 0.7551344872795225 0.7947193410849823 0.7972381033243311 0.10477682283894393
24 0.22 0.7894147026971006 -0.17651756772919341 0.7438242359834689 0.7540941866826836 0.7975605026647324 0.10631182441371936
25 0.23 0.7809997374598548 -0.2214103719409295 0.7321767396537806 0.7112894518675287 0.7974995317311054 0.1080672415170634
26 0.24 0.7722646970273015 -0.2680107379394189 0.7200862142018722 0.6655745739336695 0.7970267795229349 0.11006041388465265
27 0.25 0.7631307298557146 -0.3167393290089981 0.7074435179925446 0.6160047476007512 0.7960993904970947 0.11231257117602686
28 0.26 0.7535192192483822 -0.36801555555407994 0.6941398344519211 0.5612859274945571 0.794659599537827 0.11484733363789801
29 0.27 0.7433557597838075 -0.42223636134393283 0.6800721760037781 0.4994862901720824 0.7926351396848288 0.11768844813479104
30 0.28 0.732575139048096 -0.479749646583324 0.6651502794883674 0.42731393423789277 0.7899410218414098 0.12085678487511567
31 0.29 0.7211269294461059 -0.5408244362880141 0.6493043460161184 0.3378265607222193 0.786483110019224 0.124366774034814
32 0.3 0.7090756028785993 -0.6051167807996883 0.6326236137723747 0.2098475715121697 0.7821998608677176 0.12819222127525928
33 0.31 0.7094510768540225 -0.6165036055456403 0.5630307498747129 0.15061488620640032 0.7845112116922692 0.21943537230975235
34 0.32 0.7174669421288304 -0.5917687864932311 0.4797229624661701 0.18766933782916642 0.7905828987725732 0.31091344246312086
35 0.33 0.7249009746435938 -0.5688293479200438 0.40246208306061504 0.21160609617940718 0.7962175427587832 0.38519766326885596
36 0.34 0.7317072855135611 -0.5478268906666535 0.3317250285377912 0.22717569971119178 0.8013847719431052 0.4490960048955565
37 0.35 0.7379328517830899 -0.5286164561226088 0.26702357292455026 0.23690087622812972 0.8061220291668977 0.5056371468159843
38 0.36 0.7436229063122554 -0.5110584677642499 0.20788761731555405 0.24226377668817778 0.8104638164122776 0.5563570758573497
39 0.37 0.7488251728809415 -0.4950056627547577 0.15382117501783654 0.24424372086048424 0.8144455902164638 0.6022301663745243
40 0.38 0.7535943992285348 -0.48028910419451787 0.10425526029155024 0.24352232677523483 0.818107753931944 0.6440238320299774
41 0.39 0.757994865186593 -0.4667104416936734 0.05852182167144754 0.240562414747303 0.8214980148949816 0.6824536572462205
42 0.4 0.7620994844391137 -0.4540446830999986 0.015863077249098356 0.2356325204239052 0.8246710357361025 0.7182393675419642
43 0.41 0.7659871096124125 -0.4420485102716773 -0.024540477496154123 0.22880568593963535 0.8276865975886148 0.7521146815529202
44 0.42 0.7697410958994951 -0.4304647113488041 -0.06355514164248566 0.21993360985514526 0.8306086550266585 0.7848331944479765
45 0.43 0.773446484628189 -0.4190308715098135 -0.10206473803580057 0.20858849290850018 0.833503273690861 0.8171544357676854
46 0.44 0.7771893686864673 -0.4074813310994203 -0.14096401824224686 0.1939295692427068 0.8364382500400466 0.8498448067259188
47 0.45 0.7810574093604746 -0.3955455908045306 -0.18116403397486242 0.17438366103820427 0.839483669055626 0.8836865023336339
48 0.46 0.7851360804917298 -0.3829599011818591 -0.2235531031349741 0.14679145002531463 0.8427091517444469 0.9194481212717681
49 0.47 0.789525027020907 -0.369416784561489 -0.26916682191206776 0.10278921007810798 0.8461971304126237 0.9580316568065935
50 0.48 0.7942371698732826 -0.35487637041943493 -0.3181394757087982 0.0013920913109500188 0.8499626968466341 0.9995866371771526
51 0.49 0.7773897680996302 -0.31852357140025195 -0.34537976514700053 0.10740420703601522 0.8254781216972907 1.
52 0.5 0.7604011244310231 -0.28211213216592784 -0.3722846952738428 0.1581725581872408 0.8008522647497104 1.
53 0.51 0.7433440454962605 -0.2455540169176899 -0.3992980063927199 0.19300141807932156 0.7761561224913385 1.
54 0.52 0.7262590833969331 -0.20893614020926626 -0.42635547610418184 0.2194621842292243 0.751443124097109 1.
55 0.53 0.709058602701224 -0.17207067467417486 -0.453595892719742 0.2405673704012788 0.7265803324554873 1.
56 0.54 0.6915768892539101 -0.1346024482921609 -0.48128169789479536 0.25788347992973676 0.701321051230534 1.
57 0.55 0.6736331627810209 -0.09614399811510127 -0.5096991935104321 0.2722888922216317 0.6753950894563805 1.
58 0.56 0.6551463184003872 -0.05652149358027936 -0.5389768254408652 0.28422807900785235 0.6486730893521468 1.
59 0.57 0.6361671326276888 -0.01584376303510615 -0.5690341788729347 0.293907374075009 0.6212117649042732 1.
60 0.58 0.6168396823565967 0.025580396234342995 -0.5996430791016598 0.301442767979156 0.5931976878638505 1.
61 0.59 0.5973210287815495 0.06741435793529688 -0.6305547881733555 0.30694603901024253 0.5648312189065924 1.
62 0.6 0.5777303704171711 0.10940264614179468 -0.661580531294122 0.3105418468883679 0.5362525958007331 1.
63 0.61 0.5581475370499237 0.15137416317967575 -0.6925938819599547 0.3123531986526998 0.5075386530652202 1.
64 0.62 0.5386227795100639 0.19322120739317136 -0.7235152578861672 0.31248922600720636 0.4787151440558522 1.
65 0.63 0.5191666876024412 0.23492108185347996 -0.754327887989376 0.31103663081260624 0.44973844514160927 1.
66 0.64 0.4996990584326256 0.2766456839100268 -0.7851587896650079 0.30803814950244496 0.4204116611935119 1.
67 0.65 0.479957679121191 0.3189570094767831 -0.8164232296840259 0.30343473603466015 0.390226489453496 1.
68 0.66 0.4600072725872886 0.3617163391430824 -0.8480187063016573 0.29717122075330515 0.3591178757512998 1.
69 0.67 0.44600100870220305 0.4113853615984094 -0.8697728377551008 0.3178994129506999 0.3295740682997879 1.
70 0.68 0.4574651571354146 0.44026390446569547 -0.8504539292487465 0.3842479358768364 0.3280946443367561 1.
71 0.69 0.4691809168948424 0.46977626401045774 -0.830711015748157 0.44293649140770447 0.3260767554252525 1.
72 0.7 0.4811696900083858 0.49997635259991063 -0.8105080314416201 0.49708450874457527 0.3234487047238236 1.
73 0.71 0.49350094811609174 0.5310391714342613 -0.7897279055963483 0.5485591109413528 0.3201099534066949 1.
74 0.72 0.5062548753068121 0.5631667067020758 -0.7682355153041539 0.5985798481027601 0.3159263917472715 1.
75 0.73 0.5195243020949684 0.5965928013272943 -0.7458744264238399 0.6480500606439057 0.31071717884730565 1.
76 0.74 0.5334043922713477 0.6315571758288618 -0.7224842728734379 0.6976685401842261 0.3042411890803418 1.
77 0.75 0.5479805812358602 0.6682750446095802 -0.697921082452685 0.7479712773579563 0.29618040787504757 1.
78 0.76 0.5633244502526606 0.7069267230777347 -0.6720642293775535 0.7993701361353484 0.28611136999256687 1.
79 0.77 0.5794956601139 0.7476624986056212 -0.6448131757501174 0.8521918014427678 0.2734527325942473 1.
80 0.78 0.5965429098573916 0.7906050455688622 -0.6160858559672187 0.9067003897516911 0.2573693489198746 1.
81 0.79 0.6145761476424179 0.8360313267658297 -0.5856969899409387 0.963334644317004 0.23648492980159264 1.
82 0.8 0.6232910688128902 0.859291371252556 -0.5300995185388214 1. 0.21867949406239662 0.9712088595948508
83 0.81 0.6159984336377875 0.8439887543380684 -0.44635440435952856 1. 0.21606849746358275 0.9041480210597966
84 0.82 0.6091642745073532 0.8296481879180277 -0.36787420852419694 1. 0.21421830096504035 0.8419706002336461
85 0.83 0.6025478038652375 0.8157644115969636 -0.2918938425681935 1. 0.21295365915197917 0.7823908751330636
86 0.84 0.5961857222953111 0.8024144366282877 -0.21883475834162458 0.9971140114799418 0.21220068235083267 0.7256713129328118
87 0.85 0.5900921771070883 0.7896279492437488 -0.1488594167412921 0.993273906363258 0.2118788857127918 0.671860243327784
88 0.86 0.5842771639541229 0.7774259239818333 -0.08208260304413262 0.9887084084529413 0.21191070453347688 0.6209624706933893
89 0.87 0.578741582584259 0.7658102488427286 -0.018514649521559012 0.9835846378805114 0.2122246941077346 0.5728987835613306
90 0.88 0.5734741590353537 0.7547572669288056 0.04197390858426542 0.9780378159372328 0.21275878699579343 0.5274829957183049
91 0.89 0.5684517008574971 0.7442183119942206 0.09964940221121898 0.9721670725313721 0.21346242315895625 0.4844270603851604
92 0.9 0.5636419856510335 0.7341257696545772 0.15488185789614228 0.9660363209686843 0.21429691147008262 0.4433660148378527
93 0.91 0.5590069340453534 0.7243997354573974 0.20810856081277884 0.9596781387247791 0.2152344151262528 0.4038812338146013
94 0.92 0.5545051525321143 0.7149533506766244 0.25980485409830323 0.9530986696850675 0.21625626438013962 0.3655130449917989
95 0.93 0.5500961975299247 0.705701749880514 0.3104351723857584 0.9462863346513658 0.21735046958786286 0.327780364198278
96 0.94 0.545740378056064 0.6965616468647046 0.36045530782708896 0.93921469089265 0.21851014470332586 0.29014917175372823
97 0.95 0.5414004092067859 0.6874548042588865 0.41029342232076466 0.9318478255642132 0.21973168075163751 0.2519897371806688
98 0.96 0.5370416605957644 0.6783085548415655 0.46034719456417006 0.9241434776436454 0.22101341980094052 0.2124579038400577
99 0.97 0.5326309593934517 0.6690532898786764 0.5109975653738162 0.9160532016485884 0.22235495330179011 0.17018252385769012
100 0.98 0.5281374148557197 0.6596241892863608 0.5625992691950712 0.90752576202319 0.22375597459867458 0.1223073280126531
101 0.99 0.5235317096396147 0.6499597345521199 0.615488972291106 0.8985077346125597 0.22521565729028564 0.05933950582860665
102 1. 0.5187848173343539 0.6399990176455989 0.67 0.8889427469969852 0.22673227640012172 0.

View File

@@ -1,6 +1,8 @@
from datetime import datetime
import optuna
import torch
import util
from hypertraining.hypertraining import HyperTraining
from hypertraining.settings import (
GlobalSettings,
@@ -16,24 +18,29 @@ global_settings = GlobalSettings(
)
data_settings = DataSettings(
config_path="data/*-128-16384-100000-0-0-17-0-PAM4-0.ini",
# config_path="data/*-128-16384-100000-0-0-17-0-PAM4-0.ini",
config_path="data/20241204-131003-128-16384-100000-0-0-17-0-PAM4-0.ini",
dtype="complex64",
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
symbols=13, # study: single_core_regen_20241123_011232
# symbols=13, # study: single_core_regen_20241123_011232
# symbols = (3, 13),
symbols=4,
# output_size = (11, 32), # ballpark 26 taps -> 2 taps per input symbol -> 1 tap every 0.01m (model has 52 inputs)
output_size=26, # study: single_core_regen_20241123_011232 (model_input_dim/2)
# output_size=26, # study: single_core_regen_20241123_011232 (model_input_dim/2)
output_size=(8, 30),
shuffle=True,
in_out_delay=0,
xy_delay=0,
drop_first=128 * 100,
drop_first=256,
train_split=0.8,
randomise_polarisations=False,
)
pytorch_settings = PytorchSettings(
epochs=10000,
epochs=10,
batchsize=2**10,
device="cuda",
dataloader_workers=12,
dataloader_workers=4,
dataloader_prefetch=4,
summary_dir=".runs",
write_every=2**5,
@@ -43,28 +50,70 @@ pytorch_settings = PytorchSettings(
model_settings = ModelSettings(
output_dim=2,
# n_hidden_layers = (3, 8),
n_hidden_layers=4,
overrides={
"n_hidden_nodes_0": 8,
"n_hidden_nodes_1": 6,
"n_hidden_nodes_2": 4,
"n_hidden_nodes_3": 8,
},
model_activation_func="Mag",
# satabsT0=(1e-6, 1),
n_hidden_layers = (2, 5),
n_hidden_nodes=(2, 16),
model_activation_func="EOActivation",
dropout_prob=0,
model_layer_function="ONNRect",
model_layer_kwargs={"square": True},
# scale=(False, True),
scale=False,
model_layer_parametrizations=[
{
"tensor_name": "weight",
"parametrization": util.complexNN.energy_conserving,
},
{
"tensor_name": "alpha",
"parametrization": util.complexNN.clamp,
},
{
"tensor_name": "gain",
"parametrization": util.complexNN.clamp,
"kwargs": {
"min": 0,
"max": float("inf"),
},
},
{
"tensor_name": "phase_bias",
"parametrization": util.complexNN.clamp,
"kwargs": {
"min": 0,
"max": 2 * torch.pi,
},
},
{
"tensor_name": "scales",
"parametrization": util.complexNN.clamp,
},
{
"tensor_name": "angle",
"parametrization": util.complexNN.clamp,
"kwargs": {
"min": -torch.pi,
"max": torch.pi,
},
},
{
"tensor_name": "loss",
"parametrization": util.complexNN.clamp,
},
],
)
optimizer_settings = OptimizerSettings(
optimizer="Adam",
# learning_rate = (1e-5, 1e-1),
learning_rate=5e-3
# learning_rate=5e-4,
optimizer="AdamW",
optimizer_kwargs={
"lr": 5e-3,
"amsgrad": True,
# "weight_decay": 1e-7,
},
)
optuna_settings = OptunaSettings(
n_trials=1,
n_workers=1,
n_trials=1024,
n_workers=8,
timeout=3600,
directions=("minimize",),
metrics_names=("mse",),

View File

@@ -1,6 +1,10 @@
# from datetime import datetime
from pathlib import Path
import matplotlib
import numpy as np
import torch
import torch.utils.tensorboard
import torch.utils.tensorboard.summary
from hypertraining.settings import (
GlobalSettings,
DataSettings,
@@ -9,7 +13,7 @@ from hypertraining.settings import (
OptimizerSettings,
)
from hypertraining.training import Trainer
from hypertraining.training import RegenerationTrainer#, PolarizationTrainer
# import torch
import json
@@ -22,26 +26,39 @@ global_settings = GlobalSettings(
)
data_settings = DataSettings(
# config_path="data/*-128-16384-50000-0-0-17-0-PAM4-0.ini",
config_path=[f"data/*-128-16384-{length}-0-0-17-0-PAM4-0.ini" for length in (40000, 50000, 60000)],
# config_path="data/20250115-114319-128-16384-inf-100000-0-0-0-0-PAM4-1-0-10.ini", # no effects enabled - baseline
# config_path = "data/20250115-233553-128-16384-1060.0-100000-0-0.2-17.0-0.058-PAM4-1.0-0.0-10.ini", # dispersion + slope only
# config_path="data/20250115-115836-128-16384-60.0-100000-0-0.2-17-0.058-PAM4-1000-0.2-10.ini", # all linear effects enabled with realistic values + noise + pmd (delta_beta=0.2) + ortho_error = 0.1
# config_path="data/20250118-225840-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini", # a)
# config_path="data/20250116-214606-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini", # b)
# config_path="data/20250116-214547-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini", # c)
# config_path="data/20250117-143926-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini", # d) 10ps dgd
config_path="data/20250120-105720-128-16384-inf-100000-0-0.2-17-0.058-PAM4-0-0-10.ini", # d) 10ns
# config_path="data/20250114-215547-128-16384-60.0-100000-1.15-0.2-17-0.058-PAM4-1-0-10.ini", # with gamma=1.15, 2.5dBm launch power, no pmd
dtype="complex64",
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
symbols=13, # study: single_core_regen_20241123_011232
symbols=4, # study: single_core_regen_20241123_011232 -> taps spread over 4 symbols @ 10GBd
# output_size = (11, 32), # 26 taps -> 2 taps per input symbol -> 1 tap every 1cm (model has 52 inputs (x/y))
output_size=26, # study: single_core_regen_20241123_011232 (model_input_dim/2)
shuffle=True,
in_out_delay=0,
xy_delay=0,
drop_first=128 * 64,
output_size=20, # = model_input_dim/2 study: single_core_regen_20241123_011232
shuffle=False,
drop_first=256,
drop_last=256,
train_split=0.8,
randomise_polarisations=False,
polarisations=False,
# cross_pol_interference=0.01,
osnr=16, #16dB due to amplification with NF 5
)
pytorch_settings = PytorchSettings(
epochs=10000,
batchsize=2**12,
epochs=1000,
batchsize=2**13,
device="cuda",
dataloader_workers=12,
dataloader_prefetch=8,
dataloader_workers=32,
dataloader_prefetch=4,
summary_dir=".runs",
write_every=2**5,
save_models=True,
@@ -50,70 +67,51 @@ pytorch_settings = PytorchSettings(
model_settings = ModelSettings(
output_dim=2,
n_hidden_layers=4,
n_hidden_layers=3,
overrides={
"n_hidden_nodes_0": 4,
"n_hidden_nodes_1": 4,
"n_hidden_nodes_2": 4,
"n_hidden_nodes_3": 4,
# "hidden_layer_dims": (8, 8, 4, 4),
"n_hidden_nodes_0": 16,
"n_hidden_nodes_1": 8,
"n_hidden_nodes_2": 8,
# "n_hidden_nodes_3": 4,
# "n_hidden_nodes_4": 2,
},
model_activation_func="EOActivation",
dropout_prob=0.01,
dropout_prob=0,
model_layer_function="ONNRect",
model_layer_kwargs={"square": True},
scale=True,
scale=2.0,
model_layer_parametrizations=[
{
"tensor_name": "weight",
"parametrization": util.complexNN.energy_conserving,
},
# EOactivation
{
"tensor_name": "alpha",
"parametrization": util.complexNN.clamp,
"kwargs": {
"min": 0,
"max": 1,
},
},
# ONNRect
{
"tensor_name": "gain",
"tensor_name": "weight",
"parametrization": torch.nn.utils.parametrizations.orthogonal,
},
# Scale
{
"tensor_name": "scale",
"parametrization": util.complexNN.clamp,
"kwargs": {
"min": 0,
"max": float("inf"),
"max": 10,
},
},
{
"tensor_name": "phase_bias",
"parametrization": util.complexNN.clamp,
"kwargs": {
"min": 0,
"max": 2 * torch.pi,
},
},
{
"tensor_name": "scales",
"parametrization": util.complexNN.clamp,
},
# {
# "tensor_name": "scale",
# "parametrization": util.complexNN.clamp,
# },
# {
# "tensor_name": "bias",
# "parametrization": util.complexNN.clamp,
# },
# {
# "tensor_name": "V",
# "parametrization": torch.nn.utils.parametrizations.orthogonal,
# },
{
"tensor_name": "loss",
"parametrization": util.complexNN.clamp,
},
}
],
)
optimizer_settings = OptimizerSettings(
optimizer="AdamW",
optimizer_kwargs={
"lr": 0.05,
"lr": 0.005,
"amsgrad": True,
# "weight_decay": 1e-7,
},
@@ -121,96 +119,35 @@ optimizer_settings = OptimizerSettings(
scheduler="ReduceLROnPlateau",
scheduler_kwargs={
"patience": 2**6,
"factor": 0.75,
"factor": 0.5,
# "threshold": 1e-3,
"min_lr": 1e-6,
"cooldown": 10,
},
early_stopping=True,
early_stop_kwargs={
"threshold": 1e-06,
"plateau": 2**7,
}
)
def save_dict_to_file(dictionary, filename):
"""
Save the best dictionary to a JSON file.
:param best: Dictionary containing the best training results.
:type best: dict
:param filename: Path to the JSON file where the dictionary will be saved.
:type filename: str
"""
with open(filename, "w") as f:
json.dump(dictionary, f, indent=4)
def sweep_lengths(*lengths, model=None):
assert model is not None, "Model must be provided."
model = model
fiber_ins = {}
fiber_outs = {}
regens = {}
timestampss = {}
for length in lengths:
trainer = Trainer(
checkpoint_path=model,
settings_override={
"data_settings": {
"config_path": f"data/*-128-16384-{length}-0-0-17-0-PAM4-0.ini",
"train_split": 1,
"shuffle": True,
}
},
)
trainer.define_model()
loader, _ = trainer.get_sliced_data()
fiber_in, fiber_out, regen, timestamps = trainer.run_model(trainer.model, loader=loader)
fiber_ins[length] = fiber_in
fiber_outs[length] = fiber_out
regens[length] = regen
timestampss[length] = timestamps
data = torch.zeros(2 * len(lengths), 2, fiber_out.shape[0])
channel_names = ["" for _ in range(2 * len(lengths))]
for li, length in enumerate(lengths):
data[2 * li, 0, :] = timestampss[length] / 128
data[2 * li, 1, :] = regens[length][:, 0].abs().square()
data[2 * li + 1, 0, :] = timestampss[length] / 128
data[2 * li + 1, 1, :] = regens[length][:, 1].abs().square()
channel_names[2 * li] = f"regen x {length}"
channel_names[2 * li + 1] = f"regen y {length}"
# get current backend
backend = matplotlib.get_backend()
matplotlib.use("TkCairo")
eye = util.eye_diagram.eye_diagram(data.to(dtype=torch.float32).detach().cpu().numpy(), channel_names=channel_names)
print_attrs = ("channel", "success", "min_area")
with np.printoptions(precision=3, suppress=True, formatter={'float': '{:0.3e}'.format}):
for result in eye.eye_stats:
print_dict = {attr: result[attr] for attr in print_attrs}
rprint(print_dict)
rprint()
eye.plot()
matplotlib.use(backend)
if __name__ == "__main__":
# sweep_lengths(30000, 40000, 50000, 60000, 70000, model=".models/best_20241202_143149.tar")
trainer = Trainer(
trainer = RegenerationTrainer(
global_settings=global_settings,
data_settings=data_settings,
pytorch_settings=pytorch_settings,
model_settings=model_settings,
optimizer_settings=optimizer_settings,
# checkpoint_path=".models/best_20241202_143149.tar",
# 20241202_143149
checkpoint_path=".models/best_20250117_144001.tar",
new_model=True,
settings_override={
"data_settings": data_settings.__dict__,
# "optimizer_settings": {
# "early_stop_kwargs":{
# "plateau": 2**8,
# }
# }
}
)
trainer.train()
trainer.train()

View File

@@ -0,0 +1,729 @@
"""
generate_signal.py
This file is part of the repo "optical-regeneration"
https://git.suuppl.dev/seppl/optical-regeneration.git
Joseph Hopfmüller
Copyright 2024
Licensed under the EUPL
Full license text in LICENSE file
"""
import configparser
# import copy
from datetime import datetime
import hashlib
from pathlib import Path
import time
import h5py
from matplotlib import pyplot as plt # noqa: F401
import numpy as np
from . import add_pypho # noqa: F401
import pypho
default_config = f"""
[glova]
sps = 128
nos = 16384
f0 = 193414489032258.06
symbolrate = 10e9
wisdom_dir = "{str((Path.home() / ".pypho"))}"
flags = "FFTW_PATIENT"
nthreads = 32
[fiber]
length = 10000
gamma = 1.14
alpha = 0.2
D = 17
S = 0.058
bireflength = 10
pmd_q = 0.2
; birefseed = 0xC0FFEE
[signal]
; seed = 0xC0FFEE
modulation = "pam"
mod_order = 4
mod_depth = 1
max_jitter = 0.02
; jitter_seed = 0xC0FFEE
laser_power = 0
edfa_power = 0
edfa_nf = 5
pulse_shape = "gauss"
fwhm = 0.33
osnr = "inf"
[data]
dir = "data"
npy_dir = "npys"
"""
def get_config(config_file=None):
"""
DANGER! The function uses eval() to parse the config file. Do not use this function with untrusted input.
"""
if config_file is None:
config_file = Path(__file__).parent / "signal_generation.ini"
config_file = Path(config_file)
if not config_file.exists():
with open(config_file, "w") as f:
f.write(default_config)
config = configparser.ConfigParser()
config.read(config_file)
conf = {}
for section in config.sections():
# print(f"[{section}]")
conf[section] = {}
for key in config[section]:
# print(f"{key} = {config[section][key]}")
try:
conf[section][key] = eval(config[section][key])
except NameError:
conf[section][key] = float(config[section][key])
# if isinstance(conf[section][key], str):
# conf[section][key] = config[section][key].strip('"')
return conf
class PDM_IM_IPM:
def __init__(
self,
glova,
mod_order=8,
seed=None,
):
assert np.cbrt(mod_order) == int(np.cbrt(mod_order)) and mod_order > 1, (
"mod_order must be a cube of an integer greater than 1"
)
self.glova = glova
self.mod_order = mod_order
self.symbols_per_dim = int(np.cbrt(mod_order))
self.seed = seed
def generate_symbols(self, n):
rs = np.random.RandomState(self.seed)
symbols = rs.randint(0, self.mod_order, n)
return symbols
class pam_generator:
def __init__(
self, glova, mod_order=None, mod_depth=0.5, pulse_shape="gauss", fwhm=0.33, seed=None, single_channel=False
) -> None:
self.glova = glova
self.pulse_shape = pulse_shape
self.modulation_depth = mod_depth
self.mod_order = mod_order
self.fwhm = fwhm
self.seed = seed
self.single_channel = single_channel
def __call__(self, E, symbols, max_jitter=0):
max_jitter = int(round(max_jitter * self.glova.sps))
if self.pulse_shape == "gauss":
wavelet = self.gauss(oversampling=6)
else:
raise ValueError(f"Unknown pulse shape: {self.pulse_shape}")
# prepare symbols
symbols_x = symbols[0] / (self.mod_order)
diffs_x = np.diff(symbols_x, prepend=symbols_x[0])
digital_x = self.generate_digital_signal(diffs_x, max_jitter)
digital_x = np.pad(digital_x, (0, self.glova.sps // 2), "constant", constant_values=(0, 0))
# create analog signal of diff of symbols
E_x = np.convolve(digital_x, wavelet)
# convert to pam and set modulation depth (scale and move up such that 1 stays at 1)
E_x = np.cumsum(E_x) * self.modulation_depth + (1 - self.modulation_depth)
# cut off the wavelet tails
E_x = E_x[self.glova.sps // 2 + len(wavelet) // 2 - 1 : -len(wavelet) // 2]
# modulate the laser
E[0]["E"][0] = np.sqrt(np.multiply(np.square(E[0]["E"][0]), E_x))
if not self.single_channel:
symbols_y = symbols[1] / (self.mod_order)
diffs_y = np.diff(symbols_y, prepend=symbols_y[0])
digital_y = self.generate_digital_signal(diffs_y, max_jitter)
digital_y = np.pad(digital_y, (0, self.glova.sps // 2), "constant", constant_values=(0, 0))
E_y = np.convolve(digital_y, wavelet)
E_y = np.cumsum(E_y) * self.modulation_depth + (1 - self.modulation_depth)
E_y = E_y[self.glova.sps // 2 + len(wavelet) // 2 - 1 : -len(wavelet) // 2]
E[0]["E"][1] = np.sqrt(np.multiply(np.square(E[0]["E"][1]), E_y))
# rotate the signal on the y-polarisation by 90°
# E[0]["E"][1] *= 1j
else:
E[0]["E"][1] = np.zeros_like(E[0]["E"][0], dtype=E[0]["E"][0].dtype)
return E
def generate_digital_signal(self, symbols, max_jitter=0):
rs = np.random.RandomState(self.seed)
signal = np.zeros(self.glova.nos * self.glova.sps)
for index in range(self.glova.nos):
jitter = max_jitter != 0 and rs.randint(-max_jitter, max_jitter)
signal_index = index * self.glova.sps + jitter
if signal_index < 0:
continue
if signal_index >= len(signal):
continue
signal[signal_index] = symbols[index]
return signal
def gauss(self, oversampling=1):
sample_points = np.linspace(
-oversampling * self.glova.sps,
oversampling * self.glova.sps,
oversampling * 2 * self.glova.sps,
endpoint=True,
)
sigma = self.fwhm / (1 * np.sqrt(2 * np.log(2))) * self.glova.sps
pulse = 1 / (sigma * np.sqrt(2 * np.pi)) * np.exp(-np.square(sample_points) / (2 * np.square(sigma)))
return pulse
def initialize_fiber_and_data(config):
f0 = config["glova"].get("f0", None)
if f0 is None:
f0 = 299792458/(config["glova"].get("lambda0", 1550)*1e-9)
config["glova"]["f0"] = f0
py_glova = pypho.setup(
nos=config["glova"]["nos"],
sps=config["glova"]["sps"],
f0=f0,
symbolrate=config["glova"]["symbolrate"],
wisdom_dir=config["glova"]["wisdom_dir"],
flags=config["glova"]["flags"],
nthreads=config["glova"]["nthreads"],
)
c_glova = pypho.cfiber.GlovaWrapper.from_setup(py_glova)
c_data = pypho.cfiber.DataWrapper(py_glova.sps * py_glova.nos)
py_edfa = pypho.oamp(py_glova, Pmean=config["signal"]["edfa_power"], NF=config["signal"]["edfa_nf"])
osnr = config["signal"]["osnr"] = float(config["signal"].get("osnr", "inf"))
config["signal"]["seed"] = config["signal"].get("seed", (int(time.time() * 1000)) % 2**32)
config["signal"]["jitter_seed"] = config["signal"].get("jitter_seed", (int(time.time() * 1000)) % 2**32)
symbolsrc = pypho.symbols(
py_glova, py_glova.nos, pattern="ones", p1=config["signal"]["mod_order"], seed=config["signal"]["seed"]
)
laserx = pypho.lasmod(py_glova, power=0, Df=0, theta=np.pi/4)
# lasery = pypho.lasmod(py_glova, power=0, Df=25, theta=0)
modulator = pam_generator(
py_glova,
mod_depth=config["signal"]["mod_depth"],
pulse_shape=config["signal"]["pulse_shape"],
fwhm=config["signal"]["fwhm"],
seed=config["signal"]["jitter_seed"],
mod_order=config["signal"]["mod_order"],
)
symbols_x = symbolsrc(pattern="random")
symbols_y = symbolsrc(pattern="random")
symbols_x[:3] = 0
symbols_y[:3] = 0
# symbols_x += 1
cw = laserx()
# cwy = lasery()
# cw[0]['E'][0] = cw[0]['E'][0]
# cw[0]['E'][1] = cwy[0]['E'][0]
source_signal = modulator(E=cw, symbols=(symbols_x, symbols_y))
if osnr != float("inf"):
osnr_lin = 10 ** (osnr / 10)
signal_power = np.sum(pypho.functions.getpower_W(source_signal[0]["E"]))
noise_power = signal_power / osnr_lin
noise = np.random.normal(0, 1, source_signal[0]["E"].shape) + 1j * np.random.normal(
0, 1, source_signal[0]["E"].shape
)
noise_power_is = np.sum(pypho.functions.getpower_W(noise))
noise = noise * np.sqrt(noise_power / noise_power_is)
noise_power_is = np.sum(pypho.functions.getpower_W(noise))
source_signal[0]["E"] += noise
source_signal[0]["noise"] = noise_power_is
# source_signal[0]['E'][1] += source_signal[0]['E'][1][np.argmin(np.abs(source_signal[0]['E'][1]))]
## side channels
# df = 100
# signal_power = pypho.functions.W_to_dBm(np.sum(pypho.functions.getpower_W(source_signal[0]["E"])))
# symbols_x_side = symbolsrc(pattern="random")
# symbols_y_side = symbolsrc(pattern="random")
# symbols_x_side[:3] = 0
# symbols_y_side[:3] = 0
# cw_left = laser(Df=-df)
# source_signal_left = modulator(E=cw_left, symbols=(symbols_x_side, symbols_y_side))
# cw_right = laser(Df=df)
# source_signal_right = modulator(E=cw_right, symbols=(symbols_y_side, symbols_x_side))
E_in_pure = source_signal[0]["E"]
nf = py_edfa.NF
pmean = py_edfa.Pmean
# ideal amplification to launch power into fiber
source_signal = py_edfa(E=source_signal, NF=0, Pmean=config["signal"]["laser_power"])
# source_signal_left = py_edfa(E=source_signal_left, NF=0, Pmean=config["signal"]["laser_power"])
# source_signal_right = py_edfa(E=source_signal_right, NF=0, Pmean=config["signal"]["laser_power"])
# source_signal[0]["E"][0] += source_signal_left[0]["E"][0] + source_signal_right[0]["E"][0]
# source_signal[0]["E"][1] += source_signal_left[0]["E"][1] + source_signal_right[0]["E"][1]
c_data.E_in = source_signal[0]["E"]
noise = source_signal[0]["noise"]
py_edfa.NF = nf
py_edfa.Pmean = pmean
py_fiber = pypho.fiber(
glova=py_glova,
l=config["fiber"]["length"],
alpha=pypho.functions.dB_to_Neper(config["fiber"]["alpha"]) / 1000,
gamma=config["fiber"]["gamma"],
D=config["fiber"]["d"],
S=config["fiber"]["s"],
phi_max=0.02,
)
config["fiber"]["birefsteps"] = config["fiber"].get(
"birefsteps", config["fiber"]["length"] // config["fiber"].get("bireflength", config["fiber"]["length"])
)
if config["fiber"]["birefsteps"] > 0:
config["fiber"]["bireflength"] = config["fiber"].get("bireflength", config["fiber"]["length"] / config["fiber"]["birefsteps"])
seed = config["fiber"].get("birefseed", (int(time.time() * 1000)) % 2**32)
py_fiber.birefarray = pypho.birefringence_segment.create_pmd_fibre(
config["fiber"]["length"],
config["fiber"]["bireflength"],
maxDeltaBeta=config["fiber"].get("pmd_q", 0)/np.sqrt(2*config["fiber"]["bireflength"]),
seed=seed,
)
elif (dgd := config['fiber'].get('dgd', 0)) > 0:
py_fiber.birefarray = [
pypho.birefringence_segment(z_point=0, angle=np.pi/2, delta_beta=1000*dgd/config["fiber"]["length"])
]
c_params = pypho.cfiber.ParamsWrapper.from_fiber(py_fiber, max_step=config["fiber"]["length"] if py_fiber.gamma == 0 else 200)
c_fiber = pypho.cfiber.FiberWrapper(c_data, c_params, c_glova)
return c_fiber, c_data, noise, py_edfa, (symbols_x, symbols_y), py_glova, E_in_pure
def save_data(data, config, **metadata):
data_dir = Path(config["data"]["dir"])
npy_dir = config["data"].get("npy_dir", "")
save_dir = data_dir / npy_dir if len(npy_dir) else data_dir
save_dir.mkdir(parents=True, exist_ok=True)
save_data = np.column_stack([
data.E_in[0],
data.E_in[1],
data.E_out[0],
data.E_out[1],
])
timestamp = datetime.now()
seed = config["signal"].get("seed", False)
jitter_seed = config["signal"].get("jitter_seed", False)
birefseed = config["fiber"].get("birefseed", False)
osnr = float(config["signal"].get("osnr", "inf"))
config_content = "\n".join((
f"; Generated by {str(Path(__file__).name)} @ {timestamp.strftime('%Y-%m-%d %H:%M:%S')}",
"[glova]",
f"sps = {config['glova']['sps']}",
f"nos = {config['glova']['nos']}",
f"f0 = {config['glova']['f0']}",
f"symbolrate = {config['glova']['symbolrate']}",
f'wisdom_dir = "{config["glova"]["wisdom_dir"]}"',
f'flags = "{config["glova"]["flags"]}"',
f"nthreads = {config['glova']['nthreads']}",
"",
"[fiber]",
f"length = {config['fiber']['length']}",
f"gamma = {config['fiber']['gamma']}",
f"alpha = {config['fiber']['alpha']}",
f"D = {config['fiber']['d']}",
f"S = {config['fiber']['s']}",
f"birefsteps = {config['fiber'].get('birefsteps', 0)}",
f"pmd_q = {config['fiber'].get('pmd_q', 0)}",
f"birefseed = {hex(birefseed)}" if birefseed else "; birefseed = not set",
f"dgd = {config['fiber'].get('dgd', 0)}",
f"ortho_error = {config['fiber'].get('ortho_error', 0)}",
f"pol_error = {config['fiber'].get('pol_error', 0)}",
"",
"[signal]",
f"seed = {hex(seed)}" if seed else "; seed = not set",
"",
f'modulation = "{config["signal"]["modulation"]}"',
f"mod_order = {config['signal']['mod_order']}",
f"mod_depth = {config['signal']['mod_depth']}",
"",
f"max_jitter = {config['signal']['max_jitter']}",
f"jitter_seed = {hex(jitter_seed)}" if jitter_seed else "; jitter_seed = not set",
"",
f"laser_power = {config['signal']['laser_power']}",
f"edfa_power = {config['signal']['edfa_power']}",
f"edfa_nf = {config['signal']['edfa_nf']}",
f"osnr = {osnr}",
"",
f'pulse_shape = "{config["signal"]["pulse_shape"]}"',
f"fwhm = {config['signal']['fwhm']}",
"",
"[data]",
f'dir = "{str(data_dir)}"',
f'npy_dir = "{npy_dir}"',
"file = ",
))
config_hash = hashlib.md5(config_content.encode()).hexdigest()
save_file = f"{config_hash}.h5"
config_content += f'"{str(save_file)}"\n'
config_filename:Path = create_config_filename(config, data_dir, timestamp)
while config_filename.exists():
time.sleep(1)
config_filename = create_config_filename(config, data_dir=data_dir)
with open(config_filename, "w") as f:
f.write(config_content)
with h5py.File(save_dir / save_file, "w") as outfile:
outfile.create_dataset("data", data=save_data)
outfile.create_dataset("symbols", data=metadata.pop("symbols"))
for key, value in metadata.items():
# if isinstance(value, dict):
# value = json.dumps(model_runner.convert_arrays(value))
outfile.attrs[key] = value
# np.save(save_dir / save_file, save_data)
# print("Saved config to", config_filename)
# print("Saved data to", save_dir / save_file)
return config_filename
def create_config_filename(config, data_dir:Path, timestamp=None):
if timestamp is None:
timestamp = datetime.now()
filename_components = (
timestamp.strftime("%Y%m%d-%H%M%S"),
config["glova"]["sps"],
config["glova"]["nos"],
config["signal"]["osnr"],
config["fiber"]["length"],
config["fiber"]["gamma"],
config["fiber"]["alpha"],
config["fiber"]["d"],
config["fiber"]["s"],
f"{config['signal']['modulation'].upper()}{config['signal']['mod_order']}",
config["fiber"].get("birefsteps", 0),
config["fiber"].get("pmd_q", 0),
int(config["glova"]["symbolrate"] / 1e9),
)
lookup_file = "-".join(map(str, filename_components)) + ".ini"
return data_dir / lookup_file
def length_loop(config, lengths, save=True):
lengths = sorted(lengths)
for length in lengths:
print(f"\nGenerating data for fiber length {length}m")
config["fiber"]["length"] = length
cfiber, cdata, noise, edfa, symbols, py_glova = initialize_fiber_and_data(config)
mean_power_in = np.sum(pypho.functions.getpower_W(cdata.E_in))
cfiber()
mean_power_out = np.sum(pypho.functions.getpower_W(cdata.E_out))
print(f"Mean power in: {mean_power_in:.3e} W ({pypho.functions.W_to_dBm(mean_power_in):.3e} dBm)")
print(f"Mean power out: {mean_power_out:.3e} W ({pypho.functions.W_to_dBm(mean_power_out):.3e} dBm)")
E_tmp = [{"E": cdata.E_out, "noise": noise * (-cfiber.params.l * cfiber.params.alpha)}]
E_tmp = edfa(E=E_tmp)
cdata.E_out = E_tmp[0]["E"]
mean_power_amp = np.sum(pypho.functions.getpower_W(cdata.E_out))
print(f"Mean power after EDFA: {mean_power_amp:.3e} W ({pypho.functions.W_to_dBm(mean_power_amp):.3e} dBm)")
if save:
save_data(cdata, config)
in_out_eyes(cfiber, cdata)
def single_run_with_plot(config, save=True):
cfiber, cdata, config_filename = single_run(config, save)
in_out_eyes(cfiber, cdata, show_pols=False)
return config_filename
def single_run(config, save=True, silent=True):
cfiber, cdata, noise, edfa, symbols, glova, E_in = initialize_fiber_and_data(config)
# transmit
cfiber()
# amplify
E_tmp = [{"E": cdata.E_out, "noise": noise}]
E_tmp = edfa(E=E_tmp)
# rotate
# ortho error
ortho_error = config["fiber"].get("ortho_error", 0)
E_tmp[0]["E"] = np.stack((
E_tmp[0]["E"][0] * np.cos(ortho_error/2) + E_tmp[0]["E"][1] * np.sin(ortho_error/2),
E_tmp[0]["E"][0] * np.sin(ortho_error/2) + E_tmp[0]["E"][1] * np.cos(ortho_error/2)
), axis=0)
pol_error = config['fiber'].get('pol_error', 0)
E_tmp[0]["E"] = np.stack((
E_tmp[0]["E"][0] * np.cos(pol_error) - E_tmp[0]["E"][1] * np.sin(pol_error),
E_tmp[0]["E"][0] * np.sin(pol_error) + E_tmp[0]["E"][1] * np.cos(pol_error)
), axis=0)
# output
cdata.E_out = E_tmp[0]["E"]
config_filename = None
symbols = np.array(symbols)
if save:
config_filename = save_data(cdata, config, **{"symbols": symbols})
if not silent:
print(f"Saved config to {config_filename}")
return cfiber, cdata, config_filename
def in_out_eyes(cfiber, cdata, show_pols=False):
fig, axs = plt.subplots(2, 2, sharex=True, sharey=True)
eye_head = min(cfiber.glova.nos, 2000)
symbolrate_scale = 1e12
amplitude_scale = 1e3
plot_eye_diagram(
amplitude_scale * np.abs(cdata.E_in[0]) ** 2,
2 * cfiber.glova.sps,
normalize=False,
samplerate=cfiber.glova.symbolrate * cfiber.glova.sps / symbolrate_scale,
head=eye_head,
ax=axs[0][0],
show=False,
color="C0",
)
if show_pols:
plot_eye_diagram(
amplitude_scale * np.abs(cdata.E_in[0].real) ** 2,
2 * cfiber.glova.sps,
normalize=False,
samplerate=cfiber.glova.symbolrate * cfiber.glova.sps / symbolrate_scale,
head=eye_head,
ax=axs[0][0],
color="C2",
show=False,
)
plot_eye_diagram(
amplitude_scale * np.abs(cdata.E_in[0].imag) ** 2,
2 * cfiber.glova.sps,
normalize=False,
samplerate=cfiber.glova.symbolrate * cfiber.glova.sps / symbolrate_scale,
head=eye_head,
ax=axs[0][0],
color="C3",
show=False,
)
plot_eye_diagram(
amplitude_scale * np.abs(cdata.E_out[0]) ** 2,
2 * cfiber.glova.sps,
normalize=False,
samplerate=cfiber.glova.symbolrate * cfiber.glova.sps / symbolrate_scale,
head=eye_head,
ax=axs[0][1],
color="C1",
show=False,
)
if show_pols:
plot_eye_diagram(
amplitude_scale * np.abs(cdata.E_out[0].real) ** 2,
2 * cfiber.glova.sps,
normalize=False,
samplerate=cfiber.glova.symbolrate * cfiber.glova.sps / symbolrate_scale,
head=eye_head,
ax=axs[0][1],
color="C4",
show=False,
)
plot_eye_diagram(
amplitude_scale * np.abs(cdata.E_out[0].imag) ** 2,
2 * cfiber.glova.sps,
normalize=False,
samplerate=cfiber.glova.symbolrate * cfiber.glova.sps / symbolrate_scale,
head=eye_head,
ax=axs[0][1],
color="C5",
show=False,
)
plot_eye_diagram(
amplitude_scale * np.abs(cdata.E_in[1]) ** 2,
2 * cfiber.glova.sps,
normalize=False,
samplerate=cfiber.glova.symbolrate * cfiber.glova.sps / symbolrate_scale,
head=eye_head,
ax=axs[1][0],
color="C0",
show=False,
)
if show_pols:
plot_eye_diagram(
amplitude_scale * np.abs(cdata.E_in[1].real) ** 2,
2 * cfiber.glova.sps,
normalize=False,
samplerate=cfiber.glova.symbolrate * cfiber.glova.sps / symbolrate_scale,
head=eye_head,
ax=axs[1][0],
color="C2",
show=False,
)
plot_eye_diagram(
amplitude_scale * np.abs(cdata.E_in[1].imag) ** 2,
2 * cfiber.glova.sps,
normalize=False,
samplerate=cfiber.glova.symbolrate * cfiber.glova.sps / symbolrate_scale,
head=eye_head,
ax=axs[1][0],
color="C3",
show=False,
)
plot_eye_diagram(
amplitude_scale * np.abs(cdata.E_out[1]) ** 2,
2 * cfiber.glova.sps,
normalize=False,
samplerate=cfiber.glova.symbolrate * cfiber.glova.sps / symbolrate_scale,
head=eye_head,
ax=axs[1][1],
color="C1",
show=False,
)
if show_pols:
plot_eye_diagram(
amplitude_scale * np.abs(cdata.E_out[1].real) ** 2,
2 * cfiber.glova.sps,
normalize=False,
samplerate=cfiber.glova.symbolrate * cfiber.glova.sps / symbolrate_scale,
head=eye_head,
ax=axs[0][1],
color="C4",
show=False,
)
plot_eye_diagram(
amplitude_scale * np.abs(cdata.E_out[1].imag) ** 2,
2 * cfiber.glova.sps,
normalize=False,
samplerate=cfiber.glova.symbolrate * cfiber.glova.sps / symbolrate_scale,
head=eye_head,
ax=axs[0][1],
color="C5",
show=False,
)
title_map = [
["Input x", "Output x"],
["Input y", "Output y"],
]
title_map = np.array(title_map)
for ax, title in zip(axs.flatten(), title_map.flatten()):
ax.grid(True)
ax.set_xlabel("Time [ps]")
ax.set_ylabel("Power [mW]")
ax.set_title(title)
fig.tight_layout()
plt.show()
def plot_eye_diagram(
signal: np.ndarray,
eye_width,
offset=0,
*,
head=None,
samplerate=1,
normalize=True,
ax=None,
color="C0",
show=True,
):
ax = ax or plt.gca()
if head is not None:
signal = signal[: head * eye_width]
if normalize:
signal = signal / np.max(signal)
slices = np.lib.stride_tricks.sliding_window_view(signal, eye_width + 1)[offset % (eye_width + 1) :: eye_width]
plt_ax = np.arange(-eye_width // 2, eye_width // 2 + 1) / samplerate
for slice in slices:
ax.plot(plt_ax, slice, color=color, alpha=0.1)
ax.grid()
if show:
plt.show()
if __name__ == "__main__":
add_pypho.show_log()
config = get_config()
# ranges = (1000,10000)
# scales = tuple(range(1, 10))
# scales = (1,)
# lengths = [range_ * scale for range_ in ranges for scale in scales]
# lengths.append(10*max(ranges))
# lengths = [*lengths, *lengths]
lengths = (
# 8000, 9000,
10000,
20000,
30000,
40000,
50000,
60000,
70000,
80000,
90000,
95000,
100000,
105000,
110000,
115000,
120000,
)
# lengths = (10000,100000)
# length_loop(config, lengths, save=True)
# birefringence is constant over coupling length -> several 100m -> bireflength=1000 (m)
single_run_with_plot(config, save=False)

View File

@@ -39,7 +39,7 @@ import numpy as np
if __name__ == "__main__":
dataset = FiberRegenerationDataset("data/20241115-175517-128-16384-10000-0-0-17-0-PAM4-0.ini", symbols=13, drop_first=100, output_dim=26, num_symbols=100)
dataset = FiberRegenerationDataset("data/202412*-128-16384-50000-0-0-17-0-PAM4-0.ini", symbols=13, drop_first=100, output_dim=26, num_symbols=100)
loader = DataLoader(dataset, batch_size=10, shuffle=True)

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,234 @@
from datetime import datetime
from pathlib import Path
import matplotlib
import numpy as np
import torch
import torch.utils.tensorboard
import torch.utils.tensorboard.summary
from hypertraining.settings import (
GlobalSettings,
DataSettings,
PytorchSettings,
ModelSettings,
OptimizerSettings,
)
from hypertraining.training import RegenerationTrainer, PolarizationTrainer
# import torch
import json
import util
from rich import print as rprint
global_settings = GlobalSettings(
seed=0xC0FFEE,
)
data_settings = DataSettings(
config_path="data/20241211-105524-128-16384-1-0-0-0-0-PAM4-0-0.ini",
# config_path=[f"data/20241202-*-128-16384-{length}-0-0-17-0-PAM4-0.ini" for length in range(48000, 53000, 1000)],
dtype="complex64",
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
symbols=13, # study: single_core_regen_20241123_011232
# output_size = (11, 32), # 26 taps -> 2 taps per input symbol -> 1 tap every 1cm (model has 52 inputs (x/y))
output_size=26, # study: single_core_regen_20241123_011232 (model_input_dim/2)
shuffle=True,
drop_first=64,
train_split=0.8,
# polarisations=tuple(np.random.rand(2)*2*np.pi),
randomise_polarisations=True,
)
pytorch_settings = PytorchSettings(
epochs=10000,
batchsize=2**12,
device="cuda",
dataloader_workers=16,
dataloader_prefetch=8,
summary_dir=".runs",
write_every=2**5,
save_models=True,
model_dir=".models",
)
model_settings = ModelSettings(
output_dim=1,
n_hidden_layers=3,
overrides={
"n_hidden_nodes_0": 4,
"n_hidden_nodes_1": 4,
"n_hidden_nodes_2": 4,
},
dropout_prob=0,
model_layer_function="ONNRect",
model_activation_func="EOActivation",
model_layer_kwargs={"square": True},
scale=False,
model_layer_parametrizations=[
{
"tensor_name": "weight",
"parametrization": util.complexNN.energy_conserving,
},
{
"tensor_name": "alpha",
"parametrization": util.complexNN.clamp,
},
{
"tensor_name": "gain",
"parametrization": util.complexNN.clamp,
"kwargs": {
"min": 0,
"max": float("inf"),
},
},
{
"tensor_name": "phase_bias",
"parametrization": util.complexNN.clamp,
"kwargs": {
"min": 0,
"max": 2 * torch.pi,
},
},
{
"tensor_name": "scales",
"parametrization": util.complexNN.clamp,
},
{
"tensor_name": "angle",
"parametrization": util.complexNN.clamp,
"kwargs": {
"min": 0,
"max": 2*torch.pi,
},
},
{
"tensor_name": "loss",
"parametrization": util.complexNN.clamp,
},
],
)
optimizer_settings = OptimizerSettings(
optimizer="RMSprop",
# optimizer="AdamW",
optimizer_kwargs={
"lr": 0.01,
"alpha": 0.9,
"momentum": 0.1,
"eps": 1e-8,
"centered": True,
# "amsgrad": True,
# "weight_decay": 1e-7,
},
scheduler="ReduceLROnPlateau",
scheduler_kwargs={
"patience": 2**5,
"factor": 0.75,
# "threshold": 1e-3,
"min_lr": 1e-6,
# "cooldown": 10,
},
)
def save_dict_to_file(dictionary, filename):
"""
Save the best dictionary to a JSON file.
:param best: Dictionary containing the best training results.
:type best: dict
:param filename: Path to the JSON file where the dictionary will be saved.
:type filename: str
"""
with open(filename, "w") as f:
json.dump(dictionary, f, indent=4)
def sweep_lengths(*lengths, model=None, data_glob: str = None, strategy="newest"):
assert model is not None, "Model must be provided."
assert data_glob is not None, "Data glob must be provided."
model = model
fiber_ins = {}
fiber_outs = {}
regens = {}
timestampss = {}
trainer = RegenerationTrainer(
checkpoint_path=model,
)
trainer.define_model()
for length in lengths:
data_glob_length = data_glob.replace("{length}", str(length))
files = list(Path.cwd().glob(data_glob_length))
if len(files) == 0:
continue
if strategy == "newest":
sorted_kwargs = {
"key": lambda x: x.stat().st_mtime,
"reverse": True,
}
elif strategy == "oldest":
sorted_kwargs = {
"key": lambda x: x.stat().st_mtime,
"reverse": False,
}
else:
raise ValueError(f"Unknown strategy {strategy}.")
file = sorted(files, **sorted_kwargs)[0]
loader, _ = trainer.get_sliced_data(override={"config_path": file})
fiber_in, fiber_out, regen, timestamps = trainer.run_model(trainer.model, loader=loader)
fiber_ins[length] = fiber_in
fiber_outs[length] = fiber_out
regens[length] = regen
timestampss[length] = timestamps
data = torch.zeros(2 * len(timestampss.keys()) + 2, 2, tuple(fiber_outs.values())[-1].shape[0])
channel_names = ["" for _ in range(2 * len(timestampss.keys()) + 2)]
data[1, 0, :] = timestampss[tuple(timestampss.keys())[-1]] / 128
data[1, 1, :] = fiber_ins[tuple(timestampss.keys())[-1]][:, 0].abs().square()
channel_names[1] = "fiber in x"
for li, length in enumerate(timestampss.keys()):
data[2 + 2 * li, 0, :] = timestampss[length] / 128
data[2 + 2 * li, 1, :] = fiber_outs[length][:, 0].abs().square()
data[2 + 2 * li + 1, 0, :] = timestampss[length] / 128
data[2 + 2 * li + 1, 1, :] = regens[length][:, 0].abs().square()
channel_names[2 + 2 * li + 1] = f"regen x {length}"
channel_names[2 + 2 * li] = f"fiber out x {length}"
# get current backend
backend = matplotlib.get_backend()
matplotlib.use("TkCairo")
eye = util.eye_diagram.eye_diagram(data.to(dtype=torch.float32).detach().cpu().numpy(), channel_names=channel_names)
print_attrs = ("channel_name", "success", "min_area")
with np.printoptions(precision=3, suppress=True, formatter={"float": "{:0.3e}".format}):
for result in eye.eye_stats:
print_dict = {attr: result[attr] for attr in print_attrs}
rprint(print_dict)
rprint()
eye.plot(all_stats=False)
matplotlib.use(backend)
if __name__ == "__main__":
trainer = PolarizationTrainer(
global_settings=global_settings,
data_settings=data_settings,
pytorch_settings=pytorch_settings,
model_settings=model_settings,
optimizer_settings=optimizer_settings,
# checkpoint_path='.models/pol_pol_20241208_122418_1116.tar',
# reset_epoch=True
)
trainer.train()

View File

@@ -260,12 +260,117 @@ class ONNRect(nn.Module):
self.crop = lambda x: x
self.crop.__doc__ = "No cropping"
def forward(self, x):
x = self.pad(x)
x = self.pad(x).to(dtype=self.weight.dtype)
out = self.crop((self.weight @ x.mT).mT)
return out
class polarimeter(nn.Module):
def __init__(self):
super(polarimeter, self).__init__()
# self.input_length = input_length
def forward(self, data):
# S0 = I
# S1 = (2*I_x - I)/I
# S2 = (2*I_45 - I)/I
# S3 = (2*I_RHC - I)/I
# # data: (batch, input_length*2) -> (batch, input_length, 2)
data = data.view(data.shape[0], -1, 2)
x = data[:, :, 0].mean(dim=1)
y = data[:, :, 1].mean(dim=1)
# x = x.mean(dim=1)
# y = y.mean(dim=1)
# angle = torch.atan2(y.abs().square().real, x.abs().square().real)
# return torch.stack([angle, angle, angle, angle], dim=1)
# horizontal polarisation
I_x = x.abs().square()
# vertical polarisation
I_y = y.abs().square()
# 45 degree polarisation
I_45 = (x + y).abs().square()
# right hand circular polarisation
I_RHC = (x + 1j*y).abs().square()
# S0 = I_x + I_y
# S1 = I_x - I_y
# S2 = I_45 - I_m45
# S3 = I_RHC - I_LHC
S0 = (I_x + I_y)
S1 = ((2*I_x - S0)/S0)
S2 = ((2*I_45 - S0)/S0)
S3 = ((2*I_RHC - S0)/S0)
return torch.stack([S0/S0, S1/S0, S2/S0, S3/S0], dim=1)
class normalize_by_first(nn.Module):
def __init__(self):
super(normalize_by_first, self).__init__()
def forward(self, data):
return data / data[:, 0].unsqueeze(1)
class rotate(nn.Module):
def __init__(self):
super(rotate, self).__init__()
def forward(self, data, angle):
# data -> (batch, n*2)
# angle -> (batch, n)
data_ = data
if angle.ndim == 1:
angle_ = angle.unsqueeze(1)
else:
angle_ = angle
angle_ = angle_.expand(-1, data_.shape[1]//2)
c = torch.cos(angle_)
s = torch.sin(angle_)
rot = torch.stack([torch.stack([c, -s], dim=2),
torch.stack([s, c], dim=2)], dim=3)
d = torch.bmm(data_.reshape(-1, 1, 2), rot.view(-1, 2, 2).to(dtype=data_.dtype)).reshape(*data.shape)
# d = torch.bmm(data.unsqueeze(-1).mT, rot.to(dtype=data.dtype).mT).mT.squeeze(-1)
return d
class photodiode(nn.Module):
def __init__(self, size, bias=True):
super(photodiode, self).__init__()
self.input_dim = size
self.scale = nn.Parameter(torch.rand(size))
self.pd_bias = nn.Parameter(torch.rand(size))
def forward(self, x):
return x.abs().square().to(dtype=x.dtype.to_real()).mul(self.scale).add(self.pd_bias)
class input_rotator(nn.Module):
def __init__(self, input_dim):
super(input_rotator, self).__init__()
assert input_dim % 2 == 0, "Input dimension must be even"
self.input_dim = input_dim
# self.angle = nn.Parameter(torch.randn(1, dtype=self.dtype.to_real()))
def forward(self, x, angle=None):
# take channels (0,1), (2,3), ... and rotate them by the angle
angle = angle or self.angle
sine = torch.sin(angle)
cosine = torch.cos(angle)
rot = torch.tensor([[cosine, -sine], [sine, cosine]], dtype=self.dtype)
return torch.matmul(x.view(-1, 2), rot).view(x.shape)
# def __repr__(self):
# return f"ONNRect({self.input_dim}, {self.output_dim})"
@@ -336,8 +441,7 @@ class ONNRect(nn.Module):
# return out
#### as defined by zhang et al
#### as defined by zhang et alas
class DropoutComplex(nn.Module):
def __init__(self, p=0.5):
@@ -359,7 +463,7 @@ class Scale(nn.Module):
self.scale = nn.Parameter(torch.ones(size, dtype=torch.float32))
def forward(self, x):
return x * self.scale
return x * torch.sqrt(self.scale)
def __repr__(self):
return f"Scale({self.size})"
@@ -371,11 +475,20 @@ class Identity(nn.Module):
M(z) = z
"""
def __init__(self):
def __init__(self, size=None):
super(Identity, self).__init__()
def forward(self, x):
return x
class phase_shift(nn.Module):
def __init__(self, size):
super(phase_shift, self).__init__()
self.size = size
self.phase = nn.Parameter(torch.rand(size))
def forward(self, x):
return x * torch.exp(1j*self.phase)
class PowRot(nn.Module):
@@ -404,54 +517,68 @@ class MZISingle(nn.Module):
def forward(self, x: torch.Tensor):
return x * torch.exp(1j * self.phi) * torch.sin(self.omega + self.func(x))
def naive_angle_loss(x: torch.Tensor, target: torch.Tensor, mod=2*torch.pi):
return torch.fmod((x.abs().real - target.abs().real), mod).abs().mean()
def cosine_loss(x: torch.Tensor, target: torch.Tensor):
return (2*(1 - torch.cos(x - target))).mean()
def angle_mse_loss(x: torch.Tensor, target: torch.Tensor):
x = torch.fmod(x, 2*torch.pi)
target = torch.fmod(target, 2*torch.pi)
x_cos = torch.cos(x)
x_sin = torch.sin(x)
target_cos = torch.cos(target)
target_sin = torch.sin(target)
cos_diff = x_cos - target_cos
sin_diff = x_sin - target_sin
squared_diff = cos_diff**2 + sin_diff**2
return squared_diff.mean()
class EOActivation(nn.Module):
def __init__(self, bias, size=None):
# 10.1109/SiPhotonics60897.2024.10543376
def __init__(self, size=None):
# 10.1109/JSTQE.2019.2930455
super(EOActivation, self).__init__()
if size is None:
raise ValueError("Size must be specified")
self.size = size
self.alpha = nn.Parameter(torch.ones(size))
self.V_bias = nn.Parameter(torch.ones(size))
self.gain = nn.Parameter(torch.ones(size))
# if bias:
# self.phase_bias = nn.Parameter(torch.zeros(size))
# else:
# self.register_buffer("phase_bias", torch.zeros(size))
self.register_buffer("phase_bias", torch.clamp(torch.ones(size) + torch.randn(size)*0.1, 0, 1)*torch.pi)
self.register_buffer("responsivity", torch.ones(size)*0.9)
self.register_buffer("V_pi", torch.ones(size)*3)
self.alpha = nn.Parameter(torch.rand(size))
self.gain = nn.Parameter(torch.rand(size))
self.V_bias = nn.Parameter(torch.rand(size))
# self.register_buffer("gain", torch.ones(size))
# self.register_buffer("responsivity", torch.ones(size))
# self.register_buffer("V_pi", torch.ones(size))
self.reset_weights()
def reset_weights(self):
if "alpha" in self._parameters:
self.alpha.data = torch.ones(self.size)*0.5
if "V_pi" in self._parameters:
self.V_pi.data = torch.ones(self.size)*3
self.alpha.data = torch.rand(self.size)
# if "V_pi" in self._parameters:
# self.V_pi.data = torch.rand(self.size)*3
if "V_bias" in self._parameters:
self.V_bias.data = torch.zeros(self.size)
self.V_bias.data = torch.randn(self.size)
if "gain" in self._parameters:
self.gain.data = torch.ones(self.size)
if "responsivity" in self._parameters:
self.responsivity.data = torch.ones(self.size)*0.9
if "bias" in self._parameters:
self.phase_bias.data = torch.zeros(self.size)
self.gain.data = torch.rand(self.size)
# if "responsivity" in self._parameters:
# self.responsivity.data = torch.ones(self.size)*0.9
# if "bias" in self._parameters:
# self.phase_bias.data = torch.zeros(self.size)
def forward(self, x: torch.Tensor):
phi_b = torch.pi * self.V_bias / (self.V_pi + 1e-8)
g_phi = torch.pi * (self.alpha * self.gain * self.responsivity) / (self.V_pi + 1e-8)
phi_b = torch.pi * self.V_bias# / (self.V_pi)
g_phi = torch.pi * (self.alpha * self.gain)# * self.responsivity)# / (self.V_pi)
intermediate = g_phi * x.abs().square() + phi_b
return (
1j
* torch.sqrt(1 - self.alpha)
* torch.exp(-0.5j * (intermediate + self.phase_bias))
* torch.exp(-0.5j * intermediate)
* torch.cos(0.5 * intermediate)
* x
)
class Pow(nn.Module):
"""
implements the activation function
@@ -574,6 +701,7 @@ class ZReLU(nn.Module):
__all__ = [
complex_sse_loss,
complex_mse_loss,
angle_mse_loss,
UnitaryLayer,
unitary,
energy_conserving,
@@ -590,6 +718,8 @@ __all__ = [
ZReLU,
MZISingle,
EOActivation,
photodiode,
phase_shift,
# SaturableAbsorberLambertW,
# SaturableAbsorber,
# SpreadLayer,

View File

@@ -0,0 +1,105 @@
# Copyright (c) 2015, Warren Weckesser. All rights reserved.
# This software is licensed according to the "BSD 2-clause" license.
import hashlib
import h5py
import numpy as _np
from scipy.interpolate import interp1d as _interp1d
from scipy.ndimage import gaussian_filter as _gaussian_filter
from ._brescount import bres_curve_count as _bres_curve_count
from pathlib import Path
__all__ = ['grid_count']
def grid_count(y, window_size, offset=0, size=None, fuzz=True, blur=0, bounds=None):
"""
Parameters
----------
`y` is the 1-d array of signal samples.
`window_size` is the number of samples to show horizontally in the
eye diagram. Typically this is twice the number of samples in a
"symbol" (i.e. in a data bit).
`offset` is the number of initial samples to skip before computing
the eye diagram. This allows the overall phase of the diagram to
be adjusted.
`size` must be a tuple of two integers. It sets the size of the
array of counts, (height, width). The default is (800, 640).
`fuzz`: If True, the values in `y` are reinterpolated with a
random "fuzz factor" before plotting in the eye diagram. This
reduces an aliasing-like effect that arises with the use of
Bresenham's algorithm.
`bounds` must be a tuple of two floating point values, (ymin, ymax).
These set the y range of the returned array. If not given, the
bounds are `(y.min() - 0.05*A, y.max() + 0.05*A)`, where `A` is
`y.max() - y.min()`.
Return Value
------------
Returns a numpy array of integers.
"""
# hash input params
param_ob = (y, window_size, offset, size, fuzz, blur, bounds)
param_hash = hashlib.md5(str(param_ob).encode()).hexdigest()
cache_dir = Path.home()/".eyediagram"/".cache"
cache_dir.mkdir(parents=True, exist_ok=True)
if (cache_dir/param_hash).is_file():
try:
with h5py.File(cache_dir/param_hash, "r") as infile:
counts = infile["counts"][:]
if counts.len() != 0:
return counts
except:
pass
if size is None:
size = (800, 640)
height, width = size
dt = width / window_size
counts = _np.zeros((width, height), dtype=_np.int32)
if bounds is None:
ymin = y.min()
ymax = y.max()
yamp = ymax - ymin
ymin = ymin - 0.05*yamp
ymax = ymax + 0.05*yamp
ymax = _np.ceil(ymax*10)/10
ymin = _np.floor(ymin*10)/10
else:
ymin, ymax = bounds
start = offset
while start + window_size < len(y):
end = start + window_size
yy = y[start:end+1]
k = _np.arange(len(yy))
xx = dt*k
if fuzz:
f = _interp1d(xx, yy, kind='cubic')
jiggle = dt*(_np.random.beta(a=3, b=3, size=len(xx)-2) - 0.5)
xx[1:-1] += jiggle
yd = f(xx)
else:
yd = yy
iyd = (height * (yd - ymin)/(ymax - ymin)).astype(_np.int32)
_bres_curve_count(xx.astype(_np.int32), iyd, counts)
start = end
if blur != 0:
counts = _gaussian_filter(counts, sigma=blur)
with h5py.File(cache_dir/param_hash, "w") as outfile:
outfile.create_dataset("data", data=counts)
return counts

View File

@@ -1,10 +1,12 @@
from pathlib import Path
import h5py
import torch
from torch.utils.data import Dataset
# from torch.utils.data import Sampler
import numpy as np
import configparser
import multiprocessing as mp
# class SubsetSampler(Sampler[int]):
# """
@@ -24,7 +26,22 @@ import configparser
# return len(self.indices)
def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=False, device=None, dtype=None):
def load_from_file(datapath):
if str(datapath).endswith(".h5"):
symbols = None
with h5py.File(datapath, "r") as infile:
data = infile["data"][:]
try:
symbols = np.swapaxes(infile["symbols"][:], 0, 1)
except KeyError:
pass
else:
symbols = None
data = np.load(datapath)
return data, symbols
def load_data(config_path, skipfirst=0, skiplast=0, symbols=None, real=False, normalize=1, device=None, dtype=None):
filepath = Path(config_path)
filepath = filepath.parent.glob(filepath.name)
config = configparser.ConfigParser()
@@ -40,25 +57,39 @@ def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=Fals
if symbols is None:
symbols = int(config["glova"]["nos"]) - skipfirst
data = np.load(datapath)[int(skipfirst * sps) : int(symbols * sps + skipfirst * sps)]
timestamps = np.arange(int(skipfirst * sps), int(symbols * sps + skipfirst * sps))
data, orig_symbols = load_from_file(datapath)
if normalize:
# square gets normalized to 1, as the power is (proportional to) the square of the amplitude
a, b, c, d = np.square(data.T)
a, b, c, d = a / np.max(np.abs(a)), b / np.max(np.abs(b)), c / np.max(np.abs(c)), d / np.max(np.abs(d))
data = np.sqrt(np.array([a, b, c, d]).T)
data = data[int(skipfirst * sps) : int(symbols * sps + skipfirst * sps - skiplast * sps)]
orig_symbols = orig_symbols[skipfirst : symbols + skipfirst - skiplast]
timestamps = np.arange(int(skipfirst * sps), int(symbols * sps + skipfirst * sps - skiplast * sps))
data *= np.sqrt(normalize)
launch_power = float(config["signal"]["laser_power"])
output_power = float(config["signal"]["edfa_power"])
target_normalization = 10 ** (output_power / 10) / 10 ** (launch_power / 10)
# target_normalization *= 0.5 # allow 50% power loss, so the network can ignore parts of the signal
data[:, 0:2] *= np.sqrt(target_normalization)
# if normalize:
# # square gets normalized to 1, as the power is (proportional to) the square of the amplitude
# a, b, c, d = data.T
# a, b, c, d = a - np.min(np.abs(a)), b - np.min(np.abs(b)), c - np.min(np.abs(c)), d - np.min(np.abs(d))
# a, b, c, d = a / np.max(np.abs(a)), b / np.max(np.abs(b)), c / np.max(np.abs(c)), d / np.max(np.abs(d))
# data = np.array([a, b, c, d]).T
if real:
data = np.abs(data)
config["glova"]["nos"] = str(symbols)
data = np.concatenate([data, timestamps.reshape(-1,1)], axis=-1)
data = np.concatenate([data, timestamps.reshape(-1, 1)], axis=-1)
data = torch.tensor(data, device=device, dtype=dtype)
return data, config
return data, config, orig_symbols
def roll_along(arr, shifts, dim):
@@ -110,9 +141,15 @@ class FiberRegenerationDataset(Dataset):
target_delay: float | int = 0,
xy_delay: float | int = 0,
drop_first: float | int = 0,
drop_last=0,
dtype: torch.dtype = None,
real: bool = False,
device=None,
# osnr: float|None = None,
polarisations=None,
randomise_polarisations: bool = False,
repeat_randoms: int = 1,
# cross_pol_interference: float = 0,
**kwargs,
):
"""
@@ -145,50 +182,54 @@ class FiberRegenerationDataset(Dataset):
assert output_dim is None or output_dim > 0, "output_len must be positive or None"
assert drop_first >= 0, "drop_first must be non-negative"
faux = kwargs.pop("faux", False)
self.randomise_polarisations = randomise_polarisations
# self.cross_pol_interference = cross_pol_interference
if faux:
data_raw = np.array(
[[i + 0.1j, i + 0.2j, i + 1.1j, i + 1.2j] for i in range(12800)],
dtype=np.complex128,
data_raw = None
self.config = None
files = []
self.orig_symbols = None
for file_path in file_path if isinstance(file_path, (tuple, list)) else [file_path]:
data, config, orig_syms = load_data(
file_path,
skipfirst=drop_first,
skiplast=drop_last,
symbols=kwargs.get("num_symbols", None),
real=real,
normalize=1000,
device=device,
dtype=dtype,
)
data_raw = torch.tensor(data_raw, device=device, dtype=dtype)
timestamps = torch.arange(12800)
data_raw = torch.concatenate([data_raw, timestamps.reshape(-1, 1)], axis=-1)
self.config = {
"data": {"dir": '"."', "npy_dir": '"."', "file": "faux"},
"glova": {"sps": 128},
}
else:
data_raw = None
self.config = None
files = []
for file_path in (file_path if isinstance(file_path, (tuple, list)) else [file_path]):
data, config = load_data(
file_path,
skipfirst=drop_first,
symbols=kwargs.get("num_symbols", None),
real=real,
normalize=True,
device=device,
dtype=dtype,
)
if data_raw is None:
data_raw = data
if orig_syms is not None:
if self.orig_symbols is None:
self.orig_symbols = orig_syms
else:
data_raw = torch.cat([data_raw, data], dim=0)
if self.config is None:
self.config = config
else:
assert self.config["glova"]["sps"] == config["glova"]["sps"], "samples per symbol must be the same"
files.append(config["data"]["file"].strip('"'))
self.config["data"]["file"] = str(files)
self.orig_symbols = np.concat((self.orig_symbols, orig_syms), axis=-1)
if data_raw is None:
data_raw = data
else:
data_raw = torch.cat([data_raw, data], dim=0)
if self.config is None:
self.config = config
else:
assert self.config["glova"]["sps"] == config["glova"]["sps"], "samples per symbol must be the same"
files.append(config["data"]["file"].strip('"'))
self.config["data"]["file"] = str(files)
# if polarisations is not None:
# data_raw_clone = data_raw.clone()
# # rotate the polarisation by 180 degrees
# data_raw_clone[2, :] *= -1
# data_raw_clone[3, :] *= -1
# data_raw = torch.cat([data_raw, data_raw_clone], dim=0)
self.polarisations = bool(polarisations)
self.device = data_raw.device
self.samples_per_symbol = int(self.config["glova"]["sps"])
# self.num_symbols = int(self.config["glova"]["nos"])
self.samples_per_slice = int(symbols * self.samples_per_symbol)
self.symbols_per_slice = self.samples_per_slice / self.samples_per_symbol
@@ -258,17 +299,98 @@ class FiberRegenerationDataset(Dataset):
elif self.target_delay_samples < 0:
data_raw = data_raw[:, : self.target_delay_samples]
timestamps = data_raw[-1, :]
data_raw = data_raw[:-1, :]
timestamps = data_raw[4, :]
data_raw = data_raw[:4, :]
data_raw = data_raw.view(2, 2, -1)
timestamps_doubled = torch.cat([timestamps.unsqueeze(dim=0), timestamps.unsqueeze(dim=0)], dim=0).unsqueeze(dim=1)
data_raw = torch.cat([data_raw, timestamps_doubled], dim=1)
fiber_in = data_raw[0, :, :]
fiber_out = data_raw[1, :, :]
# timestamps_doubled = torch.cat([timestamps.unsqueeze(dim=0), timestamps.unsqueeze(dim=0)], dim=0).unsqueeze(
# dim=1
# )
fiber_in = torch.cat([fiber_in, timestamps.unsqueeze(0)], dim=0)
fiber_out = torch.cat([fiber_out, timestamps.unsqueeze(0)], dim=0)
# fiber_out: [E_out_x, E_out_y, timestamps]
# add noise related to amplification necessary due to splitting of the signal
# gain_lin = output_dim*2
# gain_lin = 1
# edfa_nf = float(self.config["signal"]["edfa_nf"])
# nf_lin = 10**(edfa_nf/10)
# f0 = float(self.config["glova"]["f0"])
# noise_add = (gain_lin-1)*nf_lin*6.626e-34*f0*12.5e9
# noise = torch.randn_like(fiber_out[:2, :])
# noise_power = torch.sum(torch.mean(noise.abs().square().squeeze(), dim=-1))
# noise = noise * torch.sqrt(noise_add / noise_power)
# fiber_out[:2, :] += noise
# if osnr is None:
# noisy = fiber_out[:2, :]
# else:
# noisy = self.add_noise(fiber_out[:2, :], osnr)
# fiber_out = torch.cat([fiber_out, noisy], dim=0)
# fiber_out: [E_out_x, E_out_y, timestamps, E_out_x_noisy, E_out_y_noisy]
if repeat_randoms > 1:
fiber_in = fiber_in.repeat(1, 1, repeat_randoms)
fiber_out = fiber_out.repeat(1, 1, repeat_randoms)
# review: potential problems with repeated timestamps when plotting
else:
repeat_randoms = 1
if self.randomise_polarisations:
angles = torch.fmod(torch.randperm(data_raw.shape[-1]*repeat_randoms, device=fiber_out.device), 2) * torch.pi
start_angle = torch.rand(1) * 2 * torch.pi
angles = start_angle + torch.cumsum(torch.randn(data_raw.shape[-1])/333, dim=0) * torch.pi * 2 # random walk
angles = torch.randn(data_raw.shape[-1], device=fiber_out.device) * 2*torch.pi / 36 # sigma = 10 degrees
# self.angles = torch.rand(self.data.shape[0]) * 2 * torch.pi
else:
angles = torch.zeros(data_raw.shape[-1], device=fiber_out.device)
sin = torch.sin(angles)
cos = torch.cos(angles)
rot = torch.stack([torch.stack([cos, -sin], dim=1), torch.stack([sin, cos], dim=1)], dim=2)
data_rot = torch.bmm(fiber_out[:2, :].T.unsqueeze(1), rot.to(dtype=data_raw.dtype)).squeeze(1).T
# data_rot_noisy = torch.bmm(fiber_out[3:5, :].T.unsqueeze(1), rot.to(dtype=data_raw.dtype)).squeeze(1).T
fiber_out = torch.cat((fiber_out, data_rot), dim=0)
fiber_out = torch.cat([fiber_out, angles.unsqueeze(0)], dim=0)
# fiber_in:
# 0 E_in_x,
# 1 E_in_y,
# 2 timestamps
# fiber_out:
# 0 E_out_x,
# 1 E_out_y,
# 2 timestamps,
# 3 E_out_x_rot,
# 4 E_out_y_rot,
# 5 angle
# data_raw = torch.cat([data_raw, timestamps_doubled], dim=1)
# data layout
# [ [E_in_x, E_in_y, timestamps],
# [E_out_x, E_out_y, timestamps] ]
self.data = data_raw.unfold(dimension=-1, size=self.samples_per_slice, step=1)
self.data = self.data.movedim(-2, 0)
self.fiber_in = fiber_in.unfold(dimension=-1, size=self.samples_per_slice, step=1)
self.fiber_in = self.fiber_in.movedim(-2, 0)
self.fiber_out = fiber_out.unfold(dimension=-1, size=self.samples_per_slice, step=1)
self.fiber_out = self.fiber_out.movedim(-2, 0)
# if self.randomise_polarisations:
# self.angles = torch.cumsum((torch.rand(self.fiber_out.shape[0]) - 0.5) * 2 * torch.pi * 2 / 5000, dim=0)
# self.data = data_raw.unfold(dimension=-1, size=self.samples_per_slice, step=1)
# self.data = self.data.movedim(-2, 0)
# self.angles = torch.zeros(self.data.shape[0])
...
# ...
# -> [no_slices, 2, 3, samples_per_slice]
# data layout
@@ -278,33 +400,160 @@ class FiberRegenerationDataset(Dataset):
# ...
# ] -> [no_slices, 2, 3, samples_per_slice]
...
def __len__(self):
return self.data.shape[0]
return self.fiber_in.shape[0]
def add_noise(self, data, osnr):
osnr_lin = 10 ** (osnr / 10)
popt = torch.mean(data.abs().square().squeeze(), dim=-1)
noise = torch.randn_like(data)
pn = torch.mean(noise.abs().square().squeeze(), dim=-1)
mult = torch.sqrt(popt / (pn * osnr_lin))
mult = mult * torch.eye(popt.shape[0], device=mult.device)
mult = mult.to(dtype=noise.dtype)
noise = mult @ noise
pn = torch.mean(noise.abs().square().squeeze(), dim=-1)
noisy = data + noise
return noisy
def __getitem__(self, idx):
if isinstance(idx, slice):
return [self.__getitem__(i) for i in range(*idx.indices(len(self)))]
else:
data_slice = self.data[idx].squeeze()
# fiber in: [E_in_x, E_in_y, timestamps]
# fiber out: [E_out_x, E_out_y, timestamps, E_out_x_rot, E_out_y_rot, angle]
data_slice = data_slice[:, :, :data_slice.shape[2] // self.output_dim * self.output_dim]
# if self.polarisations:
output_dim = self.output_dim // 2
self.output_dim = output_dim * 2
data_slice = data_slice.view(data_slice.shape[0], data_slice.shape[1], self.output_dim, -1)
target = data_slice[0, :, self.output_dim//2, 0]
data = data_slice[1, :, :, 0]
# data_timestamps = data[-1,:].real
data = data[:-1, :]
target_timestamp = target[-1].real
target = target[:-1]
if not self.polarisations:
output_dim = 2 * output_dim
fiber_in = self.fiber_in[idx].squeeze()
fiber_out = self.fiber_out[idx].squeeze()
fiber_in = fiber_in[..., : fiber_in.shape[-1] // output_dim * output_dim]
fiber_out = fiber_out[..., : fiber_out.shape[-1] // output_dim * output_dim]
fiber_in = fiber_in.view(fiber_in.shape[0], output_dim, -1)
fiber_out = fiber_out.view(fiber_out.shape[0], output_dim, -1)
center_angle = fiber_out[5, output_dim // 2, 0]
angles = fiber_out[5, :, 0]
plot_data_rot = fiber_out[3:5, output_dim // 2, 0].detach().clone()
data = fiber_out[0:2, :, 0]
plot_data = fiber_out[0:2, output_dim // 2, 0].detach().clone()
target = fiber_in[:2, output_dim // 2, 0]
plot_target = fiber_in[:2, output_dim // 2, 0].detach().clone()
target_timestamp = fiber_in[2, output_dim // 2, 0].real
...
if self.polarisations:
rot = int(np.random.randint(2) * 2 - 1)
data = rot * data
target = rot * target
plot_data_rot = rot * plot_data_rot
center_angle = center_angle + (rot - 1) * torch.pi / 2 # rot: -1 or 1 -> -2 or 0 -> -pi or 0
angles = angles + (rot - 1) * torch.pi / 2
pol_flipped_data = -data
pol_flipped_target = -target
# transpose to interleave the x and y data in the output tensor
data = data.transpose(0, 1).flatten().squeeze()
data = data / torch.sqrt(torch.ones(1) * len(data)) # power loss due to splitting
pol_flipped_data = pol_flipped_data.transpose(0, 1).flatten().squeeze()
pol_flipped_data = pol_flipped_data / torch.sqrt(
torch.ones(1) * len(pol_flipped_data)
) # power loss due to splitting
# angle_data = angle_data.transpose(0, 1).flatten().squeeze()
# angle_data2 = angle_data2.transpose(0,1).flatten().squeeze()
center_angle = center_angle.flatten().squeeze()
angles = angles.flatten().squeeze()
# data_timestamps = data_timestamps.flatten().squeeze()
# target = target.transpose(0,1).flatten().squeeze()
target = target.flatten().squeeze()
pol_flipped_target = pol_flipped_target.flatten().squeeze()
target_timestamp = target_timestamp.flatten().squeeze()
plot_target = plot_target.flatten().squeeze()
plot_data = plot_data.flatten().squeeze()
plot_data_rot = plot_data_rot.flatten().squeeze()
return data, target, target_timestamp
return {
"x": data,
"x_flipped": pol_flipped_data,
"x_stacked": torch.cat([data, pol_flipped_data], dim=-1),
"y": target,
"y_flipped": pol_flipped_target,
"y_stacked": torch.cat([target, pol_flipped_target], dim=-1),
"center_angle": center_angle,
"angles": angles,
"mean_angle": angles.mean(),
# "sop": sop,
# "angle_data": angle_data,
# "angle_data2": angle_data2,
"timestamp": target_timestamp,
"plot_target": plot_target,
"plot_data": plot_data,
"plot_data_rot": plot_data_rot,
# "plot_clean": fiber_out_plot_clean,
}
def complex_max(self, data, dim=-1):
# returns element(s) with the maximum absolute value along a given dimension
# ind = torch.argmax(data.abs(), dim=dim, keepdim=True)
# max_values = torch.gather(data, dim, ind).squeeze(dim=dim)
# return max_values
return torch.gather(data, dim, torch.argmax(data.abs(), dim=dim, keepdim=True)).squeeze(dim=dim)
def rotate(self, data, angle):
# rotates a 2d tensor by a given angle
# data: [2, ...]
# angle: [1]
# returns: [2, ...]
# get sine and cosine of the angle
sine = torch.sin(angle)
cosine = torch.cos(angle)
return torch.stack([data[0] * cosine - data[1] * sine, data[0] * sine + data[1] * cosine], dim=0)
def rotate_all(self):
def do_rotation(j, num_processes):
for i in range(len(self) // num_processes):
index = i * num_processes + j
self.data[index, 1, :2, :] = self.rotate(self.data[index, 1, :2, :], self.angles[index])
self.processes = []
for j in range(mp.cpu_count()):
self.processes.append(mp.Process(target=do_rotation, args=(j, mp.cpu_count())))
self.processes[-1].start()
for p in self.processes:
p.join()
for i in range(len(self) // mp.cpu_count() * mp.cpu_count(), len(self)):
self.data[i, 1, :2, :] = self.rotate(self.data[i, 1, :2, :], self.angles[i])
def polarimeter(self, data):
# data: [2, ...] -> x, y
# returns [4] -> S0, S1, S2, S3
x = data[0].mean()
y = data[1].mean()
I_X = x.abs().square()
I_Y = y.abs().square()
I_45 = (x + y).abs().square()
I_RHC = (x + 1j * y).abs().square()
S0 = I_X + I_Y
S1 = (2 * I_X - S0) / S0
S2 = (2 * I_45 - S0) / S0
S3 = (2 * I_RHC - S0) / S0
return torch.stack([S0, S1, S2, S3], dim=0)

View File

@@ -1,15 +1,23 @@
from datetime import datetime
import json
from pathlib import Path
from typing import Literal
import h5py
from matplotlib import pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
# from cmap import Colormap as cm
import numpy as np
from scipy.cluster.vq import kmeans2
import warnings
import multiprocessing
from rich.traceback import install
from rich import pretty
from rich import print
install()
pretty.install()
# from rich import pretty
# from rich import print
# pretty.install()
def generate_sample_data(n_symbols=1000, sps=128, noise=0.01, skew=1):
@@ -20,6 +28,7 @@ def generate_sample_data(n_symbols=1000, sps=128, noise=0.01, skew=1):
xaxis = np.arange(0, len(signal)) / sps
return np.vstack([xaxis, signal])
def create_symbol_sequence(n_symbols, skew=1):
np.random.seed(42)
data = np.random.randint(0, 4, n_symbols) / 4
@@ -38,6 +47,14 @@ def generate_signal(data, sps):
signal = np.convolve(data_padded, wavelet)
signal = np.cumsum(signal)
signal = signal[sps // 2 + len(wavelet) // 2 - 1 : -len(wavelet) // 2]
mi, ma = np.min(signal), np.max(signal)
signal = (signal - mi) / (ma - mi)
mod = 0.8
signal *= mod
signal += 1 - mod
return signal
@@ -48,8 +65,8 @@ def normalization_with_noise(signal, noise=0):
signal += awgn
# min-max normalization
signal = signal - np.min(signal)
signal = signal / np.max(signal)
# signal = signal - np.min(signal)
# signal = signal / np.max(signal)
return signal
@@ -67,197 +84,424 @@ def generate_wavelet(sps, oversample=3):
class eye_diagram:
def __init__(self, data, *, channel_names=None, horizontal_bins=256, vertical_bins=1000, n_levels=4):
def __init__(
self,
data,
*,
channel_names=None,
horizontal_bins=256,
vertical_bins=1000,
n_levels=4,
multithreaded=True,
save_file_or_dir=None,
):
# data has shape [channels, 2, samples]
# each sample has a timestamp and a value
if data.ndim == 2:
data = data[np.newaxis, :, :]
self.channel_names = channel_names
self.raw_data = data
self.channels = data.shape[0]
self.y_bins = np.zeros(1)
self.x_bins = np.zeros(1)
self.eye_data = np.zeros(1)
self.channel_names = channel_names
self.n_channels = data.shape[0]
self.n_levels = n_levels
self.eye_stats = [{"success": False} for _ in range(self.channels)]
self.eye_stats = [{"success": False} for _ in range(self.n_channels)]
self.horizontal_bins = horizontal_bins
self.vertical_bins = vertical_bins
self.multi_threaded = multithreaded
self.analysed = False
self.eye_built = False
self.analyse(self.n_levels)
def generate_eye_data(self):
self.x_bins = np.linspace(0, 2, self.horizontal_bins, endpoint=False)
self.y_bins = np.zeros((self.channels, self.vertical_bins))
self.eye_data = np.zeros((self.channels, self.vertical_bins, self.horizontal_bins))
for i in range(self.channels):
data_min = np.min(self.raw_data[i, 1, :])
data_max = np.max(self.raw_data[i, 1, :])
self.y_bins[i] = np.linspace(data_min, data_max, self.vertical_bins, endpoint=False)
t_vals = self.raw_data[i, 0, :] % 2
val_vals = self.raw_data[i, 1, :]
x_indices = np.digitize(t_vals, self.x_bins) - 1
y_indices = np.digitize(val_vals, self.y_bins[i]) - 1
np.add.at(self.eye_data[i], (y_indices, x_indices), 1)
self.eye_built = True
self.save_file = save_file_or_dir
def load_data(self, file=None):
file = self.save_file if file is None else file
if file is None:
raise FileNotFoundError("No file specified.")
self.save_file = str(file)
# self.file_or_dir = self.save_file
with h5py.File(file, "r") as infile:
self.y_bins = infile["y_bins"][:]
self.x_bins = infile["x_bins"][:]
self.eye_data = infile["eye_data"][:]
self.channel_names = infile.attrs["channel_names"]
self.n_channels = infile.attrs["n_channels"]
self.n_levels = infile.attrs["n_levels"]
self.eye_stats = infile.attrs["eye_stats"]
self.eye_stats = [json.loads(stat) for stat in self.eye_stats]
self.horizontal_bins = infile.attrs["horizontal_bins"]
self.vertical_bins = infile.attrs["vertical_bins"]
self.multi_threaded = infile.attrs["multithreaded"]
self.analysed = infile.attrs["analysed"]
self.eye_built = infile.attrs["eye_built"]
def save_data(self, file_or_dir=None):
file_or_dir = self.save_file if file_or_dir is None else file_or_dir
if file_or_dir is None:
file = Path(f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_eye_data.eye.h5")
elif Path(file_or_dir).is_dir():
file = Path(file_or_dir) / f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_eye_data.eye.h5"
else:
file = Path(file_or_dir)
# file.parent.mkdir(parents=True, exist_ok=True)
self.save_file = str(file)
with h5py.File(file, "w") as outfile:
outfile.create_dataset("eye_data", data=self.eye_data)
outfile.create_dataset("y_bins", data=self.y_bins)
outfile.create_dataset("x_bins", data=self.x_bins)
outfile.attrs["channel_names"] = self.channel_names
outfile.attrs["n_channels"] = self.n_channels
outfile.attrs["n_levels"] = self.n_levels
self.eye_stats = eye_diagram.convert_arrays(self.eye_stats)
outfile.attrs["eye_stats"] = [json.dumps(stat) for stat in self.eye_stats]
outfile.attrs["horizontal_bins"] = self.horizontal_bins
outfile.attrs["vertical_bins"] = self.vertical_bins
outfile.attrs["multithreaded"] = self.multi_threaded
outfile.attrs["analysed"] = self.analysed
outfile.attrs["eye_built"] = self.eye_built
@staticmethod
def convert_arrays(input_object):
"""
convert ndarrays in (nested) dict to lists
"""
if isinstance(input_object, np.ndarray):
return input_object.tolist()
elif isinstance(input_object, list):
return [eye_diagram.convert_arrays(old) for old in input_object]
elif isinstance(input_object, tuple):
return tuple(eye_diagram.convert_arrays(old) for old in input_object)
elif isinstance(input_object, dict):
dict_out = {}
for key, value in input_object.items():
dict_out[key] = eye_diagram.convert_arrays(value)
return dict_out
return input_object
def generate_eye_data(
self, mode: Literal["default", "load", "save", "nosave"] = "default", file_or_dir: Path | None | str = None
):
# modes:
# default: try to load eye data from file, if not found, generate and save
# load: try to load eye data from file, if not found, generate but don't save
# save: generate eye data and save
update_save = True
if mode == "load":
self.load_data(file_or_dir)
update_save = False
elif mode == "default":
try:
self.load_data(file_or_dir)
update_save = False
except (FileNotFoundError, IsADirectoryError):
pass
def plot(self, title="Eye Diagram", stats=True, show=True):
if not self.eye_built:
self.generate_eye_data()
update_save = True
self.x_bins = np.linspace(0, 2, self.horizontal_bins, endpoint=False)
self.y_bins = np.zeros((self.n_channels, self.vertical_bins))
self.eye_data = np.zeros((self.n_channels, self.vertical_bins, self.horizontal_bins))
datas = [self.raw_data[i] for i in range(self.n_channels)]
if self.multi_threaded:
with multiprocessing.Pool() as pool:
results = pool.map(self.generate_eye_data_single, datas)
for i, result in enumerate(results):
self.eye_data[i], self.y_bins[i] = result
else:
for i, data in enumerate(datas):
self.eye_data[i], self.y_bins[i] = self.generate_eye_data_single(data)
self.eye_built = True
if mode == "save" or (mode == "default" and update_save):
self.save_data(file_or_dir)
def generate_eye_data_single(self, data):
eye_data = np.zeros((self.vertical_bins, self.horizontal_bins))
data_min = np.min(data[1, :])
data_max = np.max(data[1, :])
# round down/up to 1 decimal
data_min = np.floor(data_min*10)/10
data_max = np.ceil(data_max*10)/10
# data_range = data_max - data_min
# data_min -= 0.1 * data_range
# data_max += 0.1 * data_range
# data_min = -0.05
# data_max += 0.05
# data[1,:] -= np.min(data[1, :])
# data[1,:] /= np.max(data[1, :])
# data_min = 0
# data_max = 1
y_bins = np.linspace(data_min, data_max, self.vertical_bins, endpoint=False)
t_vals = data[0, :] % 2 # + np.random.randn(*data[0, :].shape) * 1 / (512)
val_vals = data[1, :] # + np.random.randn(*data[1, :].shape) * 1 / (320)
x_indices = np.digitize(t_vals, self.x_bins) - 1
y_indices = np.digitize(val_vals, y_bins) - 1
np.add.at(eye_data, (y_indices, x_indices), 1)
return eye_data, y_bins
def plot(
self,
title="Eye Diagram",
stats=True,
all_stats=True,
show=True,
mode: Literal["default", "load", "save", "nosave"] = "default",
# save_images = False,
# image_dir = None,
# cmap=None,
):
if stats and not self.analysed:
self.analyse(mode=mode)
if not self.eye_built:
self.generate_eye_data(mode=mode)
cmap = LinearSegmentedColormap.from_list(
"eyemap",
[(0, "white"), (0.1, "blue"), (0.2, "cyan"), (0.5, "green"), (0.8, "yellow"), (0.9, "red"), (1, "magenta")],
[
(0, "#FFFFFF00"),
(0.1, "blue"),
(0.2, "cyan"),
(0.5, "green"),
(0.8, "yellow"),
(0.9, "red"),
(1, "magenta"),
],
)
if self.channels % 2 == 0:
# cmap = cm('google:turbo_r' if cmap is None else cmap)
# first = cmap(-1)
# cmap = cmap.to_mpl()
# cmap.set_under(first, alpha=0)
if self.n_channels % 2 == 0:
rows = 2
cols = self.channels // 2
cols = self.n_channels // 2
else:
cols = int(np.ceil(np.sqrt(self.channels)))
rows = int(np.ceil(self.channels / cols))
cols = int(np.ceil(np.sqrt(self.n_channels)))
rows = int(np.ceil(self.n_channels / cols))
fig, ax = plt.subplots(rows, cols, sharex=True, sharey=False)
fig.suptitle(title)
fig.tight_layout()
ax = np.atleast_1d(ax).transpose().flatten()
for i in range(self.channels):
ax[i].set_title(self.channel_names[i] if self.channel_names is not None else f"Channel {i+1}")
ax[i].set_xlabel("Symbol")
ax[i].set_ylabel("Amplitude")
for i in range(self.n_channels):
ax[i].set_title(self.channel_names[i] if self.channel_names is not None else f"Channel {i + 1}")
if (i + 1) % rows == 0:
ax[i].set_xlabel("Symbol")
if i < rows:
ax[i].set_ylabel("Amplitude")
ax[i].grid()
ax[i].set_axisbelow(True)
ax[i].imshow(
self.eye_data[i],
self.eye_data[i] - 0.1,
origin="lower",
aspect="auto",
cmap=cmap,
extent=[self.x_bins[0], self.x_bins[-1], self.y_bins[i][0], self.y_bins[i][-1]],
interpolation="gaussian",
vmin=0,
zorder=3,
)
ax[i].set_xlim((self.x_bins[0], self.x_bins[-1]))
ymin = np.min(self.y_bins[:, 0])
ymax = np.max(self.y_bins[:, -1])
yspan = ymax - ymin
ax[i].set_ylim((ymin - 0.1 * yspan, ymax + 0.1 * yspan))
# if save_images:
# image_dir = "images_out" if image_dir is None else image_dir
# image_path = Path(image_dir) / (slugify(f"{datetime.now().strftime("%Y%m%d_%H%M%S")}_{title.replace(" ","_")}_{self.channel_names[i].replace(" ", "_") if self.channel_names is not None else f"{i + 1}"}_{ymin:.1f}_{ymax:.1f}") + ".png")
# image_path.parent.mkdir(parents=True, exist_ok=True)
# # plt.imsave(
# # image_path,
# # self.eye_data[i] - 0.1,
# # origin="lower",
# # # aspect="auto",
# # cmap=cmap,
# # # extent=[self.x_bins[0], self.x_bins[-1], self.y_bins[i][0], self.y_bins[i][-1]],
# # # interpolation="gaussian",
# # vmin=0,
# # # zorder=3,
# # )
if stats and self.eye_stats[i]["success"]:
ax[i].plot([0, 2], [self.eye_stats[i]["levels"], self.eye_stats[i]["levels"]], "k--")
ax[i].set_yticks(self.eye_stats[i]["levels"])
# add arrows for amplitudes
for j in range(len(self.eye_stats[i]["amplitudes"])):
ax[i].annotate(
"",
xy=(0.05, self.eye_stats[i]["levels"][j]),
xytext=(0.05, self.eye_stats[i]["levels"][j + 1]),
arrowprops=dict(arrowstyle="<->", facecolor="black"),
)
ax[i].annotate(
f"{self.eye_stats[i]['amplitudes'][j]:.2e}",
xy=(0.06, (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2),
)
# add arrows for eye heights
for j in range(len(self.eye_stats[i]["heights"])):
try:
bot = np.max(self.eye_stats[i]["amplitude_clusters"][j])
top = np.min(self.eye_stats[i]["amplitude_clusters"][j + 1])
# # add min_area above the plot
# ax[i].annotate(
# f"Min Area: {self.eye_stats[i]['min_area']:.2e}",
# xy=(0.05, ymax + 0.05 * yspan),
# # xycoords="axes fraction",
# ha="left",
# va="center",
# bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"),
# )
if all_stats:
ax[i].plot([0, 2], [self.eye_stats[i]["levels"], self.eye_stats[i]["levels"]], "k--")
y_ticks = (*self.eye_stats[i]["levels"], *self.eye_stats[i]["thresholds"])
# y_ticks = np.sort(y_ticks)
ax[i].set_yticks(y_ticks)
# add arrows for amplitudes
for j in range(len(self.eye_stats[i]["amplitudes"])):
ax[i].annotate(
"",
xy=(self.eye_stats[i]["time_midpoint"], bot),
xytext=(self.eye_stats[i]["time_midpoint"], top),
xy=(0.05, self.eye_stats[i]["levels"][j]),
xytext=(0.05, self.eye_stats[i]["levels"][j + 1]),
arrowprops=dict(arrowstyle="<->", facecolor="black"),
)
ax[i].annotate(
f"{self.eye_stats[i]['heights'][j]:.2e}",
xy=(self.eye_stats[i]["time_midpoint"] + 0.015, (bot + top) / 2 + 0.04),
f"{self.eye_stats[i]['amplitudes'][j]:.2e}",
xy=(0.06, (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2),
bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"),
)
except (ValueError, IndexError):
pass
# add arrows for eye widths
for j in range(len(self.eye_stats[i]["widths"])):
try:
left = np.max(self.eye_stats[i]["time_clusters"][j][0])
right = np.min(self.eye_stats[i]["time_clusters"][j][1])
vertical = (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2
# add arrows for eye heights
for j in range(len(self.eye_stats[i]["heights"])):
try:
bot = np.max(self.eye_stats[i]["amplitude_clusters"][j])
top = np.min(self.eye_stats[i]["amplitude_clusters"][j + 1])
ax[i].annotate(
"",
xy=(left, vertical),
xytext=(right, vertical),
arrowprops=dict(arrowstyle="<->", facecolor="black"),
)
ax[i].annotate(
f"{self.eye_stats[i]['widths'][j]:.2e}",
xy=((left + right) / 2 - 0.15, vertical + 0.01),
)
except (ValueError, IndexError):
pass
ax[i].annotate(
"",
xy=(self.eye_stats[i]["time_midpoint"], bot),
xytext=(self.eye_stats[i]["time_midpoint"], top),
arrowprops=dict(arrowstyle="<->", facecolor="black"),
)
ax[i].annotate(
f"{self.eye_stats[i]['heights'][j]:.2e}",
xy=(self.eye_stats[i]["time_midpoint"] + 0.015, (bot + top) / 2 + 0.04),
bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"),
)
except (ValueError, IndexError):
pass
# add arrows for eye widths
# for j in range(len(self.eye_stats[i]["widths"])):
# try:
# left = np.max(self.eye_stats[i]["time_clusters"][j][0])
# right = np.min(self.eye_stats[i]["time_clusters"][j][1])
# vertical = (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2
# ax[i].annotate(
# "",
# xy=(left, vertical),
# xytext=(right, vertical),
# arrowprops=dict(arrowstyle="<->", facecolor="black"),
# )
# ax[i].annotate(
# f"{self.eye_stats[i]['widths'][j]:.2e}",
# xy=((left + right) / 2 - 0.15, vertical + 0.01),
# bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"),
# )
# except (ValueError, IndexError):
# pass
# # add area
# for j in range(len(self.eye_stats[i]["areas"])):
# horizontal = self.eye_stats[i]["time_midpoint"]
# vertical = (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2
# ax[i].annotate(
# f"{self.eye_stats[i]['areas'][j]:.2e}",
# xy=(horizontal + 0.035, vertical - 0.07),
# bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"),
# )
# add area
for j in range(len(self.eye_stats[i]["areas"])):
horizontal = self.eye_stats[i]["time_midpoint"]
vertical = (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2
ax[i].annotate(
f"{self.eye_stats[i]['areas'][j]:.2e}",
xy=(horizontal + 0.035, vertical - 0.07),
)
# add min_area above the plot
ax[i].annotate(
f"Min Area: {self.eye_stats[i]['min_area']:.2e}",
xy=(0.05, ymax + 0.05 * yspan),
# xycoords="axes fraction",
ha="left",
va="center",
)
fig.tight_layout()
if show:
plt.show()
return fig
def analyse(self, n_levels=4):
@staticmethod
def calculate_thresholds(levels):
ret = np.cumsum(levels, dtype=float)
ret[2:] = ret[2:] - ret[:-2]
return ret[1:] / 2
def analyse_single(self, data, index):
warnings.filterwarnings("error")
for i in range(self.channels):
self.eye_stats[i]["channel"] = str(i+1) if self.channel_names is None else self.channel_names[i]
try:
approx_levels = eye_diagram.approximate_levels(self.raw_data[i], n_levels)
eye_stats = {}
eye_stats["channel_name"] = str(index + 1) if self.channel_names is None else self.channel_names[index]
try:
approx_levels = eye_diagram.approximate_levels(data, self.n_levels)
time_bounds = eye_diagram.calculate_time_bounds(self.raw_data[i], approx_levels)
time_bounds = eye_diagram.calculate_time_bounds(data, approx_levels)
self.eye_stats[i]["time_midpoint"] = (time_bounds[0] + time_bounds[1]) / 2
eye_stats["time_midpoint"] = float((time_bounds[0] + time_bounds[1]) / 2)
# eye_stats["time_midpoint"] = 1.0
self.eye_stats[i]["levels"], self.eye_stats[i]["amplitude_clusters"] = eye_diagram.calculate_levels(
self.raw_data[i], approx_levels, time_bounds
)
eye_stats["levels"], eye_stats["amplitude_clusters"] = eye_diagram.calculate_levels(
data, approx_levels, time_bounds
)
self.eye_stats[i]["amplitudes"] = np.diff(self.eye_stats[i]["levels"])
eye_stats["thresholds"] = self.calculate_thresholds(eye_stats["levels"])
self.eye_stats[i]["heights"] = eye_diagram.calculate_eye_heights(
self.eye_stats[i]["amplitude_clusters"]
)
eye_stats["amplitudes"] = np.diff(eye_stats["levels"])
self.eye_stats[i]["widths"], self.eye_stats[i]["time_clusters"] = eye_diagram.calculate_eye_widths(
self.raw_data[i], self.eye_stats[i]["levels"]
)
eye_stats["heights"] = eye_diagram.calculate_eye_heights(eye_stats["amplitude_clusters"])
# # check if time clusters are valid (upper bound > time_midpoint > lower bound)
# # if not: raise ValueError
# for j in range(len(self.eye_stats[i]['time_clusters'])):
# if not (np.max(self.eye_stats[i]['time_clusters'][j][0]) < self.eye_stats[i]["time_midpoint"] < np.min(self.eye_stats[i]['time_clusters'][j][1])):
# raise ValueError
eye_stats["widths"], eye_stats["time_clusters"] = eye_diagram.calculate_eye_widths(
data, eye_stats["levels"]
)
self.eye_stats[i]["areas"] = self.eye_stats[i]["heights"] * self.eye_stats[i]["widths"]
self.eye_stats[i]["mean_area"] = np.mean(self.eye_stats[i]["areas"])
self.eye_stats[i]["min_area"] = np.min(self.eye_stats[i]["areas"])
# # check if time clusters are valid (upper bound > time_midpoint > lower bound)
# # if not: raise ValueError
# for j in range(len(eye_stats['time_clusters'])):
# if not (np.max(eye_stats['time_clusters'][j][0]) < eye_stats["time_midpoint"] < np.min(eye_stats['time_clusters'][j][1])):
# raise ValueError
self.eye_stats[i]["success"] = True
except (RuntimeWarning, UserWarning, ValueError):
self.eye_stats[i]["success"] = False
self.eye_stats[i]["time_midpoint"] = 0
self.eye_stats[i]["levels"] = np.zeros(n_levels)
self.eye_stats[i]["amplitude_clusters"] = []
self.eye_stats[i]["amplitudes"] = np.zeros(n_levels - 1)
self.eye_stats[i]["heights"] = np.zeros(n_levels - 1)
self.eye_stats[i]["widths"] = np.zeros(n_levels - 1)
self.eye_stats[i]["areas"] = np.zeros(n_levels - 1)
self.eye_stats[i]["mean_area"] = 0
self.eye_stats[i]["min_area"] = 0
# eye_stats["areas"] = eye_stats["heights"] * eye_stats["widths"]
# eye_stats["mean_area"] = np.mean(eye_stats["areas"])
# eye_stats["min_area"] = np.min(eye_stats["areas"])
eye_stats["success"] = True
except (RuntimeWarning, UserWarning, ValueError):
eye_stats["success"] = False
eye_stats["time_midpoint"] = None
eye_stats["levels"] = None
eye_stats["thresholds"] = None
eye_stats["amplitude_clusters"] = None
eye_stats["amplitudes"] = None
eye_stats["heights"] = None
eye_stats["widths"] = None
# eye_stats["areas"] = np.zeros(self.n_levels - 1)
# eye_stats["mean_area"] = 0
# eye_stats["min_area"] = 0
warnings.resetwarnings()
return eye_stats
def analyse(
self, mode: Literal["default", "load", "save", "nosave"] = "default", file_or_dir: Path | None | str = None
):
# modes:
# default: try to load eye data from file, if not found, generate and save
# load: try to load eye data from file, if not found, generate but don't save
# save: generate eye data and save
update_save = True
if mode == "load":
self.load_data(file_or_dir)
update_save = False
elif mode == "default":
try:
self.load_data(file_or_dir)
update_save = False
except (FileNotFoundError, IsADirectoryError):
pass
if not self.analysed:
update_save = True
self.eye_stats = []
if self.multi_threaded:
with multiprocessing.Pool() as pool:
results = pool.starmap(self.analyse_single, [(self.raw_data[i], i) for i in range(self.n_channels)])
for i, result in enumerate(results):
self.eye_stats.append(result)
else:
for i in range(self.n_channels):
self.eye_stats.append(self.analyse_single(self.raw_data[i], i))
self.analysed = True
if mode == "save" or (mode == "default" and update_save):
self.save_data(file_or_dir)
@staticmethod
def approximate_levels(data, levels):
@@ -399,7 +643,7 @@ class eye_diagram:
if __name__ == "__main__":
length = int(2**14)
length = int(2**16)
# data = generate_sample_data(length, noise=1)
# data1 = generate_sample_data(length, noise=0.01)
# data2 = generate_sample_data(length, noise=0.01, skew=1.2)
@@ -407,12 +651,13 @@ if __name__ == "__main__":
# data = np.stack([data, data1, data2, data3])
data = generate_sample_data(length, noise=0.005)
eye = eye_diagram(data, horizontal_bins=256, vertical_bins=256)
attrs = ("success", "amplitudes", "time_midpoint", "levels", "heights", "widths", "area", "mean_area", "min_area")
for i, channel in enumerate(eye.eye_stats):
print(f"Channel {i}")
print_data = {attr: channel[attr] for attr in attrs}
print(print_data)
data = generate_sample_data(length, noise=0.0000)
eye = eye_diagram(data, horizontal_bins=1056, vertical_bins=1200)
eye.plot(mode="nosave", stats=False)
# attrs = ("success", "amplitudes", "time_midpoint", "levels", "heights", "widths")#, "area", "mean_area", "min_area")
# for i, channel in enumerate(eye.eye_stats):
# print(f"Channel {i}")
# print_data = {attr: channel[attr] for attr in attrs}
# print(print_data)
eye.plot()
# eye.plot()

View File

@@ -0,0 +1,122 @@
# Copyright (c) 2015, Warren Weckesser. All rights reserved.
# This software is licensed according to the "BSD 2-clause" license.
# modified by Joseph Hopfmüller in 2025,
# for integration into optical regeneration analysis scripts
from pathlib import Path
from matplotlib.colors import LinearSegmentedColormap
import matplotlib.colors as colors
import numpy as _np
from .core import grid_count as _grid_count
import matplotlib.pyplot as _plt
import numpy as np
from scipy.ndimage import gaussian_filter
# from ._common import _common_doc
__all__ = ["eyediagram"] # , 'eyediagram_lines']
# def eyediagram_lines(y, window_size, offset=0, **plotkwargs):
# """
# Plot an eye diagram using matplotlib by repeatedly calling the `plot`
# function.
# <common>
# """
# start = offset
# while start < len(y):
# end = start + window_size
# if end > len(y):
# end = len(y)
# yy = y[start:end+1]
# _plt.plot(_np.arange(len(yy)), yy, 'k', **plotkwargs)
# start = end
# eyediagram_lines.__doc__ = eyediagram_lines.__doc__.replace("<common>",
# _common_doc)
eyemap = LinearSegmentedColormap.from_list(
"eyemap",
[
(0, "#0000FF00"),
(0.1, "blue"),
(0.2, "cyan"),
(0.5, "green"),
(0.8, "yellow"),
(0.9, "red"),
(1, "magenta"),
],
)
def eyediagram(
y,
window_size,
offset=0,
colorbar=False,
show=False,
save_im=False,
overwrite=False,
blur: int | bool = True,
save_path="out.png",
bounds=None,
**imshowkwargs,
):
"""
Plot an eye diagram using matplotlib by creating an image and calling
the `imshow` function.
<common>
"""
if bounds is None:
ymax = y.max()
ymin = y.min()
yamp = ymax - ymin
ymin = ymin - 0.05 * yamp
ymax = ymax + 0.05 * yamp
ymin = np.floor(ymin * 10) / 10
ymax = np.ceil(ymax * 10) / 10
bounds = (ymin, ymax)
counts = _grid_count(y, window_size, offset, bounds=bounds, size=(1000, 1200), blur=int(blur))
counts = counts.astype(_np.float32)
origin = imshowkwargs.pop("origin", "lower")
cmap: colors.Colormap = imshowkwargs.pop("cmap", eyemap)
vmin = imshowkwargs.pop("vmin", 1)
vmax = imshowkwargs.pop("vmax", None)
cmap.set_under("white", alpha=0)
if show:
_plt.imshow(
counts.T[::-1, :],
extent=[0, 2, *bounds],
origin=origin,
cmap=cmap,
vmin=vmin,
vmax=vmax,
**imshowkwargs,
)
_plt.grid()
if colorbar:
_plt.colorbar()
if Path(save_path).is_file() and not overwrite:
save_im = False
if save_im:
from PIL import Image
arr = counts.T[::-1, :]
if origin == "lower":
arr = arr[::-1]
arr = (arr-arr.min())/(arr.max()-arr.min())
image = Image.fromarray((cmap(arr)[:, :, :] * 255).astype(np.uint8))
image.save(save_path)
# print("-")
if show:
_plt.show()
# eyediagram.__doc__ = eyediagram.__doc__.replace("<common>", _common_doc)

View File

@@ -1,6 +1,9 @@
import matplotlib.pyplot as plt
import numpy as np
from .datasets import load_data
if __name__ == "__main__":
from datasets import load_data
else:
from .datasets import load_data
def eye(*, path=None, data=None, sps=None, title=None, symbols=1000, skipfirst=0, width=2, alpha=None, complex=False, show=True):
"""Plot an eye diagram for the data given by filepath.
@@ -20,6 +23,7 @@ def eye(*, path=None, data=None, sps=None, title=None, symbols=1000, skipfirst=0
raise ValueError("Either path or data and sps must be given.")
if path is not None:
data, config = load_data(path, skipfirst, symbols)
data = data.detach().cpu().numpy()[:, :4]
sps = int(config["glova"]["sps"])
if sps is None:
raise ValueError("sps not set.")
@@ -71,3 +75,6 @@ def eye(*, path=None, data=None, sps=None, title=None, symbols=1000, skipfirst=0
plt.show()
return fig
if __name__ == "__main__":
eye(path="data/20241229-163838-128-16384-50000-0-0.2-16.8-0.058-PAM4-0-0.16.ini", symbols=1000, width=2, alpha=0.1, complex=False)

82160
src/visualization/viz.ipynb Normal file

File diff suppressed because it is too large Load Diff