Compare commits
15 Commits
950516197c
...
machine_le
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
249fe1e940 | ||
|
|
f38d0ca3bb | ||
|
|
3af73343c1 | ||
|
|
7a0b65f82d | ||
|
|
98305fdf47 | ||
|
|
33141bdf41 | ||
|
|
638b62ee03 | ||
|
|
98d0b5ba8d | ||
|
|
39ae13d0af | ||
|
|
0e29b87395 | ||
|
|
884d9f73c9 | ||
|
|
08c81428ef | ||
|
|
9299316f81 | ||
|
|
2fcf621a50 | ||
|
|
36f89f379d |
2
.gitignore
vendored
2
.gitignore
vendored
@@ -163,3 +163,5 @@ cython_debug/
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
tolerance_results/*
|
||||
data/*
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:1250b8258fe86ed75db52637b69e1d5738105be1b60fd45049738f711d390192
|
||||
size 606
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:beb48236b04818824df1110815ecacdd2f2796a441d4c50bc944711cc3d6ef94
|
||||
size 606
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:505efd2a124a4e1bba5fb4342bc1427b24042e64db5c60a96cb66d0a2b19378f
|
||||
size 605
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:166c482c5cbf586468865358a4b3cf101be4b16c42a590a8b517cdf7f9c2fd7a
|
||||
size 605
|
||||
3
data/20241211-100154-128-16384-1-0-0-0-0-PAM4-0-0.ini
Normal file
3
data/20241211-100154-128-16384-1-0-0-0-0-PAM4-0-0.ini
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:f86feab7455c42e0234090b5b3461341b2a3637414356863b92443e8970692a4
|
||||
size 599
|
||||
3
data/20241211-105524-128-16384-1-0-0-0-0-PAM4-0-0.ini
Normal file
3
data/20241211-105524-128-16384-1-0-0-0-0-PAM4-0-0.ini
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:ac7eb8830fa7bb91e18ae87ab7b55d463738dbfd1f5662957c06c69ac98a214f
|
||||
size 599
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:670487527d53db564874c6bdbe057fcca48e19963275af92690a9f22e22764b2
|
||||
size 616
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:b2cf718ae05cebb5bdcbe4bb9e7924fc0856c7f236ec79b3a4e0cc0f3e84bc72
|
||||
size 616
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:edc005000a97d3eba4f573d35736c57350f17ff69017acf0c873413d2b824abe
|
||||
size 616
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:91e4b00d042393e967d783605da22ddea0846a4ff7328354033eb47fa96e6055
|
||||
size 616
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:9495d7059e92069b4182e1096b4690e5bc2122a4c7ac3cde84d3fd2e71537405
|
||||
size 616
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:db22dc55b809c359a432294308b3cd70c861ff7f67b1002cd296f2e05e58da53
|
||||
size 616
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:d78908aeda67365c2efee3efcbf71d05a91c2391b54893a4f6b060a7d2327714
|
||||
size 616
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:b75f51f0d35df1953c29874c2e3c111ccacdff624ce76aa13f24f37fa8f37997
|
||||
size 616
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:f7e942995d3ddf4009a93d3f4178457a2219822e8790e0cc73abd4cae0429333
|
||||
size 616
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:2befcef5adb9f8206292b443278d603f02575b8ffcd2e93f3d6d0db8eb009798
|
||||
size 616
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:c96f91e93524cb49e8cad67da973989cd66bfb56d04c370c770c6ed2610b35f3
|
||||
size 618
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:2ef989d392f318b37fecabeb7e5fa24db14c8867227862153736eeae908ef30f
|
||||
size 618
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:44ffed0e7387f70fc8d524d9d3d9587017860752b6199cc78545bcb6692e68af
|
||||
size 618
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:27191a866d99a3ab81959bc93f496e9d656b59e13ec77f2ff393079b468765ae
|
||||
size 618
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:d3c61cf0f502298db563d97bd3aadd3694e5834434377ff28f553de0e78873ee
|
||||
size 618
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:bc02b0099ea3bb136733e3d20817cad79b6c50c2e4b845f0d206455dde188cc4
|
||||
size 615
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:d80ff6f2a84acf973fbdf81a05ed0b1902f8bf97856cd5132b646f6b1173f496
|
||||
size 615
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:54ac6b6a452aa6b7d312a4c8ab8f7ebe2f96c1c4170cbc56147e8f2f9d934ad6
|
||||
size 615
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:bd2c6f4050488b6e857759d48aa1f1f37399d81cee1667d3668145e938d17c83
|
||||
size 615
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:c59b8113092459a8751b385a7b1a6f10828626d2ec2f29775512157fd9bbc75c
|
||||
size 615
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:04eaa2a29b3302e5de3bf99bf6a57fc8f27361fd3df3cac9245e25ab99324829
|
||||
size 615
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:704d3b0b17b9d320f4717b5a5a8bdbc5714f3caa4efa7153e980766429e834f4
|
||||
size 615
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:29304f35a88fd777566105f8666fff1c9927beb32756822365bcf9c159feb98e
|
||||
size 615
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:4a87488c12e0253b2bb5d1ac7aa1536f69ef569e62c2aab6a10149d753e049b4
|
||||
size 615
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:ff47c8d5413881edb03cfadfcde3b550ef7089543615a467c4f0027edaf1455e
|
||||
size 615
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:9e05a0a54c7e3aaffeca0ac90cce1b274a544d90b329a93d453392f5df4e91a8
|
||||
size 616
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:25e45b06b551ab8031c2159030d658999fcb3d1f0a34538c90768b94c8116771
|
||||
size 616
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:707ed73713b6c2d80d4e333b1ccdf4650f50aefefff58ddb471c1d5411954b3d
|
||||
size 616
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:a8c3baf878943741d83835c1afed05c1f9780ff3f0df260c0d92706343e59c50
|
||||
size 616
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:98878d09510dedc24f1429bbd12cce49c66be6c9d279a28765b120efe820a171
|
||||
size 616
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:fcbdaffa211d6b0b44b3ae1c66645999e95901bfdb2fffee4c45e34a0d901ee1
|
||||
size 649
|
||||
3
data/npys/0d56eb7933de7bd4a847eaa1ba4fd2c4.npy
Normal file
3
data/npys/0d56eb7933de7bd4a847eaa1ba4fd2c4.npy
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:73e922310068a66ab1a0c3d39b0b0fd7db798f2637b199a8c4bd113a38bb28c8
|
||||
size 134217856
|
||||
3
data/npys/0fb99fd81fd3076612f6aca9e2926e36.npy
Normal file
3
data/npys/0fb99fd81fd3076612f6aca9e2926e36.npy
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:43e80dc7d21aeff62c73f0ed02b20a4ac9b573d0e131a3ab8d1471077e03634b
|
||||
size 134217856
|
||||
3
data/npys/10503c5808359136376a663b39fe9451.npy
Normal file
3
data/npys/10503c5808359136376a663b39fe9451.npy
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:bfe3303d51a74602dfeb8c185f1cd49f39811fbda5e72399f8e112bc7fc7d5bc
|
||||
size 134217856
|
||||
3
data/npys/1506dbe6b4855599647d54c9b83e4913.npy
Normal file
3
data/npys/1506dbe6b4855599647d54c9b83e4913.npy
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:074c07bf5673060e5fe26162a01959030595671e10d607f80df7ff0e035c8f7b
|
||||
size 134217856
|
||||
3
data/npys/20959ec61eeecedc5baa28ca7a25c6a2.npy
Normal file
3
data/npys/20959ec61eeecedc5baa28ca7a25c6a2.npy
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:05ca3cddb57c739ba57c6a5a871607715cb926d87d6ececbcb385c6f11ad5690
|
||||
size 134217856
|
||||
3
data/npys/290c74cedb7a35eef6221e6120ed6019.npy
Normal file
3
data/npys/290c74cedb7a35eef6221e6120ed6019.npy
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:798a9cb026a88a336d612e28033354f00c412e74bea98672ae4d4dd276af97be
|
||||
size 134217856
|
||||
3
data/npys/2916537fb1df77ce203ce4eed3f7f8fd.npy
Normal file
3
data/npys/2916537fb1df77ce203ce4eed3f7f8fd.npy
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:50477064b1f8f9d0629a0cc294cf4195ea4bffb0f853df1c9e1177723d0ac9a6
|
||||
size 134217856
|
||||
3
data/npys/2a7f021c1c1112554d469b4e7bc3081f.npy
Normal file
3
data/npys/2a7f021c1c1112554d469b4e7bc3081f.npy
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:fb74cfbeec54b4f263c08510312881527fc7e484604fa0a1213b596f175fecc2
|
||||
size 134217856
|
||||
3
data/npys/39d009618e14a82fdcce8299fe8025c7.npy
Normal file
3
data/npys/39d009618e14a82fdcce8299fe8025c7.npy
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:4bd3a9265ac0200b04a2f5f095edefbbf7df18436cf1066813c2eb93b234f5f3
|
||||
size 134217856
|
||||
3
data/npys/3e8d8834ce3bc4ceea8a4da93d5d4c7e.npy
Normal file
3
data/npys/3e8d8834ce3bc4ceea8a4da93d5d4c7e.npy
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:bc6bae5f56197c905482795cffd6158254cb2f6ae38b848f16afe230fe6a9250
|
||||
size 134217856
|
||||
3
data/npys/3f6b6dfaecf6f1c003c6d9dfe950c6d0.npy
Normal file
3
data/npys/3f6b6dfaecf6f1c003c6d9dfe950c6d0.npy
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:447cca0af25e309c8be61216cea4fb2d3a8a967b0522760f1dab8f15e6b41574
|
||||
size 134217856
|
||||
3
data/npys/4365258197699ae6a83e846b9985f606.npy
Normal file
3
data/npys/4365258197699ae6a83e846b9985f606.npy
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:63321906920de825fc7857aa9f2e4944c3b32f3eadf99da06b66f6599924bc4c
|
||||
size 134217856
|
||||
3
data/npys/445b029eb26fba84333d7ed701ad637d.npy
Normal file
3
data/npys/445b029eb26fba84333d7ed701ad637d.npy
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:420fffd570da7bce1912a75d0d3d1cafb6faa0029701b204d9054914fa5d499a
|
||||
size 134217856
|
||||
3
data/npys/484ffd154056e75885022c286dbfafa7.npy
Normal file
3
data/npys/484ffd154056e75885022c286dbfafa7.npy
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:92e3fcc41f05380cb7b334fefc0a30bb8c1dfd9ca5c3b2cfaad36b0c7093914e
|
||||
size 134217856
|
||||
3
data/npys/55aac65b78a8181629f963815a7cdf5c.npy
Normal file
3
data/npys/55aac65b78a8181629f963815a7cdf5c.npy
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:0bf671d666d35617edd7bfb58548a5120bd92e5f9cb6edb4b5fc8d3bf5db8987
|
||||
size 134217856
|
||||
3
data/npys/5aaa36769f52738e1fa37f4c79b9eed2.npy
Normal file
3
data/npys/5aaa36769f52738e1fa37f4c79b9eed2.npy
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:bd429a5776555671b19d42e3ae152026dafd5bf95aeed9847df1432ed37f3eba
|
||||
size 134217856
|
||||
3
data/npys/5d775c0087b5389f79ff09112a5da948.npy
Normal file
3
data/npys/5d775c0087b5389f79ff09112a5da948.npy
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:4c39a10ad5ff46d977afbf84c5c6ecce4c7c0f8e37faf00c1a3b7a264f01b1cd
|
||||
size 134217856
|
||||
3
data/npys/6789fdea2609799ef2e975907625b79a.h5
Normal file
3
data/npys/6789fdea2609799ef2e975907625b79a.h5
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:1df90745cc2e6d4b0ad964fca2de1441e6e0b4b8345fbb0fbc1ffe9820674269
|
||||
size 134481920
|
||||
3
data/npys/68c8067f1d602a2afce20b08eeafa2f0.npy
Normal file
3
data/npys/68c8067f1d602a2afce20b08eeafa2f0.npy
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:1953d4308a837de3e976f9568cd571925e1c377746013220e263bbc1454b3be6
|
||||
size 134217856
|
||||
3
data/npys/78f8295da779018e431af641e9f3eb76.npy
Normal file
3
data/npys/78f8295da779018e431af641e9f3eb76.npy
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:477365edd696f66610298198257422a017f825e1f8bec4363bdfb0da0d741ebc
|
||||
size 134217856
|
||||
3
data/npys/7a63d865c33ab1bab2697ee5a9c2ce3b.npy
Normal file
3
data/npys/7a63d865c33ab1bab2697ee5a9c2ce3b.npy
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:1814f8ae0e6cdb69c0030741a3c6b35c74f2d6985eed34c0d5b4135384014abc
|
||||
size 134217856
|
||||
3
data/npys/7cb0340abb0154d0efafa5a8c7b70944.npy
Normal file
3
data/npys/7cb0340abb0154d0efafa5a8c7b70944.npy
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:d6c84d6be6411e0447e51d0090bce1465bcf609ae63a711ad88823f646809efc
|
||||
size 134217856
|
||||
3
data/npys/8818212e607784aa0a19d8d04370d6e1.npy
Normal file
3
data/npys/8818212e607784aa0a19d8d04370d6e1.npy
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:6e8cdb0328ed84f16e1f5e586acf981a16ae60710b97c3de74f9eb873fafd0cc
|
||||
size 134217856
|
||||
3
data/npys/9850899a5a98e27a33d5dc6650816fea.npy
Normal file
3
data/npys/9850899a5a98e27a33d5dc6650816fea.npy
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:39975a057e7c41a50d99164212b9279d69ad872295b3195d74e34cb6695e6938
|
||||
size 134217856
|
||||
3
data/npys/a0ddcc9a8f2fc25a1152ec78d54d9676.npy
Normal file
3
data/npys/a0ddcc9a8f2fc25a1152ec78d54d9676.npy
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:d53bc0fd0897c7372c2f182b84163bcf401ad91f26b6949c6d7a1d70c5dbb513
|
||||
size 134217856
|
||||
3
data/npys/b93fe9dee334240943c3a95f9e8bfa87.npy
Normal file
3
data/npys/b93fe9dee334240943c3a95f9e8bfa87.npy
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:81aa3425ea73ef060976d4021908dd453a69ba8036de53cff93f52dcec05ba32
|
||||
size 134217856
|
||||
3
data/npys/ba169fa8dcd6439724407c1ad4f9f440.npy
Normal file
3
data/npys/ba169fa8dcd6439724407c1ad4f9f440.npy
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:7a4e1cf76fb2c54ad755745f70757cbebce0b093d33ba2ad437636497da4f68b
|
||||
size 134217856
|
||||
3
data/npys/c31e3076e8c72dedae2c17c477e189f8.npy
Normal file
3
data/npys/c31e3076e8c72dedae2c17c477e189f8.npy
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:ff0411b01ecfa5a0663343f21ba90bf6fa95e5bf5673e99ec26ba2055bbb19f9
|
||||
size 134217856
|
||||
3
data/npys/ceb9676f372a236b2a6ec2a3c535947a.npy
Normal file
3
data/npys/ceb9676f372a236b2a6ec2a3c535947a.npy
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:43b78cdded1b5acd9d569920357d45c9433348a955b794524ee705845828825c
|
||||
size 134217856
|
||||
3
data/npys/cf33366e5842c443d7e67f89947a380f.npy
Normal file
3
data/npys/cf33366e5842c443d7e67f89947a380f.npy
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:fc882f29530f7d683be631214f6667611e0aba87453a11157d71c50f3548fb3c
|
||||
size 134217856
|
||||
3
data/npys/d30358afa19a86066b4fda6adb74c1b4.npy
Normal file
3
data/npys/d30358afa19a86066b4fda6adb74c1b4.npy
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:0f0bc616c2e581444a6fa658e45ee942c1ef5a4d21f22363518331f8d80cbe62
|
||||
size 134217856
|
||||
3
data/npys/df3f6e545a475e0e826db85505b1dfe1.npy
Normal file
3
data/npys/df3f6e545a475e0e826db85505b1dfe1.npy
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:b6311b36169637883c3b6723adac7626f68053d7bbeaee09336bf43b7754c662
|
||||
size 134217856
|
||||
3
data/npys/df441985bbf743f38bfe87b7abb0c474.npy
Normal file
3
data/npys/df441985bbf743f38bfe87b7abb0c474.npy
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:00383256ee1eea343837cf72c2e7b923c540f8871d7b929230e23475f87d993c
|
||||
size 134217856
|
||||
3
data/npys/ec3a1c902e17ca5cb5d85398cea182f7.npy
Normal file
3
data/npys/ec3a1c902e17ca5cb5d85398cea182f7.npy
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:ad94f2af43a2a06ebddc78566cbef0ea538c3da9191682e2fde4ecddb061b0f0
|
||||
size 134217856
|
||||
3
data/npys/eed647d6cb11272ef0d7a0c2b89ee2a0.npy
Normal file
3
data/npys/eed647d6cb11272ef0d7a0c2b89ee2a0.npy
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:6e8cdb0328ed84f16e1f5e586acf981a16ae60710b97c3de74f9eb873fafd0cc
|
||||
size 134217856
|
||||
3
data/npys/f084c8957ce4ed430fed2bd47d482df5.npy
Normal file
3
data/npys/f084c8957ce4ed430fed2bd47d482df5.npy
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:89c3c809fbc2e605869ce9347a10734330aaf3cb56f9bbd66e97295b9edeb642
|
||||
size 134217856
|
||||
3
data/npys/f4a8fbbeced371651d851440bd565ee4.npy
Normal file
3
data/npys/f4a8fbbeced371651d851440bd565ee4.npy
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:6930349d1ae1479dcb4c9ee9eaefe2da3eae62539cb4b97de873a7a8b175e809
|
||||
size 134217856
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:746ea83013870351296e01e294905ee027291ef79fd78c1e6b69dd9ebaa1cba0
|
||||
size 10240000
|
||||
oid sha256:76934d1d202aea1311ba67f5ea35eeb99a9c5c856f491565032e7d54ca6f072d
|
||||
size 13598720
|
||||
|
||||
37
notes/models.md
Normal file
37
notes/models.md
Normal file
@@ -0,0 +1,37 @@
|
||||
# models
|
||||
|
||||
## no polarisation flipping
|
||||
|
||||
```py
|
||||
config_path="data/20241229-163*-128-16384-50000-*.ini"
|
||||
model=".models/best_20241230_011907.tar"
|
||||
```
|
||||
|
||||
```py
|
||||
config_path="data/20241229-163*-128-16384-80000-*.ini"
|
||||
model=".models/best_20241230_103752.tar"
|
||||
```
|
||||
|
||||
```py
|
||||
config_path="data/20241229-163*-128-16384-100000-*.ini"
|
||||
model=".models/best_20241230_164534.tar"
|
||||
```
|
||||
|
||||
## with polarisation flipping
|
||||
|
||||
polarisation flipping: signal is randomly rotated by 180°. polarization rotation can be detected by adding a tone on one of the polarisations, but only to mod 180° with a direct detection setup. the randomly flipped signal should allow the network to hopefully learn to compensate for dispersion, pmd independently from the polarization rot. the training data includes the flipped signal as well, but no indication if the polarisation is flipped.
|
||||
|
||||
```py
|
||||
config_path="data/20241229-163*-128-16384-50000-*.ini"
|
||||
model=".models/best_20241231_000328.tar"
|
||||
```
|
||||
|
||||
```py
|
||||
config_path="data/20241229-163*-128-16384-80000-*.ini"
|
||||
model=".models/best_20241231_163614.tar"
|
||||
```
|
||||
|
||||
```py
|
||||
config_path="data/20241229-163*-128-16384-100000-*.ini"
|
||||
model=".models/best_20241231_170532.tar"
|
||||
```
|
||||
59
notes/tolerance_testing.md
Normal file
59
notes/tolerance_testing.md
Normal file
@@ -0,0 +1,59 @@
|
||||
# Baseline Models
|
||||
|
||||
## a) D+S, pol_error 0, ortho_error 0, DGD 0
|
||||
|
||||
dataset
|
||||
|
||||
```raw
|
||||
data/20250118-225840-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini
|
||||
```
|
||||
|
||||
model
|
||||
|
||||
```raw
|
||||
.models/best_20250118_225918.tar
|
||||
```
|
||||
|
||||
## b) D+S, pol_error 0.4, ortho_error 0, DGD 0
|
||||
|
||||
dataset
|
||||
|
||||
```raw
|
||||
data/20250116-214606-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini
|
||||
```
|
||||
|
||||
model
|
||||
|
||||
```raw
|
||||
.models/best_20250116_214816.tar
|
||||
```
|
||||
|
||||
## c) D+S, pol_error 0, ortho_error 0.1, DGD 0
|
||||
|
||||
dataset
|
||||
|
||||
```raw
|
||||
data/20250116-214547-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini
|
||||
```
|
||||
|
||||
model
|
||||
|
||||
```raw
|
||||
.models/best_20250117_122319.tar
|
||||
```
|
||||
|
||||
## d) D+S, pol_error 0, ortho_error 0, DGD 10ps (1 T_sym)
|
||||
|
||||
birefringence angle pi/2 (worst case)
|
||||
|
||||
dataset
|
||||
|
||||
```raw
|
||||
data/20250117-143926-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini
|
||||
```
|
||||
|
||||
model
|
||||
|
||||
```raw
|
||||
.models/best_20250117_144001.tar
|
||||
```
|
||||
2
pypho
2
pypho
Submodule pypho updated: 08d6dadf20...e44fc477fe
@@ -1,509 +0,0 @@
|
||||
"""
|
||||
generate_signal.py
|
||||
|
||||
This file is part of the repo "optical-regeneration"
|
||||
https://git.suuppl.dev/seppl/optical-regeneration.git
|
||||
|
||||
Joseph Hopfmüller
|
||||
Copyright 2024
|
||||
Licensed under the EUPL
|
||||
|
||||
Full license text in LICENSE file
|
||||
"""
|
||||
|
||||
import configparser
|
||||
from datetime import datetime
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
import time
|
||||
from matplotlib import pyplot as plt # noqa: F401
|
||||
import numpy as np
|
||||
|
||||
import add_pypho # noqa: F401
|
||||
import pypho
|
||||
|
||||
default_config = f"""
|
||||
[glova]
|
||||
nos = 256
|
||||
sps = 256
|
||||
f0 = 193414489032258.06
|
||||
symbolrate = 10e9
|
||||
wisdom_dir = "{str((Path.home() / ".pypho"))}"
|
||||
flags = "FFTW_PATIENT"
|
||||
nthreads = 32
|
||||
|
||||
[fiber]
|
||||
length = 10000
|
||||
gamma = 1.14
|
||||
alpha = 0.2
|
||||
D = 17
|
||||
S = 0
|
||||
birefsteps = 1
|
||||
; birefseed = 0xC0FFEE
|
||||
|
||||
[signal]
|
||||
; seed = 0xC0FFEE
|
||||
|
||||
modulation = "pam"
|
||||
mod_order = 4
|
||||
mod_depth = 0.8
|
||||
|
||||
max_jitter = 0.02
|
||||
; jitter_seed = 0xC0FFEE
|
||||
|
||||
laser_power = 0
|
||||
edfa_power = 3
|
||||
edfa_nf = 5
|
||||
|
||||
pulse_shape = "gauss"
|
||||
fwhm = 0.33
|
||||
|
||||
[data]
|
||||
dir = "data"
|
||||
npy_dir = "npys"
|
||||
"""
|
||||
|
||||
|
||||
def get_config(config_file=None):
|
||||
"""
|
||||
DANGER! The function uses eval() to parse the config file. Do not use this function with untrusted input.
|
||||
"""
|
||||
if config_file is None:
|
||||
config_file = Path(__file__).parent / "signal_generation.ini"
|
||||
if not config_file.exists():
|
||||
with open(config_file, "w") as f:
|
||||
f.write(default_config)
|
||||
config = configparser.ConfigParser()
|
||||
config.read(config_file)
|
||||
|
||||
conf = {}
|
||||
for section in config.sections():
|
||||
# print(f"[{section}]")
|
||||
conf[section] = {}
|
||||
for key in config[section]:
|
||||
# print(f"{key} = {config[section][key]}")
|
||||
conf[section][key] = eval(config[section][key])
|
||||
# if isinstance(conf[section][key], str):
|
||||
# conf[section][key] = config[section][key].strip('"')
|
||||
return conf
|
||||
|
||||
|
||||
class pam_generator:
|
||||
def __init__(
|
||||
self,
|
||||
glova,
|
||||
mod_order=None,
|
||||
mod_depth=0.5,
|
||||
pulse_shape="gauss",
|
||||
fwhm=0.33,
|
||||
seed=None,
|
||||
) -> None:
|
||||
self.glova = glova
|
||||
self.pulse_shape = pulse_shape
|
||||
self.modulation_depth = mod_depth
|
||||
self.mod_order = mod_order
|
||||
self.fwhm = fwhm
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, E, symbols, max_jitter=0):
|
||||
symbols_x = symbols[0] / (self.mod_order or np.max(symbols[0]))
|
||||
symbols_y = symbols[1] / (self.mod_order or np.max(symbols[1]))
|
||||
|
||||
diffs_x = np.diff(symbols_x, prepend=symbols_x[0])
|
||||
diffs_y = np.diff(symbols_y, prepend=symbols_y[0])
|
||||
|
||||
max_jitter = int(round(max_jitter * self.glova.sps))
|
||||
digital_x = self.generate_digital_signal(diffs_x, max_jitter)
|
||||
digital_y = self.generate_digital_signal(diffs_y, max_jitter)
|
||||
|
||||
digital_x = np.pad(
|
||||
digital_x, (0, self.glova.sps // 2), "constant", constant_values=(0, 0)
|
||||
)
|
||||
digital_y = np.pad(
|
||||
digital_y, (0, self.glova.sps // 2), "constant", constant_values=(0, 0)
|
||||
)
|
||||
|
||||
if self.pulse_shape == "gauss":
|
||||
wavelet = self.gauss(oversampling=6)
|
||||
else:
|
||||
raise ValueError(f"Unknown pulse shape: {self.pulse_shape}")
|
||||
|
||||
# create analog signal of diff of symbols
|
||||
E_x = np.convolve(digital_x, wavelet)
|
||||
E_y = np.convolve(digital_y, wavelet)
|
||||
|
||||
# convert to pam and set modulation depth (scale and move up such that 1 stays at 1)
|
||||
E_x = np.cumsum(E_x) * self.modulation_depth + (1 - self.modulation_depth)
|
||||
E_y = np.cumsum(E_y) * self.modulation_depth + (1 - self.modulation_depth)
|
||||
|
||||
# cut off the wavelet tails
|
||||
E_x = E_x[self.glova.sps // 2 + len(wavelet) // 2 - 1 : -len(wavelet) // 2]
|
||||
E_y = E_y[self.glova.sps // 2 + len(wavelet) // 2 - 1 : -len(wavelet) // 2]
|
||||
|
||||
# modulate the laser
|
||||
E[0]["E"][0] = np.sqrt(np.multiply(np.square(E[0]["E"][0]), E_x))
|
||||
E[0]["E"][1] = np.sqrt(np.multiply(np.square(E[0]["E"][1]), E_y))
|
||||
|
||||
return E
|
||||
|
||||
def generate_digital_signal(self, symbols, max_jitter=0):
|
||||
rs = np.random.RandomState(self.seed)
|
||||
signal = np.zeros(self.glova.nos * self.glova.sps)
|
||||
for index in range(self.glova.nos):
|
||||
jitter = max_jitter != 0 and rs.randint(-max_jitter, max_jitter)
|
||||
signal_index = index * self.glova.sps + jitter
|
||||
if signal_index < 0:
|
||||
continue
|
||||
if signal_index >= len(signal):
|
||||
continue
|
||||
signal[signal_index] = symbols[index]
|
||||
return signal
|
||||
|
||||
def gauss(self, oversampling=1):
|
||||
sample_points = np.linspace(
|
||||
-oversampling * self.glova.sps,
|
||||
oversampling * self.glova.sps,
|
||||
oversampling * 2 * self.glova.sps,
|
||||
endpoint=True,
|
||||
)
|
||||
sigma = self.fwhm / (1 * np.sqrt(2 * np.log(2))) * self.glova.sps
|
||||
pulse = (
|
||||
1
|
||||
/ (sigma * np.sqrt(2 * np.pi))
|
||||
* np.exp(-np.square(sample_points) / (2 * np.square(sigma)))
|
||||
)
|
||||
return pulse
|
||||
|
||||
|
||||
def initialize_fiber_and_data(config, input_data_override=None):
|
||||
py_glova = pypho.setup(
|
||||
nos=config["glova"]["nos"],
|
||||
sps=config["glova"]["sps"],
|
||||
f0=config["glova"]["f0"],
|
||||
symbolrate=config["glova"]["symbolrate"],
|
||||
wisdom_dir=config["glova"]["wisdom_dir"],
|
||||
flags=config["glova"]["flags"],
|
||||
nthreads=config["glova"]["nthreads"],
|
||||
)
|
||||
|
||||
c_glova = pypho.cfiber.GlovaWrapper.from_setup(py_glova)
|
||||
c_data = pypho.cfiber.DataWrapper(py_glova.sps * py_glova.nos)
|
||||
py_edfa = pypho.oamp(py_glova, Pmean=config["signal"]["edfa_power"], NF=config["signal"]["edfa_nf"])
|
||||
|
||||
if input_data_override is not None:
|
||||
c_data.E_in = input_data_override[0]
|
||||
noise = input_data_override[1]
|
||||
else:
|
||||
config["signal"]["seed"] = config["signal"].get(
|
||||
"seed", (int(time.time() * 1000)) % 2**32
|
||||
)
|
||||
config["signal"]["jitter_seed"] = config["signal"].get(
|
||||
"jitter_seed", (int(time.time() * 1000)) % 2**32
|
||||
)
|
||||
symbolsrc = pypho.symbols(
|
||||
py_glova, py_glova.nos, pattern="ones", p1=config["signal"]["mod_order"], seed=config["signal"]["seed"]
|
||||
)
|
||||
laser = pypho.lasmod(
|
||||
py_glova, power=config["signal"]["laser_power"]+1.5, Df=0, theta=np.pi / 4
|
||||
)
|
||||
modulator = pam_generator(
|
||||
py_glova,
|
||||
mod_depth=config["signal"]["mod_depth"],
|
||||
pulse_shape=config["signal"]["pulse_shape"],
|
||||
fwhm=config["signal"]["fwhm"],
|
||||
seed=config["signal"]["jitter_seed"],
|
||||
)
|
||||
|
||||
symbols_x = symbolsrc(pattern="random")
|
||||
symbols_y = symbolsrc(pattern="random")
|
||||
symbols_x[:3] = 0
|
||||
symbols_y[:3] = 0
|
||||
|
||||
cw = laser()
|
||||
|
||||
source_signal = modulator(E=cw, symbols=(symbols_x, symbols_y))
|
||||
|
||||
source_signal = py_edfa(E=source_signal)
|
||||
|
||||
c_data.E_in = source_signal[0]["E"]
|
||||
noise = source_signal[0]["noise"]
|
||||
|
||||
py_fiber = pypho.fiber(
|
||||
glova=py_glova,
|
||||
l=config["fiber"]["length"],
|
||||
alpha=pypho.functions.dB_to_Neper(config["fiber"]["alpha"]) / 1000,
|
||||
gamma=config["fiber"]["gamma"],
|
||||
D=config["fiber"]["d"],
|
||||
S=config["fiber"]["s"],
|
||||
)
|
||||
if config["fiber"]["birefsteps"] > 0:
|
||||
config["fiber"]["birefseed"] = config["fiber"].get(
|
||||
"birefseed", (int(time.time() * 1000)) % 2**32
|
||||
)
|
||||
py_fiber.birefarray = pypho.birefringence_segment.create_pmd_fibre(
|
||||
py_fiber.l,
|
||||
py_fiber.l / config["fiber"]["birefsteps"],
|
||||
0,
|
||||
config["fiber"]["birefseed"],
|
||||
)
|
||||
c_params = pypho.cfiber.ParamsWrapper.from_fiber(
|
||||
py_fiber, max_step=1e3 if py_fiber.gamma == 0 else 200
|
||||
)
|
||||
c_fiber = pypho.cfiber.FiberWrapper(c_data, c_params, c_glova)
|
||||
|
||||
return c_fiber, c_data, noise, py_edfa
|
||||
|
||||
|
||||
def save_data(data, config):
|
||||
data_dir = Path(config["data"]["dir"])
|
||||
npy_dir = config["data"].get("npy_dir", "")
|
||||
save_dir = data_dir / npy_dir if len(npy_dir) else data_dir
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
save_data = np.column_stack([
|
||||
data.E_in[0],
|
||||
data.E_in[1],
|
||||
data.E_out[0],
|
||||
data.E_out[1],
|
||||
])
|
||||
timestamp = datetime.now()
|
||||
seed = config["signal"].get("seed", False)
|
||||
jitter_seed = config["signal"].get("jitter_seed", False)
|
||||
birefseed = config["fiber"].get("birefseed", False)
|
||||
|
||||
config_content = "\n".join((
|
||||
f"; Generated by {str(Path(__file__).name)} @ {timestamp.strftime('%Y-%m-%d %H:%M:%S')}",
|
||||
"[glova]",
|
||||
f"sps = {config['glova']['sps']}",
|
||||
f"nos = {config['glova']['nos']}",
|
||||
f"f0 = {config['glova']['f0']}",
|
||||
f"symbolrate = {config['glova']['symbolrate']}",
|
||||
f'wisdom_dir = "{config["glova"]["wisdom_dir"]}"',
|
||||
f'flags = "{config["glova"]["flags"]}"',
|
||||
f"nthreads = {config['glova']['nthreads']}",
|
||||
" ",
|
||||
"[fiber]",
|
||||
f"length = {config['fiber']['length']}",
|
||||
f"gamma = {config['fiber']['gamma']}",
|
||||
f"alpha = {config['fiber']['alpha']}",
|
||||
f"D = {config['fiber']['d']}",
|
||||
f"S = {config['fiber']['s']}",
|
||||
f"birefsteps = {config['fiber']['birefsteps']}",
|
||||
f"birefseed = {hex(birefseed)}" if birefseed else "; birefseed = not set",
|
||||
"",
|
||||
"[signal]",
|
||||
f"seed = {hex(seed)}" if seed else "; seed = not set",
|
||||
"",
|
||||
f'modulation = "{config["signal"]["modulation"]}"',
|
||||
f"mod_order = {config['signal']['mod_order']}",
|
||||
f"mod_depth = {config['signal']['mod_depth']}",
|
||||
""
|
||||
f"max_jitter = {config['signal']['max_jitter']}",
|
||||
f"jitter_seed = {hex(jitter_seed)}" if jitter_seed else "; jitter_seed = not set",
|
||||
""
|
||||
f"laser_power = {config['signal']['laser_power']}",
|
||||
f"edfa_power = {config['signal']['edfa_power']}",
|
||||
f"edfa_nf = {config['signal']['edfa_nf']}",
|
||||
""
|
||||
f'pulse_shape = "{config["signal"]["pulse_shape"]}"',
|
||||
f"fwhm = {config['signal']['fwhm']}",
|
||||
"",
|
||||
"[data]",
|
||||
f'dir = "{str(data_dir)}"',
|
||||
f'npy_dir = "{npy_dir}"',
|
||||
"file = "
|
||||
))
|
||||
config_hash = hashlib.md5(config_content.encode()).hexdigest()
|
||||
save_file = f"{config_hash}.npy"
|
||||
config_content += f'"{str(save_file)}"\n'
|
||||
|
||||
filename_components = (
|
||||
timestamp.strftime("%Y%m%d-%H%M%S"),
|
||||
config["glova"]["sps"],
|
||||
config["glova"]["nos"],
|
||||
config["fiber"]["length"],
|
||||
config["fiber"]["gamma"],
|
||||
config["fiber"]["alpha"],
|
||||
config["fiber"]["d"],
|
||||
config["fiber"]["s"],
|
||||
f"{config['signal']['modulation'].upper()}{config['signal']['mod_order']}",
|
||||
config["fiber"]["birefsteps"],
|
||||
)
|
||||
|
||||
lookup_file = "-".join(map(str, filename_components)) + ".ini"
|
||||
with open(data_dir / lookup_file, "w") as f:
|
||||
f.write(config_content)
|
||||
|
||||
np.save(save_dir / save_file, save_data)
|
||||
|
||||
print("Saved config to", data_dir / lookup_file)
|
||||
print("Saved data to", save_dir / save_file)
|
||||
|
||||
|
||||
def length_loop(config, lengths, incremental=False):
|
||||
lengths = sorted(lengths)
|
||||
input_override = None
|
||||
for lind, length in enumerate(lengths):
|
||||
# print(f"\nGenerating data for fiber length {length}")
|
||||
if lind > 0 and incremental:
|
||||
# set the length to the difference between the current and previous length -> incremental
|
||||
length = lengths[lind] - lengths[lind - 1]
|
||||
if incremental:
|
||||
print(
|
||||
f"\nGenerating data for fiber length {lengths[lind]}m [using {length}m increment]"
|
||||
)
|
||||
else:
|
||||
print(f"\nGenerating data for fiber length {length}m")
|
||||
config["fiber"]["length"] = length
|
||||
# set the input data to the output data of the previous run
|
||||
cfiber, cdata, noise, edfa = initialize_fiber_and_data(
|
||||
config, input_data_override=input_override
|
||||
)
|
||||
|
||||
if lind == 0:
|
||||
cdata_orig = cdata
|
||||
|
||||
mean_power_in = np.sum(pypho.functions.getpower_W(cdata.E_in))
|
||||
print(
|
||||
f"Mean power in: {mean_power_in:.3e} W ({pypho.functions.W_to_dBm(mean_power_in):.3e} dBm)"
|
||||
)
|
||||
|
||||
cfiber()
|
||||
|
||||
|
||||
mean_power_out = np.sum(pypho.functions.getpower_W(cdata.E_out))
|
||||
print(
|
||||
f"Mean power out: {mean_power_out:.3e} W ({pypho.functions.W_to_dBm(mean_power_out):.3e} dBm)"
|
||||
)
|
||||
|
||||
if incremental:
|
||||
input_override = (cdata.E_out, noise)
|
||||
cdata.E_in = cdata_orig.E_in
|
||||
config["fiber"]["length"] = lengths[lind]
|
||||
|
||||
E_tmp = [{'E': cdata.E_out, 'noise': noise*(-cfiber.params.l*cfiber.params.alpha)}]
|
||||
E_tmp = edfa(E=E_tmp)
|
||||
cdata.E_out = E_tmp[0]['E']
|
||||
save_data(cdata, config)
|
||||
|
||||
in_out_eyes(cfiber, cdata)
|
||||
|
||||
|
||||
def single_run_with_plot(config, save=True):
|
||||
cfiber, cdata, noise, edfa = initialize_fiber_and_data(config)
|
||||
|
||||
mean_power_in = np.sum(pypho.functions.getpower_W(cdata.E_in))
|
||||
print(
|
||||
f"Mean power in: {mean_power_in:.3e} W ({pypho.functions.W_to_dBm(mean_power_in):.3e} dBm)"
|
||||
)
|
||||
|
||||
cfiber()
|
||||
|
||||
mean_power_out = np.sum(pypho.functions.getpower_W(cdata.E_out))
|
||||
print(
|
||||
f"Mean power out: {mean_power_out:.3e} W ({pypho.functions.W_to_dBm(mean_power_out):.3e} dBm)"
|
||||
)
|
||||
|
||||
E_tmp = [{'E': cdata.E_out, 'noise': noise*(-cfiber.params.l*cfiber.params.alpha)}]
|
||||
E_tmp = edfa(E=E_tmp)
|
||||
cdata.E_out = E_tmp[0]['E']
|
||||
if save:
|
||||
save_data(cdata, config)
|
||||
|
||||
in_out_eyes(cfiber, cdata)
|
||||
|
||||
def in_out_eyes(cfiber, cdata):
|
||||
fig, axs = plt.subplots(2, 2, sharex=True, sharey=True)
|
||||
eye_head = min(cfiber.glova.nos, 2000)
|
||||
symbolrate_scale = 1e12
|
||||
amplitude_scale = 1e3
|
||||
plot_eye_diagram(
|
||||
amplitude_scale * np.abs(cdata.E_in[0]) ** 2,
|
||||
2 * cfiber.glova.sps,
|
||||
normalize=False,
|
||||
samplerate=cfiber.glova.symbolrate * cfiber.glova.sps / symbolrate_scale,
|
||||
head=eye_head,
|
||||
ax=axs[0][0],
|
||||
show=False,
|
||||
)
|
||||
plot_eye_diagram(
|
||||
amplitude_scale * np.abs(cdata.E_out[0]) ** 2,
|
||||
2 * cfiber.glova.sps,
|
||||
normalize=False,
|
||||
samplerate=cfiber.glova.symbolrate * cfiber.glova.sps / symbolrate_scale,
|
||||
head=eye_head,
|
||||
ax=axs[0][1],
|
||||
color="C1",
|
||||
show=False,
|
||||
)
|
||||
plot_eye_diagram(
|
||||
amplitude_scale * np.abs(cdata.E_in[1]) ** 2,
|
||||
2 * cfiber.glova.sps,
|
||||
normalize=False,
|
||||
samplerate=cfiber.glova.symbolrate * cfiber.glova.sps / symbolrate_scale,
|
||||
head=eye_head,
|
||||
ax=axs[1][0],
|
||||
show=False,
|
||||
)
|
||||
plot_eye_diagram(
|
||||
amplitude_scale * np.abs(cdata.E_out[1]) ** 2,
|
||||
2 * cfiber.glova.sps,
|
||||
normalize=False,
|
||||
samplerate=cfiber.glova.symbolrate * cfiber.glova.sps / symbolrate_scale,
|
||||
head=eye_head,
|
||||
ax=axs[1][1],
|
||||
color="C1",
|
||||
show=False,
|
||||
)
|
||||
|
||||
title_map = [
|
||||
["Input x", "Output x"],
|
||||
["Input y", "Output y"],
|
||||
]
|
||||
title_map = np.array(title_map)
|
||||
for ax, title in zip(axs.flatten(), title_map.flatten()):
|
||||
ax.grid(True)
|
||||
ax.set_xlabel("Time [ps]")
|
||||
ax.set_ylabel("Power [mW]")
|
||||
ax.set_title(title)
|
||||
fig.tight_layout()
|
||||
|
||||
plt.show()
|
||||
|
||||
|
||||
def plot_eye_diagram(
|
||||
signal: np.ndarray,
|
||||
eye_width,
|
||||
offset=0,
|
||||
*,
|
||||
head=None,
|
||||
samplerate=1,
|
||||
normalize=True,
|
||||
ax=None,
|
||||
color="C0",
|
||||
show=True,
|
||||
):
|
||||
ax = ax or plt.gca()
|
||||
if head is not None:
|
||||
signal = signal[: head * eye_width]
|
||||
if normalize:
|
||||
signal = signal / np.max(signal)
|
||||
slices = np.lib.stride_tricks.sliding_window_view(signal, eye_width + 1)[
|
||||
offset % (eye_width + 1) :: eye_width
|
||||
]
|
||||
plt_ax = np.arange(-eye_width // 2, eye_width // 2 + 1) / samplerate
|
||||
for slice in slices:
|
||||
ax.plot(plt_ax, slice, color=color, alpha=0.1)
|
||||
ax.grid()
|
||||
if show:
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
add_pypho.show_log()
|
||||
config = get_config()
|
||||
|
||||
lengths = np.arange(50000, 100000+6000, 1000)
|
||||
length_loop(config, lengths, incremental=False)
|
||||
|
||||
single_run_with_plot(config, save=False)
|
||||
|
||||
@@ -1,7 +1,16 @@
|
||||
import copy
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
import random
|
||||
from typing import Literal
|
||||
import matplotlib
|
||||
from matplotlib.colors import LinearSegmentedColormap
|
||||
import torch.nn.utils.parametrize
|
||||
|
||||
try:
|
||||
matplotlib.use("cairo")
|
||||
except ImportError:
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
import numpy as np
|
||||
@@ -11,35 +20,19 @@ import optuna
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
# import torch.nn as nn
|
||||
|
||||
# import torch.nn.functional as F # mse_loss doesn't support complex numbers
|
||||
import torch.optim as optim
|
||||
import torch.utils.data
|
||||
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
import hypertraining.models as models
|
||||
|
||||
# from rich.progress import (
|
||||
# Progress,
|
||||
# TextColumn,
|
||||
# BarColumn,
|
||||
# TaskProgressColumn,
|
||||
# TimeRemainingColumn,
|
||||
# MofNCompleteColumn,
|
||||
# TimeElapsedColumn,
|
||||
# )
|
||||
# from rich.console import Console
|
||||
# from rich import print as rprint
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
import multiprocessing
|
||||
|
||||
from util.datasets import FiberRegenerationDataset
|
||||
|
||||
# from util.optuna_helpers import (
|
||||
# suggest_categorical_optional, # noqa: F401
|
||||
# suggest_float_optional, # noqa: F401
|
||||
# suggest_int_optional, # noqa: F401
|
||||
# )
|
||||
from util.optuna_helpers import install_optional_suggests
|
||||
import util
|
||||
|
||||
@@ -65,7 +58,6 @@ class HyperTraining:
|
||||
model_settings,
|
||||
optimizer_settings,
|
||||
optuna_settings,
|
||||
# console=None,
|
||||
):
|
||||
self.global_settings: GlobalSettings = global_settings
|
||||
self.data_settings: DataSettings = data_settings
|
||||
@@ -75,11 +67,8 @@ class HyperTraining:
|
||||
self.optuna_settings: OptunaSettings = optuna_settings
|
||||
self.processes = None
|
||||
|
||||
# self.console = console or Console()
|
||||
|
||||
# set some extra settings to make the code more readable
|
||||
self._extra_optuna_settings()
|
||||
self.stop_study = True
|
||||
self.stop_study = False
|
||||
|
||||
def setup_tb_writer(self, study_name=None, append=None):
|
||||
log_dir = self.pytorch_settings.summary_dir + "/" + (study_name or self.optuna_settings.study_name)
|
||||
@@ -229,7 +218,7 @@ class HyperTraining:
|
||||
self.optuna_settings._parallel = self.optuna_settings._n_threads > 1
|
||||
|
||||
def define_model(self, trial: optuna.Trial, writer=None):
|
||||
n_layers = trial.suggest_int_optional("model_n_hidden_layers", self.model_settings.n_hidden_layers)
|
||||
n_hidden_layers = trial.suggest_int_optional("model_n_hidden_layers", self.model_settings.n_hidden_layers)
|
||||
|
||||
input_dim = trial.suggest_int_optional(
|
||||
"model_input_dim",
|
||||
@@ -245,32 +234,44 @@ class HyperTraining:
|
||||
dtype = getattr(torch, dtype)
|
||||
|
||||
afunc = trial.suggest_categorical_optional("model_activation_func", self.model_settings.model_activation_func)
|
||||
# T0 = trial.suggest_float_optional("T0", self.model_settings.satabsT0 , log=True)
|
||||
afunc = getattr(util.complexNN, afunc)
|
||||
layer_func = trial.suggest_categorical_optional("model_layer_function", self.model_settings.model_layer_function)
|
||||
layer_func = getattr(util.complexNN, layer_func)
|
||||
layer_parametrizations = self.model_settings.model_layer_parametrizations
|
||||
|
||||
layers = []
|
||||
last_dim = input_dim
|
||||
n_nodes = last_dim
|
||||
for i in range(n_layers):
|
||||
scale_layers = trial.suggest_categorical_optional("model_enable_scale_layers", self.model_settings.scale)
|
||||
|
||||
|
||||
hidden_dims = []
|
||||
for i in range(n_hidden_layers):
|
||||
if hidden_dim_override := self.model_settings.overrides.get(f"n_hidden_nodes_{i}", False):
|
||||
hidden_dim = trial.suggest_int_optional(f"model_hidden_dim_{i}", hidden_dim_override)
|
||||
hidden_dims.append(trial.suggest_int_optional(f"model_hidden_dim_{i}", hidden_dim_override))
|
||||
else:
|
||||
hidden_dim = trial.suggest_int_optional(
|
||||
hidden_dims.append(trial.suggest_int_optional(
|
||||
f"model_hidden_dim_{i}",
|
||||
self.model_settings.n_hidden_nodes,
|
||||
)
|
||||
layers.append(util.complexNN.ONNRect(last_dim, hidden_dim, dtype=dtype))
|
||||
last_dim = hidden_dim
|
||||
layers.append(getattr(util.complexNN, afunc)())
|
||||
n_nodes += last_dim
|
||||
|
||||
layers.append(util.complexNN.ONNRect(last_dim, self.model_settings.output_dim, dtype=dtype))
|
||||
|
||||
model = nn.Sequential(*layers)
|
||||
))
|
||||
|
||||
model_kwargs = {
|
||||
"dims": (input_dim, *hidden_dims, self.model_settings.output_dim),
|
||||
"layer_function": layer_func,
|
||||
"layer_func_kwargs": self.model_settings.model_layer_kwargs,
|
||||
"act_function": afunc,
|
||||
"act_func_kwargs": None,
|
||||
"parametrizations": layer_parametrizations,
|
||||
"dtype": dtype,
|
||||
"dropout_prob": self.model_settings.dropout_prob,
|
||||
"scale_layers": scale_layers,
|
||||
"rotate": False,
|
||||
}
|
||||
|
||||
model = models.regenerator(*model_kwargs.pop("dims"), **model_kwargs)
|
||||
n_nodes = sum(hidden_dims)
|
||||
|
||||
if writer is not None:
|
||||
writer.add_graph(model, torch.zeros(1, input_dim, dtype=dtype), use_strict_trace=False)
|
||||
|
||||
n_params = sum(p.numel() for p in model.parameters())
|
||||
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
trial.set_user_attr("model_n_params", n_params)
|
||||
trial.set_user_attr("model_n_nodes", n_nodes)
|
||||
|
||||
@@ -384,7 +385,11 @@ class HyperTraining:
|
||||
running_loss2 = 0.0
|
||||
running_loss = 0.0
|
||||
model.train()
|
||||
for batch_idx, (x, y) in enumerate(train_loader):
|
||||
loader_len = len(train_loader)
|
||||
for batch_idx, batch in enumerate(train_loader):
|
||||
x = batch["x"]
|
||||
y = batch["y"]
|
||||
|
||||
if batch_idx >= self.optuna_settings._n_train_batches:
|
||||
break
|
||||
model.zero_grad(set_to_none=True)
|
||||
@@ -393,7 +398,7 @@ class HyperTraining:
|
||||
y.to(self.pytorch_settings.device),
|
||||
)
|
||||
y_pred = model(x)
|
||||
loss = util.complexNN.complex_mse_loss(y_pred, y)
|
||||
loss = util.complexNN.complex_mse_loss(y_pred, y, power=True)
|
||||
loss_value = loss.item()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
@@ -408,14 +413,14 @@ class HyperTraining:
|
||||
writer.add_scalar(
|
||||
"training loss",
|
||||
running_loss2 / (self.pytorch_settings.write_every if batch_idx > 0 else 1),
|
||||
epoch * min(len(train_loader), self.optuna_settings._n_train_batches) + batch_idx,
|
||||
epoch * min(loader_len, self.optuna_settings._n_train_batches) + batch_idx,
|
||||
)
|
||||
running_loss2 = 0.0
|
||||
|
||||
# if enable_progress:
|
||||
# progress.stop()
|
||||
|
||||
return running_loss / min(len(train_loader), self.optuna_settings._n_train_batches)
|
||||
return running_loss / min(loader_len, self.optuna_settings._n_train_batches)
|
||||
|
||||
def eval_model(
|
||||
self,
|
||||
@@ -446,9 +451,10 @@ class HyperTraining:
|
||||
|
||||
model.eval()
|
||||
running_error = 0
|
||||
running_error_2 = 0
|
||||
with torch.no_grad():
|
||||
for batch_idx, (x, y) in enumerate(valid_loader):
|
||||
for batch_idx, batch in enumerate(valid_loader):
|
||||
x = batch["x"]
|
||||
y = batch["y"]
|
||||
if batch_idx >= self.optuna_settings._n_valid_batches:
|
||||
break
|
||||
x, y = (
|
||||
@@ -456,72 +462,91 @@ class HyperTraining:
|
||||
y.to(self.pytorch_settings.device),
|
||||
)
|
||||
y_pred = model(x)
|
||||
error = util.complexNN.complex_mse_loss(y_pred, y)
|
||||
error = util.complexNN.complex_mse_loss(y_pred, y, power=True)
|
||||
error_value = error.item()
|
||||
running_error += error_value
|
||||
running_error_2 += error_value
|
||||
|
||||
# if enable_progress:
|
||||
# progress.update(task, advance=1, description=f"{error_value:.3e}")
|
||||
|
||||
if writer is not None:
|
||||
if batch_idx % self.pytorch_settings.write_every == 0:
|
||||
writer.add_scalar(
|
||||
"eval loss",
|
||||
running_error_2 / (self.pytorch_settings.write_every if batch_idx > 0 else 1),
|
||||
epoch * min(len(valid_loader), self.optuna_settings._n_valid_batches) + batch_idx,
|
||||
)
|
||||
running_error_2 = 0.0
|
||||
|
||||
running_error /= min(len(valid_loader), self.optuna_settings._n_valid_batches)
|
||||
|
||||
if writer is not None:
|
||||
title_append, subtitle = self.build_title(trial)
|
||||
writer.add_figure(
|
||||
"fiber response",
|
||||
self.plot_model_response(
|
||||
trial,
|
||||
model=model,
|
||||
title_append=title_append,
|
||||
subtitle=subtitle,
|
||||
show=False,
|
||||
),
|
||||
epoch + 1,
|
||||
writer.add_scalar(
|
||||
"eval loss",
|
||||
running_error,
|
||||
epoch,
|
||||
)
|
||||
# if (epoch + 1) % 10 == 0 or epoch < 10:
|
||||
# # plotting is slow, so only do it every 10 epochs
|
||||
# title_append, subtitle = self.build_title(trial)
|
||||
# head_fig, eye_fig, powers_fig = self.plot_model_response(
|
||||
# model=model,
|
||||
# title_append=title_append,
|
||||
# subtitle=subtitle,
|
||||
# show=False,
|
||||
# )
|
||||
# writer.add_figure(
|
||||
# "fiber response",
|
||||
# head_fig,
|
||||
# epoch + 1,
|
||||
# )
|
||||
# writer.add_figure(
|
||||
# "eye diagram",
|
||||
# eye_fig,
|
||||
# epoch + 1,
|
||||
# )
|
||||
|
||||
# writer.add_figure(
|
||||
# "powers",
|
||||
# powers_fig,
|
||||
# epoch + 1,
|
||||
# )
|
||||
# writer.flush()
|
||||
|
||||
# if enable_progress:
|
||||
# progress.stop()
|
||||
|
||||
return running_error
|
||||
|
||||
def run_model(self, model, loader):
|
||||
def run_model(self, model, loader, trace_powers=False):
|
||||
model.eval()
|
||||
xs = []
|
||||
ys = []
|
||||
y_preds = []
|
||||
fiber_out = []
|
||||
fiber_in = []
|
||||
regen = []
|
||||
timestamps = []
|
||||
|
||||
with torch.no_grad():
|
||||
model = model.to(self.pytorch_settings.device)
|
||||
for x, y in loader:
|
||||
for batch in loader:
|
||||
x = batch["x"]
|
||||
y = batch["y"]
|
||||
timestamp = batch["timestamp"]
|
||||
x, y = (
|
||||
x.to(self.pytorch_settings.device),
|
||||
y.to(self.pytorch_settings.device),
|
||||
)
|
||||
y_pred = model(x).cpu()
|
||||
if trace_powers:
|
||||
y_pred, powers = model(x, trace_powers=True).cpu()
|
||||
else:
|
||||
y_pred = model(x, trace_powers=True).cpu()
|
||||
# x = x.cpu()
|
||||
# y = y.cpu()
|
||||
y_pred = y_pred.view(y_pred.shape[0], -1, 2)
|
||||
y = y.view(y.shape[0], -1, 2)
|
||||
x = x.view(x.shape[0], -1, 2)
|
||||
xs.append(x[:, 0, :].squeeze())
|
||||
ys.append(y.squeeze())
|
||||
y_preds.append(y_pred.squeeze())
|
||||
# timestamp = timestamp.view(-1, 1)
|
||||
fiber_out.append(x[:, x.shape[1] // 2, :].squeeze())
|
||||
fiber_in.append(y.squeeze())
|
||||
regen.append(y_pred.squeeze())
|
||||
timestamps.append(timestamp.squeeze())
|
||||
|
||||
xs = torch.vstack(xs).cpu()
|
||||
ys = torch.vstack(ys).cpu()
|
||||
y_preds = torch.vstack(y_preds).cpu()
|
||||
return ys, xs, y_preds
|
||||
fiber_out = torch.vstack(fiber_out).cpu()
|
||||
fiber_in = torch.vstack(fiber_in).cpu()
|
||||
regen = torch.vstack(regen).cpu()
|
||||
timestamps = torch.concat(timestamps).cpu()
|
||||
if trace_powers:
|
||||
return fiber_in, fiber_out, regen, timestamps, powers
|
||||
return fiber_in, fiber_out, regen, timestamps
|
||||
|
||||
def objective(self, trial: optuna.Trial, plot_before=False):
|
||||
def objective(self, trial: optuna.Trial):
|
||||
if self.stop_study:
|
||||
trial.study.stop()
|
||||
model = None
|
||||
@@ -537,29 +562,54 @@ class HyperTraining:
|
||||
|
||||
title_append, subtitle = self.build_title(trial)
|
||||
|
||||
writer.add_figure(
|
||||
"fiber response",
|
||||
self.plot_model_response(
|
||||
trial,
|
||||
model=model,
|
||||
title_append=title_append,
|
||||
subtitle=subtitle,
|
||||
show=plot_before,
|
||||
),
|
||||
0,
|
||||
)
|
||||
# writer.add_figure(
|
||||
# "fiber response",
|
||||
# self.plot_model_response(
|
||||
# trial,
|
||||
# model=model,
|
||||
# title_append=title_append,
|
||||
# subtitle=subtitle,
|
||||
# show=False,
|
||||
# ),
|
||||
# 0,
|
||||
# )
|
||||
# writer.add_figure(
|
||||
# "eye diagram",
|
||||
# self.plot_model_response(
|
||||
# trial,
|
||||
# model=self.model,
|
||||
# title_append=title_append,
|
||||
# subtitle=subtitle,
|
||||
# mode="eye",
|
||||
# show=False,
|
||||
# ),
|
||||
# 0,
|
||||
# )
|
||||
|
||||
# writer.add_figure(
|
||||
# "powers",
|
||||
# self.plot_model_response(
|
||||
# trial,
|
||||
# model=self.model,
|
||||
# title_append=title_append,
|
||||
# subtitle=subtitle,
|
||||
# mode="powers",
|
||||
# show=False,
|
||||
# ),
|
||||
# 0,
|
||||
# )
|
||||
|
||||
train_loader, valid_loader = self.get_sliced_data(trial)
|
||||
|
||||
optimizer_name = trial.suggest_categorical_optional("optimizer", self.optimizer_settings.optimizer)
|
||||
|
||||
lr = trial.suggest_float_optional("lr", self.optimizer_settings.learning_rate, log=True)
|
||||
lr = trial.suggest_float_optional("lr", self.optimizer_settings.optimizer_kwargs["lr"], log=True)
|
||||
|
||||
optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)
|
||||
if self.optimizer_settings.scheduler is not None:
|
||||
scheduler = getattr(optim.lr_scheduler, self.optimizer_settings.scheduler)(
|
||||
optimizer, **self.optimizer_settings.scheduler_kwargs
|
||||
)
|
||||
# if self.optimizer_settings.scheduler is not None:
|
||||
# scheduler = getattr(optim.lr_scheduler, self.optimizer_settings.scheduler)(
|
||||
# optimizer, **self.optimizer_settings.scheduler_kwargs
|
||||
# )
|
||||
|
||||
for epoch in range(self.pytorch_settings.epochs):
|
||||
trial.set_user_attr("epoch", epoch)
|
||||
@@ -585,8 +635,8 @@ class HyperTraining:
|
||||
writer,
|
||||
# enable_progress=enable_progress,
|
||||
)
|
||||
if self.optimizer_settings.scheduler is not None:
|
||||
scheduler.step(error)
|
||||
# if self.optimizer_settings.scheduler is not None:
|
||||
# scheduler.step(error)
|
||||
|
||||
trial.set_user_attr("mse", error)
|
||||
trial.set_user_attr("log_mse", np.log10(error + np.finfo(float).eps))
|
||||
@@ -602,14 +652,16 @@ class HyperTraining:
|
||||
if self.optuna_settings._multi_objective:
|
||||
return -np.log10(error + np.finfo(float).eps), trial.user_attrs.get("model_n_nodes", -1)
|
||||
|
||||
if self.pytorch_settings.save_models and model is not None:
|
||||
save_path = Path(self.pytorch_settings.model_dir) / f"{self.optuna_settings.study_name}_{trial.number}.pth"
|
||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
torch.save(model, save_path)
|
||||
# if self.pytorch_settings.save_models and model is not None:
|
||||
# save_path = Path(self.pytorch_settings.model_dir) / f"{self.optuna_settings.study_name}_{trial.number}.pth"
|
||||
# save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
# torch.save(model, save_path)
|
||||
|
||||
return error
|
||||
|
||||
def _plot_model_response_eye(self, *signals, labels=None, sps=None, title_append="", subtitle="", show=True):
|
||||
def _plot_model_response_eye(
|
||||
self, *signals, timestamps, labels=None, sps=None, title_append="", subtitle="", show=True
|
||||
):
|
||||
if sps is None:
|
||||
raise ValueError("sps must be provided")
|
||||
if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))):
|
||||
@@ -624,27 +676,84 @@ class HyperTraining:
|
||||
if not any(labels):
|
||||
labels = [f"signal {i + 1}" for i in range(len(signals))]
|
||||
|
||||
fig, axs = plt.subplots(2, len(signals), sharex=True, sharey=True)
|
||||
x_bins = np.linspace(0, 2, 2 * sps, endpoint=False)
|
||||
y_bins = np.zeros((2 * len(signals), 1000))
|
||||
eye_data = np.zeros((2 * len(signals), 1000, 2 * sps))
|
||||
# signals = [signal.cpu().numpy() for signal in signals]
|
||||
for i in range(len(signals) * 2):
|
||||
eye_signal = signals[i // 2][:, i % 2] # x, y, x, y, ...
|
||||
eye_signal = np.real(np.square(np.abs(eye_signal)))
|
||||
data_min = np.min(eye_signal)
|
||||
data_max = np.max(eye_signal)
|
||||
y_bins[i] = np.linspace(data_min, data_max, 1000, endpoint=False)
|
||||
for j in range(len(timestamps)):
|
||||
t = timestamps[j] / sps
|
||||
val = eye_signal[j]
|
||||
x = np.digitize(t % 2, x_bins) - 1
|
||||
y = np.digitize(val, y_bins[i]) - 1
|
||||
eye_data[i][y][x] += 1
|
||||
|
||||
cmap = LinearSegmentedColormap.from_list(
|
||||
"eyemap",
|
||||
[
|
||||
(0, "white"),
|
||||
(0.001, "dodgerblue"),
|
||||
(0.1, "blue"),
|
||||
(0.2, "cyan"),
|
||||
(0.5, "lime"),
|
||||
(0.8, "gold"),
|
||||
(1, "red"),
|
||||
],
|
||||
)
|
||||
|
||||
# ordering = np.argsort(timestamps)
|
||||
# signals = [signal[ordering] for signal in signals]
|
||||
# timestamps = timestamps[ordering]
|
||||
|
||||
fig, axs = plt.subplots(1, 2 * len(signals), sharex=True, sharey=True)
|
||||
fig.set_figwidth(18)
|
||||
fig.suptitle(f"Eye diagram{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}")
|
||||
xaxis = np.linspace(0, 2, 2 * sps, endpoint=False)
|
||||
for j, (label, signal) in enumerate(zip(labels, signals)):
|
||||
# xaxis = timestamps / sps
|
||||
# xaxis = np.arange(2 * sps) / sps
|
||||
for j, label in enumerate(labels):
|
||||
x = eye_data[2 * j]
|
||||
y = eye_data[2 * j + 1]
|
||||
# x, y = signal.T
|
||||
# signal = signal.cpu().numpy()
|
||||
for i in range(len(signal) // sps - 1):
|
||||
x, y = signal[i * sps : (i + 2) * sps].T
|
||||
axs[0, j].plot(xaxis, np.abs(x) ** 2, color="C0", alpha=0.02)
|
||||
axs[1, j].plot(xaxis, np.abs(y) ** 2, color="C0", alpha=0.02)
|
||||
axs[0, j].set_title(label + " x")
|
||||
axs[1, j].set_title(label + " y")
|
||||
axs[0, j].set_xlabel("Symbol")
|
||||
axs[1, j].set_xlabel("Symbol")
|
||||
axs[0, j].set_ylabel("normalized power")
|
||||
axs[1, j].set_ylabel("normalized power")
|
||||
# for i in range(len(signal) // sps - 1):
|
||||
# x, y = signal[i * sps : (i + 2) * sps].T
|
||||
# axs[0 + 2 * j].scatter((timestamps/sps) % 2, np.abs(x) ** 2, color=f"C{j}", alpha=1 / (len(signal) // sps) * 10, s=1)
|
||||
# axs[1 + 2 * j].scatter((timestamps/sps) % 2, np.abs(y) ** 2, color=f"C{j}", alpha=1 / (len(signal) // sps) * 10, s=1)
|
||||
axs[0 + 2 * j].imshow(
|
||||
x, aspect="auto", cmap=cmap, origin="lower", extent=[0, 2, y_bins[2 * j][0], y_bins[2 * j][-1]]
|
||||
)
|
||||
axs[1 + 2 * j].imshow(
|
||||
y, aspect="auto", cmap=cmap, origin="lower", extent=[0, 2, y_bins[2 * j + 1][0], y_bins[2 * j + 1][-1]]
|
||||
)
|
||||
axs[0 + 2 * j].set_xlim((x_bins[0], x_bins[-1]))
|
||||
axs[1 + 2 * j].set_xlim((x_bins[0], x_bins[-1]))
|
||||
ymin = np.min(y_bins[:, 0])
|
||||
ymax = np.max(y_bins[:, -1])
|
||||
ydiff = ymax - ymin
|
||||
axs[0 + 2 * j].set_ylim((ymin - 0.05 * ydiff, ymax + 0.05 * ydiff))
|
||||
axs[1 + 2 * j].set_ylim((ymin - 0.05 * ydiff, ymax + 0.05 * ydiff))
|
||||
axs[0 + 2 * j].set_title(label + " x")
|
||||
axs[1 + 2 * j].set_title(label + " y")
|
||||
axs[0 + 2 * j].set_xlabel("Symbol")
|
||||
axs[1 + 2 * j].set_xlabel("Symbol")
|
||||
axs[0 + 2 * j].set_box_aspect(1)
|
||||
axs[1 + 2 * j].set_box_aspect(1)
|
||||
axs[0].set_ylabel("normalized power")
|
||||
fig.tight_layout()
|
||||
# axs[1+2*len(labels)-1].set_ylabel("normalized power")
|
||||
|
||||
if show:
|
||||
plt.show()
|
||||
return fig
|
||||
|
||||
def _plot_model_response_head(self, *signals, labels=None, sps=None, title_append="", subtitle="", show=True):
|
||||
def _plot_model_response_head(
|
||||
self, *signals, timestamps, labels=None, sps=None, title_append="", subtitle="", show=True
|
||||
):
|
||||
if not hasattr(labels, "__iter__") or isinstance(labels, (str, type(None))):
|
||||
labels = [labels]
|
||||
else:
|
||||
@@ -657,19 +766,31 @@ class HyperTraining:
|
||||
if not any(labels):
|
||||
labels = [f"signal {i + 1}" for i in range(len(signals))]
|
||||
|
||||
ordering = np.argsort(timestamps)
|
||||
signals = [signal[ordering] for signal in signals]
|
||||
timestamps = timestamps[ordering]
|
||||
|
||||
fig, axs = plt.subplots(1, 2, sharex=True, sharey=True)
|
||||
fig.set_size_inches(18, 6)
|
||||
fig.set_figwidth(18)
|
||||
fig.set_figheight(4)
|
||||
fig.suptitle(f"Fiber response{f' {title_append}' if title_append else ''}{f'\n{subtitle}' if subtitle else ''}")
|
||||
for i, ax in enumerate(axs):
|
||||
ax: plt.Axes
|
||||
for signal, label in zip(signals, labels):
|
||||
if sps is not None:
|
||||
xaxis = np.linspace(0, len(signal) / sps, len(signal), endpoint=False)
|
||||
xaxis = timestamps / sps
|
||||
else:
|
||||
xaxis = np.arange(len(signal))
|
||||
xaxis = timestamps
|
||||
ax.plot(xaxis, np.abs(signal[:, i]) ** 2, label=label)
|
||||
ax.set_xlabel("Sample" if sps is None else "Symbol")
|
||||
ax.set_ylabel("normalized power")
|
||||
ax.minorticks_on()
|
||||
ax.tick_params(axis="y", which="minor", left=False, right=False)
|
||||
ax.grid(which="major", axis="x")
|
||||
ax.grid(which="minor", axis="x", linestyle=":")
|
||||
ax.grid(which="major", axis="y")
|
||||
ax.legend(loc="upper right")
|
||||
fig.tight_layout()
|
||||
if show:
|
||||
plt.show()
|
||||
return fig
|
||||
@@ -680,22 +801,52 @@ class HyperTraining:
|
||||
model=None,
|
||||
title_append="",
|
||||
subtitle="",
|
||||
mode: Literal["eye", "head"] = "head",
|
||||
show=True,
|
||||
mode: Literal["eye", "head", "powers"] = "head",
|
||||
show=False,
|
||||
):
|
||||
if mode == "powers":
|
||||
input_data = torch.ones(
|
||||
1, 2 * self.data_settings.output_size, dtype=getattr(torch, self.data_settings.dtype)
|
||||
).to(self.pytorch_settings.device)
|
||||
model = model.to(self.pytorch_settings.device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
_, powers = model(input_data, trace_powers=True)
|
||||
|
||||
powers = [power.item() for power in powers]
|
||||
layer_names = ["input", *[str(x).split("(")[0] for x in model._layers._modules.values()]]
|
||||
|
||||
# remove dropout layers
|
||||
mask = [1 if "Dropout" not in layer_name else 0 for layer_name in layer_names]
|
||||
layer_names = [layer_name for layer_name, m in zip(layer_names, mask) if m]
|
||||
powers = [power for power, m in zip(powers, mask) if m]
|
||||
|
||||
fig = self._plot_model_response_powers(
|
||||
powers, layer_names, title_append=title_append, subtitle=subtitle, show=show
|
||||
)
|
||||
return fig
|
||||
|
||||
data_settings_backup = copy.deepcopy(self.data_settings)
|
||||
pytorch_settings_backup = copy.deepcopy(self.pytorch_settings)
|
||||
self.data_settings.drop_first = 100*128
|
||||
self.data_settings.drop_first = 99.5 + random.randint(0, 1000)
|
||||
self.data_settings.shuffle = False
|
||||
self.data_settings.train_split = 1.0
|
||||
self.pytorch_settings.batchsize = (
|
||||
self.pytorch_settings.eye_symbols if mode == "eye" else self.pytorch_settings.head_symbols
|
||||
)
|
||||
plot_loader, _ = self.get_sliced_data(trial, override={"num_symbols": self.pytorch_settings.batchsize})
|
||||
config_path = random.choice(self.data_settings.config_path) if isinstance(self.data_settings.config_path, (list, tuple)) else self.data_settings.config_path
|
||||
fiber_length = int(float(str(config_path).split('-')[-7])/1000)
|
||||
plot_loader, _ = self.get_sliced_data(
|
||||
trial,
|
||||
override={
|
||||
"num_symbols": self.pytorch_settings.batchsize,
|
||||
"config_path": config_path,
|
||||
}
|
||||
)
|
||||
self.data_settings = data_settings_backup
|
||||
self.pytorch_settings = pytorch_settings_backup
|
||||
|
||||
fiber_in, fiber_out, regen = self.run_model(model, plot_loader)
|
||||
fiber_in, fiber_out, regen, timestamps = self.run_model(model, plot_loader)
|
||||
fiber_in = fiber_in.view(-1, 2)
|
||||
fiber_out = fiber_out.view(-1, 2)
|
||||
regen = regen.view(-1, 2)
|
||||
@@ -703,6 +854,7 @@ class HyperTraining:
|
||||
fiber_in = fiber_in.numpy()
|
||||
fiber_out = fiber_out.numpy()
|
||||
regen = regen.numpy()
|
||||
timestamps = timestamps.numpy()
|
||||
|
||||
# https://github.com/matplotlib/matplotlib/issues/27713#issue-2104110987
|
||||
# https://github.com/matplotlib/matplotlib/issues/27713#issuecomment-1915497463
|
||||
@@ -713,9 +865,10 @@ class HyperTraining:
|
||||
fiber_in,
|
||||
fiber_out,
|
||||
regen,
|
||||
timestamps=timestamps,
|
||||
labels=("fiber in", "fiber out", "regen"),
|
||||
sps=plot_loader.dataset.samples_per_symbol,
|
||||
title_append=title_append,
|
||||
title_append=title_append + f" ({fiber_length} km)",
|
||||
subtitle=subtitle,
|
||||
show=show,
|
||||
)
|
||||
@@ -725,9 +878,10 @@ class HyperTraining:
|
||||
fiber_in,
|
||||
fiber_out,
|
||||
regen,
|
||||
timestamps=timestamps,
|
||||
labels=("fiber in", "fiber out", "regen"),
|
||||
sps=plot_loader.dataset.samples_per_symbol,
|
||||
title_append=title_append,
|
||||
title_append=title_append + f" ({fiber_length} km)",
|
||||
subtitle=subtitle,
|
||||
show=show,
|
||||
)
|
||||
@@ -739,7 +893,7 @@ class HyperTraining:
|
||||
|
||||
@staticmethod
|
||||
def build_title(trial: optuna.trial.Trial):
|
||||
title_append = f"for trial {trial.number}"
|
||||
title_append = f"at epoch {trial.user_attrs.get("epoch", -1)} for trial {trial.number}"
|
||||
model_n_hidden_layers = util.misc.multi_getattr((trial.params, trial.user_attrs), "model_n_hidden_layers", 0)
|
||||
input_dim = util.misc.multi_getattr((trial.params, trial.user_attrs), "model_input_dim", 0)
|
||||
model_dims = [
|
||||
|
||||
443
src/single-core-regen/hypertraining/lighning_models.py
Normal file
443
src/single-core-regen/hypertraining/lighning_models.py
Normal file
@@ -0,0 +1,443 @@
|
||||
from typing import Any
|
||||
import lightning as L
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
# import torch.nn.functional as F
|
||||
|
||||
from util.complexNN import DropoutComplex, Scale, ONNRect, EOActivation, energy_conserving, clamp, complex_mse_loss
|
||||
from util.datasets import FiberRegenerationDataset
|
||||
|
||||
|
||||
class regeneratorData(L.LightningDataModule):
|
||||
def __init__(
|
||||
self,
|
||||
config_globs,
|
||||
output_symbols,
|
||||
output_dim,
|
||||
dtype,
|
||||
drop_first,
|
||||
shuffle=True,
|
||||
train_split=None,
|
||||
batch_size=None,
|
||||
loader_settings=None,
|
||||
seed=None,
|
||||
num_symbols=None,
|
||||
test_globs=None,
|
||||
):
|
||||
super().__init__()
|
||||
self._config_globs = config_globs
|
||||
self._test_globs = test_globs
|
||||
self._test_data_available = test_globs is not None
|
||||
if self._test_data_available:
|
||||
self.test_dataloader = self._test_dataloader
|
||||
self._output_symbols = output_symbols
|
||||
self._output_dim = output_dim
|
||||
self._dtype = dtype
|
||||
self._drop_first = drop_first
|
||||
self._seed = seed
|
||||
self._shuffle = shuffle
|
||||
self._num_symbols = num_symbols
|
||||
self._train_split = train_split if train_split is not None else 0.8
|
||||
self.batch_size = batch_size if batch_size is not None else 1024
|
||||
self._loader_settings = loader_settings if loader_settings is not None else {}
|
||||
|
||||
def _get_data(self):
|
||||
self._data_train = FiberRegenerationDataset(
|
||||
file_path=self._config_globs,
|
||||
symbols=self._output_symbols,
|
||||
output_dim=self._output_dim,
|
||||
dtype=self._dtype,
|
||||
real=not self._dtype.is_complex,
|
||||
drop_first=self._drop_first,
|
||||
num_symbols=self._num_symbols,
|
||||
)
|
||||
# self._data_plot = FiberRegenerationDataset(
|
||||
# file_path=self._config_globs,
|
||||
# symbols=self._output_symbols,
|
||||
# output_dim=self._output_dim,
|
||||
# dtype=self._dtype,
|
||||
# real=not self._dtype.is_complex,
|
||||
# drop_first=self._drop_first,
|
||||
# num_symbols=400,
|
||||
# )
|
||||
if self._test_data_available:
|
||||
self._data_test = FiberRegenerationDataset(
|
||||
file_path=self._test_globs,
|
||||
symbols=self._output_symbols,
|
||||
output_dim=self._output_dim,
|
||||
dtype=self._dtype,
|
||||
real=not self._dtype.is_complex,
|
||||
drop_first=self._drop_first,
|
||||
num_symbols=self._num_symbols,
|
||||
)
|
||||
return self._data_train, self._data_test
|
||||
return self._data_train
|
||||
|
||||
def _split_data(self, stage="fit", split=None, shuffle=None):
|
||||
_split = split if split is not None else self._train_split
|
||||
_shuffle = shuffle if shuffle is not None else self._shuffle
|
||||
|
||||
dataset_size = len(self._data_train)
|
||||
indices = list(range(dataset_size))
|
||||
split_index = int(np.floor(_split * dataset_size))
|
||||
train_indices, valid_indices = indices[:split_index], indices[split_index:]
|
||||
if _shuffle:
|
||||
np.random.seed(self._seed)
|
||||
np.random.shuffle(train_indices)
|
||||
|
||||
|
||||
if _shuffle:
|
||||
if stage == "fit" or stage == "predict":
|
||||
self._train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
|
||||
# if stage == "fit" or stage == "validate":
|
||||
# self._valid_sampler = torch.utils.data.SubsetRandomSampler(valid_indices)
|
||||
else:
|
||||
if stage == "fit" or stage == "predict":
|
||||
self._train_sampler = train_indices
|
||||
if stage == "fit" or stage == "validate":
|
||||
self._valid_sampler = valid_indices
|
||||
|
||||
if stage == "fit":
|
||||
return self._train_sampler, self._valid_sampler
|
||||
elif stage == "validate":
|
||||
return self._valid_sampler
|
||||
elif stage == "predict":
|
||||
return self._train_sampler
|
||||
|
||||
def prepare_data(self):
|
||||
self._get_data()
|
||||
|
||||
def setup(self, stage=None):
|
||||
stage = stage or "fit"
|
||||
self._split_data(stage=stage)
|
||||
|
||||
def train_dataloader(self):
|
||||
return torch.utils.data.DataLoader(
|
||||
self._data_train,
|
||||
batch_size=self.batch_size,
|
||||
sampler=self._train_sampler,
|
||||
**self._loader_settings
|
||||
)
|
||||
|
||||
def val_dataloader(self):
|
||||
return torch.utils.data.DataLoader(
|
||||
self._data_train,
|
||||
batch_size=self.batch_size,
|
||||
sampler=self._valid_sampler,
|
||||
**self._loader_settings
|
||||
)
|
||||
|
||||
def _test_dataloader(self):
|
||||
return torch.utils.data.DataLoader(
|
||||
self._data_test,
|
||||
shuffle=self._shuffle,
|
||||
batch_size=self.batch_size,
|
||||
**self._loader_settings
|
||||
)
|
||||
|
||||
def predict_dataloader(self):
|
||||
return torch.utils.data.DataLoader(
|
||||
self._data_plot,
|
||||
shuffle=False,
|
||||
batch_size=40,
|
||||
pin_memory=True,
|
||||
drop_last=True,
|
||||
num_workers=4,
|
||||
prefetch_factor=2,
|
||||
)
|
||||
|
||||
# def plot_dataloader(self):
|
||||
|
||||
|
||||
|
||||
class regenerator(L.LightningModule):
|
||||
def __init__(
|
||||
self,
|
||||
*dims,
|
||||
layer_function=ONNRect,
|
||||
layer_func_kwargs: dict | None = {"square": True},
|
||||
act_function=EOActivation,
|
||||
act_func_kwargs: dict | None = None,
|
||||
parametrizations: list[dict] | None = [
|
||||
{
|
||||
"tensor_name": "weight",
|
||||
"parametrization": energy_conserving,
|
||||
},
|
||||
{
|
||||
"tensor_name": "alpha",
|
||||
"parametrization": clamp,
|
||||
},
|
||||
{
|
||||
"tensor_name": "alpha",
|
||||
"parametrization": clamp,
|
||||
},
|
||||
],
|
||||
dtype=torch.complex64,
|
||||
dropout_prob=0.01,
|
||||
scale_layers=False,
|
||||
optimizer=torch.optim.AdamW,
|
||||
optimizer_kwargs: dict | None = {
|
||||
"lr": 0.01,
|
||||
"amsgrad": True,
|
||||
},
|
||||
lr_scheduler=None,
|
||||
lr_scheduler_kwargs: dict | None = {
|
||||
"patience": 20,
|
||||
"factor": 0.5,
|
||||
"min_lr": 1e-6,
|
||||
"cooldown": 10,
|
||||
},
|
||||
sps = 128,
|
||||
# **kwargs,
|
||||
):
|
||||
torch.set_float32_matmul_precision('high')
|
||||
layer_func_kwargs = layer_func_kwargs if layer_func_kwargs is not None else {}
|
||||
act_func_kwargs = act_func_kwargs if act_func_kwargs is not None else {}
|
||||
optimizer_kwargs = optimizer_kwargs if optimizer_kwargs is not None else {}
|
||||
lr_scheduler_kwargs = lr_scheduler_kwargs if lr_scheduler_kwargs is not None else {}
|
||||
super().__init__()
|
||||
|
||||
self.example_input_array = torch.randn(1, dims[0], dtype=dtype)
|
||||
self._sps = sps
|
||||
|
||||
self.optimizer_settings = {
|
||||
"optimizer": optimizer,
|
||||
"optimizer_kwargs": optimizer_kwargs,
|
||||
"lr_scheduler": lr_scheduler,
|
||||
"lr_scheduler_kwargs": lr_scheduler_kwargs,
|
||||
}
|
||||
|
||||
# if len(dims) == 0:
|
||||
# try:
|
||||
# dims = kwargs["dims"]
|
||||
# except KeyError:
|
||||
# raise ValueError("dims must be provided")
|
||||
self._n_hidden_layers = len(dims) - 2
|
||||
|
||||
self.build_model(dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob, scale_layers)
|
||||
|
||||
def build_model(self, dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob, scale_layers):
|
||||
input_layer = nn.Sequential(
|
||||
layer_function(dims[0], dims[1], dtype=dtype, **layer_func_kwargs),
|
||||
act_function(size=dims[1], **act_func_kwargs),
|
||||
DropoutComplex(p=dropout_prob),
|
||||
)
|
||||
|
||||
if scale_layers:
|
||||
input_layer = nn.Sequential(Scale(dims[0]), input_layer)
|
||||
|
||||
self.layer_0 = input_layer
|
||||
|
||||
for i in range(1, self._n_hidden_layers):
|
||||
layer = nn.Sequential(
|
||||
layer_function(dims[i], dims[i + 1], dtype=dtype, **layer_func_kwargs),
|
||||
act_function(size=dims[i + 1], **act_func_kwargs),
|
||||
DropoutComplex(p=dropout_prob),
|
||||
)
|
||||
if scale_layers:
|
||||
layer = nn.Sequential(Scale(dims[i]), layer)
|
||||
setattr(self, f"layer_{i}", layer)
|
||||
|
||||
output_layer = nn.Sequential(
|
||||
layer_function(dims[-2], dims[-1], dtype=dtype, **layer_func_kwargs),
|
||||
act_function(size=dims[-1], **act_func_kwargs),
|
||||
Scale(dims[-1]),
|
||||
)
|
||||
setattr(self, f"layer_{self._n_hidden_layers}", output_layer)
|
||||
|
||||
if parametrizations is not None:
|
||||
self._apply_parametrizations(self, parametrizations)
|
||||
|
||||
def _apply_parametrizations(self, layer, parametrizations):
|
||||
for sub_layer in layer.children():
|
||||
if len(sub_layer._modules) > 0:
|
||||
self._apply_parametrizations(sub_layer, parametrizations)
|
||||
else:
|
||||
for parametrization in parametrizations:
|
||||
tensor_name = parametrization.get("tensor_name", None)
|
||||
if tensor_name is None:
|
||||
continue
|
||||
parametrization_func = parametrization.get("parametrization", None)
|
||||
if parametrization_func is None:
|
||||
continue
|
||||
param_kwargs = parametrization.get("kwargs", {})
|
||||
if tensor_name in sub_layer._parameters:
|
||||
parametrization_func(sub_layer, tensor_name, **param_kwargs)
|
||||
|
||||
def _trace_powers(self, enable, x, powers=None):
|
||||
if not enable:
|
||||
return
|
||||
if powers is None:
|
||||
powers = []
|
||||
powers.append(x.abs().square().sum())
|
||||
return powers
|
||||
|
||||
# def plot(self, mode):
|
||||
# self.predict_step()
|
||||
|
||||
# def validation_epoch_end(self, outputs):
|
||||
# x = torch.vstack([output['x'].view(output['x'].shape[0], -1, 2)[:, output['x'].shape[1]//2, :].squeeze() for output in outputs])
|
||||
# y = torch.vstack([output['y'].view(output['y'].shape[0], -1, 2).squeeze() for output in outputs])
|
||||
# y_hat = torch.vstack([output['y_hat'].view(output['y_hat'].shape[0], -1, 2).squeeze() for output in outputs])
|
||||
# timesteps = torch.vstack([output['timesteps'].squeeze() for output in outputs])
|
||||
# powers = torch.vstack([output['powers'] for output in outputs])
|
||||
|
||||
# return {'x': x, 'y': y, 'y_hat': y_hat, 'timesteps': timesteps, 'powers': powers}
|
||||
|
||||
def on_validation_epoch_end(self):
|
||||
if self.current_epoch % 10 == 0 or self.current_epoch == self.trainer.max_epochs - 1 or self.current_epoch < 10:
|
||||
x = self.val_outputs['x']
|
||||
# x = x.view(x.shape[0], -1, 2)
|
||||
# x = x[:, x.shape[1]//2, :].squeeze()
|
||||
y = self.val_outputs['y']
|
||||
# y = y.view(y.shape[0], -1, 2).squeeze()
|
||||
y_hat = self.val_outputs['y_hat']
|
||||
# y_hat = y_hat.view(y_hat.shape[0], -1, 2).squeeze()
|
||||
timesteps = self.val_outputs['timesteps']
|
||||
# timesteps = timesteps.squeeze()
|
||||
powers = self.val_outputs['powers']
|
||||
# powers = powers.squeeze()
|
||||
|
||||
fiber_in = x.detach().cpu().numpy()
|
||||
fiber_out = y.detach().cpu().numpy()
|
||||
regen = y_hat.detach().cpu().numpy()
|
||||
timesteps = timesteps.detach().cpu().numpy()
|
||||
# powers = np.array([power.detach().cpu().numpy() for power in powers])
|
||||
|
||||
# fiber_in = np.concat(fiber_in, axis=0)
|
||||
# fiber_out = np.concat(fiber_out, axis=0)
|
||||
# regen = np.concat(regen, axis=0)
|
||||
# timesteps = np.concat(timesteps, axis=0)
|
||||
# powers = powers.detach().cpu().numpy()
|
||||
|
||||
|
||||
import gc
|
||||
|
||||
fig = self.plot_model_head(fiber_in, fiber_out, regen, timesteps, sps=self._sps)
|
||||
|
||||
self.logger.experiment.add_figure("model response", fig, self.current_epoch)
|
||||
|
||||
# fig = self.plot_model_eye(fiber_in, fiber_out, regen, timesteps, sps=self._sps)
|
||||
|
||||
# self.logger.experiment.add_figure("model eye", fig, self.current_epoch)
|
||||
|
||||
# fig = self.plot_model_powers(powers)
|
||||
|
||||
# self.logger.experiment.add_figure("powers", fig, self.current_epoch)
|
||||
|
||||
gc.collect()
|
||||
# x, y, y_hat, timesteps, powers = self.validation_epoch_end(self.outputs)
|
||||
# self.plot(x, y, y_hat, timesteps, powers)
|
||||
|
||||
def plot_model_head(self, fiber_in, fiber_out, regen, timesteps, sps):
|
||||
import matplotlib
|
||||
matplotlib.use("TkCairo")
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
ordering = np.argsort(timesteps)
|
||||
signals = [signal[ordering] for signal in [fiber_in, fiber_out, regen]]
|
||||
timesteps = timesteps[ordering]
|
||||
|
||||
signals = [signal[:sps*40] for signal in signals]
|
||||
timesteps = timesteps[:sps*40]
|
||||
|
||||
fig, axs = plt.subplots(1, 2, sharex=True, sharey=True)
|
||||
fig.set_figwidth(16)
|
||||
fig.set_figheight(4)
|
||||
|
||||
for i, ax in enumerate(axs):
|
||||
for j, signal in enumerate(signals):
|
||||
ax.plot(timesteps / sps, np.square(np.abs(signal[:,i])), label=["fiber in", "fiber out", "regen"][j] + [" x", " y"][i])
|
||||
ax.set_xlabel("symbol")
|
||||
ax.set_ylabel("amplitude")
|
||||
ax.minorticks_on()
|
||||
ax.tick_params(axis="y", which="minor", left=False, right=False)
|
||||
ax.grid(which="major", axis="x")
|
||||
ax.grid(which="minor", axis="x", linestyle=":")
|
||||
ax.grid(which="major", axis="y")
|
||||
ax.legend(loc="upper right")
|
||||
fig.tight_layout()
|
||||
|
||||
return fig
|
||||
|
||||
def plot_model_eye(self, fiber_in, fiber_out, regen, timesteps, sps):
|
||||
...
|
||||
|
||||
def plot_model_powers(self, powers):
|
||||
...
|
||||
|
||||
def forward(self, x, trace_powers=False):
|
||||
powers = self._trace_powers(trace_powers, x)
|
||||
x = self.layer_0(x)
|
||||
powers = self._trace_powers(trace_powers, x, powers)
|
||||
for i in range(1, self._n_hidden_layers):
|
||||
x = getattr(self, f"layer_{i}")(x)
|
||||
powers = self._trace_powers(trace_powers, x, powers)
|
||||
x = getattr(self, f"layer_{self._n_hidden_layers}")(x)
|
||||
powers = self._trace_powers(trace_powers, x, powers)
|
||||
if trace_powers:
|
||||
return x, powers
|
||||
return x
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = self.optimizer_settings["optimizer"](
|
||||
self.parameters(), **self.optimizer_settings["optimizer_kwargs"]
|
||||
)
|
||||
if self.optimizer_settings["lr_scheduler"] is not None:
|
||||
lr_scheduler = self.optimizer_settings["lr_scheduler"](
|
||||
optimizer, **self.optimizer_settings["lr_scheduler_kwargs"]
|
||||
)
|
||||
return {
|
||||
"optimizer": optimizer,
|
||||
"lr_scheduler": {
|
||||
"scheduler": lr_scheduler,
|
||||
"monitor": "val_loss",
|
||||
}
|
||||
}
|
||||
return {"optimizer": optimizer}
|
||||
|
||||
def training_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0):
|
||||
x, y, timesteps = batch
|
||||
y_hat = self(x)
|
||||
loss = complex_mse_loss(y_hat, y, power=True)
|
||||
self.log("train_loss", loss, on_epoch=True, on_step=True)
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0):
|
||||
x, y, timesteps = batch
|
||||
if batch_idx == 0:
|
||||
y_hat, powers = self.forward(x, trace_powers=True)
|
||||
else:
|
||||
y_hat = self.forward(x)
|
||||
loss = complex_mse_loss(y_hat, y, power=True)
|
||||
self.log("val_loss", loss, on_epoch=True)
|
||||
y = y.view(y.shape[0], -1, 2).squeeze()
|
||||
x = x.view(x.shape[0], -1, 2)
|
||||
x = x[:, x.shape[1]//2, :].squeeze()
|
||||
y_hat = y_hat.view(y_hat.shape[0], -1, 2).squeeze()
|
||||
timesteps = timesteps.squeeze()
|
||||
if batch_idx == 0:
|
||||
powers = np.array([power.detach().cpu() for power in powers])
|
||||
self.val_outputs = {"y": y, "x": x, "y_hat": y_hat, "timesteps": timesteps, "powers": powers}
|
||||
else:
|
||||
self.val_outputs["y"] = torch.vstack([self.val_outputs["y"], y])
|
||||
self.val_outputs["x"] = torch.vstack([self.val_outputs["x"], x])
|
||||
self.val_outputs["y_hat"] = torch.vstack([self.val_outputs["y_hat"], y_hat])
|
||||
self.val_outputs["timesteps"] = torch.concat([self.val_outputs["timesteps"], timesteps], dim=0)
|
||||
return loss
|
||||
|
||||
def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0):
|
||||
x, y, timesteps = batch
|
||||
y_hat = self(x)
|
||||
loss = complex_mse_loss(y_hat, y, power=True)
|
||||
self.log("test_loss", loss, on_epoch=True)
|
||||
return loss
|
||||
|
||||
# def predict_step(self, batch, batch_idx):
|
||||
# x, y, timesteps = batch
|
||||
# y_hat = self(x)
|
||||
# return y, x, y_hat, timesteps
|
||||
|
||||
|
||||
|
||||
229
src/single-core-regen/hypertraining/models.py
Normal file
229
src/single-core-regen/hypertraining/models.py
Normal file
@@ -0,0 +1,229 @@
|
||||
import torch
|
||||
from torch.nn import Module, Sequential
|
||||
|
||||
from util.complexNN import (
|
||||
DropoutComplex,
|
||||
Scale,
|
||||
ONNRect,
|
||||
photodiode,
|
||||
EOActivation,
|
||||
polarimeter,
|
||||
# normalize_by_first,
|
||||
rotate,
|
||||
)
|
||||
|
||||
|
||||
class polarisation_estimator2(Module):
|
||||
def __init__(self):
|
||||
super(polarisation_estimator2, self).__init__()
|
||||
self.layers = Sequential(
|
||||
polarimeter(),
|
||||
torch.nn.Linear(4, 4),
|
||||
torch.nn.ReLU(),
|
||||
# torch.nn.Dropout(p=0.01),
|
||||
torch.nn.Linear(4, 4),
|
||||
torch.nn.ReLU(),
|
||||
# torch.nn.Dropout(p=0.01),
|
||||
torch.nn.Linear(4, 1),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# x = self.polarimeter(x)
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
class polarisation_estimator(Module):
|
||||
def __init__(
|
||||
self,
|
||||
*dims,
|
||||
layer_function=ONNRect,
|
||||
layer_func_kwargs: dict | None = None,
|
||||
output_layer_function=photodiode,
|
||||
# output_layer_func_kwargs: dict | None = None,
|
||||
act_function=EOActivation,
|
||||
act_func_kwargs: dict | None = None,
|
||||
parametrizations: list[dict] = None,
|
||||
dtype=torch.float64,
|
||||
dropout_prob=0.01,
|
||||
scale_layers=False,
|
||||
):
|
||||
super(polarisation_estimator, self).__init__()
|
||||
self._n_hidden_layers = len(dims) - 2
|
||||
|
||||
layer_func_kwargs = layer_func_kwargs or {}
|
||||
act_func_kwargs = act_func_kwargs or {}
|
||||
|
||||
self.build_model(dims, layer_function, layer_func_kwargs, output_layer_function, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob, scale_layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.layer_0(x)
|
||||
for i in range(1, self._n_hidden_layers):
|
||||
x = getattr(self, f"layer_{i}")(x)
|
||||
x = getattr(self, f"layer_{self._n_hidden_layers}")(x)
|
||||
x = torch.remainder(x, torch.ones_like(x) * 2 * torch.pi)
|
||||
return x.squeeze()
|
||||
|
||||
def build_model(self, dims, layer_function, layer_func_kwargs, output_layer_function, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob, scale_layers):
|
||||
for i in range(0, self._n_hidden_layers):
|
||||
self.add_module(f"layer_{i}", Sequential())
|
||||
|
||||
if scale_layers:
|
||||
self.get_submodule(f"layer_{i}").add_module("scale", Scale(dims[i]))
|
||||
|
||||
module = layer_function(dims[i], dims[i + 1], dtype=dtype, **layer_func_kwargs)
|
||||
self.get_submodule(f"layer_{i}").add_module("ONN", module)
|
||||
|
||||
module = act_function(size=dims[i + 1], **act_func_kwargs)
|
||||
self.get_submodule(f"layer_{i}").add_module("activation", module)
|
||||
|
||||
module = DropoutComplex(p=dropout_prob)
|
||||
self.get_submodule(f"layer_{i}").add_module("dropout", module)
|
||||
|
||||
self.add_module(f"layer_{self._n_hidden_layers}", Sequential())
|
||||
|
||||
if scale_layers:
|
||||
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("scale", Scale(dims[-2]))
|
||||
|
||||
module = layer_function(dims[-2], dims[-1], dtype=dtype, **layer_func_kwargs)
|
||||
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("ONN", module)
|
||||
|
||||
module = output_layer_function(size=dims[-1])
|
||||
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("photodiode", module)
|
||||
|
||||
# module = normalize_by_first()
|
||||
# self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("normalize", module)
|
||||
|
||||
if parametrizations is not None:
|
||||
self._apply_parametrizations(self, parametrizations)
|
||||
|
||||
def _apply_parametrizations(self, layer, parametrizations):
|
||||
for sub_layer in layer.children():
|
||||
if len(sub_layer._modules) > 0:
|
||||
self._apply_parametrizations(sub_layer, parametrizations)
|
||||
else:
|
||||
for parametrization in parametrizations:
|
||||
tensor_name = parametrization.get("tensor_name", None)
|
||||
if tensor_name is None:
|
||||
continue
|
||||
parametrization_func = parametrization.get("parametrization", None)
|
||||
if parametrization_func is None:
|
||||
continue
|
||||
param_kwargs = parametrization.get("kwargs", {})
|
||||
if tensor_name in sub_layer._parameters:
|
||||
parametrization_func(sub_layer, tensor_name, **param_kwargs)
|
||||
|
||||
class regenerator(Module):
|
||||
def __init__(
|
||||
self,
|
||||
*dims,
|
||||
layer_function=ONNRect,
|
||||
layer_func_kwargs: dict | None = None,
|
||||
act_function=EOActivation,
|
||||
act_func_kwargs: dict | None = None,
|
||||
parametrizations: list[dict] = None,
|
||||
dtype=torch.float64,
|
||||
dropout_prob=0.01,
|
||||
prescale=1,
|
||||
rotate=False,
|
||||
):
|
||||
super(regenerator, self).__init__()
|
||||
self._n_hidden_layers = len(dims) - 2
|
||||
|
||||
layer_func_kwargs = layer_func_kwargs or {}
|
||||
act_func_kwargs = act_func_kwargs or {}
|
||||
|
||||
self.rotation = rotate
|
||||
self.prescale = prescale
|
||||
|
||||
self.build_model(dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob)
|
||||
|
||||
def build_model(self, dims, layer_function, layer_func_kwargs, act_function, act_func_kwargs, parametrizations, dtype, dropout_prob):
|
||||
for i in range(0, self._n_hidden_layers):
|
||||
self.add_module(f"layer_{i}", Sequential())
|
||||
|
||||
|
||||
module = layer_function(dims[i], dims[i + 1], dtype=dtype, **layer_func_kwargs)
|
||||
self.get_submodule(f"layer_{i}").add_module("ONN", module)
|
||||
|
||||
module = act_function(size=dims[i + 1], **act_func_kwargs)
|
||||
self.get_submodule(f"layer_{i}").add_module("activation", module)
|
||||
|
||||
if dropout_prob is not None and dropout_prob > 0:
|
||||
module = DropoutComplex(p=dropout_prob)
|
||||
self.get_submodule(f"layer_{i}").add_module("dropout", module)
|
||||
|
||||
self.add_module(f"layer_{self._n_hidden_layers}", Sequential())
|
||||
|
||||
# if scale_layers:
|
||||
# self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("scale", Scale(dims[-2]))
|
||||
|
||||
module = layer_function(dims[-2], dims[-1], dtype=dtype, **layer_func_kwargs)
|
||||
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("ONN", module)
|
||||
|
||||
module = act_function(size=dims[-1], **act_func_kwargs)
|
||||
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("activation", module)
|
||||
|
||||
module = Scale(size=dims[-1])
|
||||
self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("out_scale", module)
|
||||
|
||||
if self.rotation:
|
||||
module = rotate()
|
||||
self.add_module("rotate", module)
|
||||
|
||||
|
||||
# module = Scale(size=dims[-1])
|
||||
# self.get_submodule(f"layer_{self._n_hidden_layers}").add_module("out_scale", module)
|
||||
|
||||
if parametrizations is not None:
|
||||
self._apply_parametrizations(self, parametrizations)
|
||||
|
||||
def _apply_parametrizations(self, layer, parametrizations):
|
||||
for sub_layer in layer.children():
|
||||
if len(sub_layer._modules) > 0:
|
||||
self._apply_parametrizations(sub_layer, parametrizations)
|
||||
else:
|
||||
for parametrization in parametrizations:
|
||||
tensor_name = parametrization.get("tensor_name", None)
|
||||
if tensor_name is None:
|
||||
continue
|
||||
parametrization_func = parametrization.get("parametrization", None)
|
||||
if parametrization_func is None:
|
||||
continue
|
||||
param_kwargs = parametrization.get("kwargs", {})
|
||||
if tensor_name in sub_layer._parameters:
|
||||
parametrization_func(sub_layer, tensor_name, **param_kwargs)
|
||||
|
||||
def _trace_powers(self, enable, x, powers=None):
|
||||
if not enable:
|
||||
return
|
||||
if powers is None:
|
||||
powers = []
|
||||
powers.append(x.abs().square().sum())
|
||||
return powers
|
||||
|
||||
def forward(self, x, angle=None, pre_rot=False, trace_powers=False):
|
||||
x = x * self.prescale
|
||||
powers = self._trace_powers(trace_powers, x)
|
||||
# x = self.layer_0(x)
|
||||
# powers = self._trace_powers(trace_powers, x, powers)
|
||||
for i in range(0, self._n_hidden_layers):
|
||||
x = getattr(self, f"layer_{i}")(x)
|
||||
powers = self._trace_powers(trace_powers, x, powers)
|
||||
x = getattr(self, f"layer_{self._n_hidden_layers}")(x)
|
||||
if self.rotation:
|
||||
try:
|
||||
x_rot = self.rotate(x, angle)
|
||||
except AttributeError:
|
||||
pass
|
||||
powers = self._trace_powers(trace_powers, x_rot, powers)
|
||||
else:
|
||||
x_rot = x
|
||||
|
||||
if pre_rot and trace_powers:
|
||||
return x_rot, x, powers
|
||||
if pre_rot and not trace_powers:
|
||||
return x_rot, x
|
||||
if not pre_rot and trace_powers:
|
||||
return x_rot, powers
|
||||
return x_rot
|
||||
@@ -18,8 +18,28 @@ class DataSettings:
|
||||
shuffle: bool = True
|
||||
in_out_delay: float = 0
|
||||
xy_delay: tuple | float | int = 0
|
||||
drop_first: int = 1000
|
||||
drop_first: int = 64
|
||||
drop_last: int = 64
|
||||
train_split: float = 0.8
|
||||
polarisations: tuple | list = (0,)
|
||||
# cross_pol_interference: float = 0
|
||||
randomise_polarisations: bool = False
|
||||
osnr: float | int = None
|
||||
seed: int = None
|
||||
|
||||
"""
|
||||
change to:
|
||||
|
||||
config_path: tuple | list | None = None
|
||||
dtype: torch.dtype | None = None
|
||||
symbols: int | float = 1
|
||||
output_dim: int = 2
|
||||
shuffle: bool = True
|
||||
drop_first: float | int = 0
|
||||
train_split: float = 0.8
|
||||
randomise_polarisations: bool = False
|
||||
|
||||
"""
|
||||
|
||||
|
||||
# pytorch settings
|
||||
@@ -30,8 +50,8 @@ class PytorchSettings:
|
||||
|
||||
device: str = "cuda"
|
||||
|
||||
dataloader_workers: int = 2
|
||||
dataloader_prefetch: int = 2
|
||||
dataloader_workers: int = 1
|
||||
dataloader_prefetch: int = 1
|
||||
|
||||
save_models: bool = True
|
||||
model_dir: str = ".models"
|
||||
@@ -56,6 +76,30 @@ class ModelSettings:
|
||||
model_layer_kwargs: dict | None = None
|
||||
model_layer_parametrizations: list= field(default_factory=list)
|
||||
|
||||
"""
|
||||
change to:
|
||||
|
||||
dims: tuple | list | None = None
|
||||
layer_function: nn.Module | None = None
|
||||
layer_func_kwargs: dict | None = None
|
||||
activation_function: nn.Module | None = None
|
||||
activation_func_kwargs: dict | None = None
|
||||
output_function: nn.Module | None = None
|
||||
output_func_kwargs: dict | None = None
|
||||
dropout_function: nn.Module | None = None
|
||||
dropout_func_kwargs: dict | None = None
|
||||
scale_function: nn.Module | None = None
|
||||
scale_func_kwargs: dict | None = None
|
||||
parametrizations: list | None = None
|
||||
|
||||
"""
|
||||
|
||||
|
||||
def _early_stop_default_kwargs():
|
||||
return {
|
||||
"threshold": 1e-05,
|
||||
"plateau": 25,
|
||||
}
|
||||
|
||||
@dataclass
|
||||
class OptimizerSettings:
|
||||
@@ -65,6 +109,20 @@ class OptimizerSettings:
|
||||
scheduler: str | None = None
|
||||
scheduler_kwargs: dict | None = None
|
||||
|
||||
early_stopping: bool = False
|
||||
early_stop_kwargs: dict = field(default_factory=_early_stop_default_kwargs)
|
||||
|
||||
"""
|
||||
change to:
|
||||
|
||||
optimizer: torch.optim.Optimizer | None = None
|
||||
optimizer_kwargs: dict | None = None
|
||||
learning_rate: float | None = None
|
||||
scheduler: torch.optim.lr_scheduler | None = None
|
||||
scheduler_kwargs: dict | None = None
|
||||
|
||||
"""
|
||||
|
||||
|
||||
def _pruner_default_kwargs():
|
||||
# MedianPruner
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
217
src/single-core-regen/plot_model.py
Normal file
217
src/single-core-regen/plot_model.py
Normal file
@@ -0,0 +1,217 @@
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
from matplotlib import pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
import util
|
||||
from hypertraining.settings import GlobalSettings, DataSettings, ModelSettings, OptimizerSettings, PytorchSettings
|
||||
from hypertraining import models
|
||||
|
||||
# def move_to_location_in_size(array, location, size):
|
||||
# array_x, array_y = array.shape
|
||||
# location_x, location_y = location
|
||||
# size_x, size_y = size
|
||||
|
||||
# left = location_x
|
||||
# right = size_x - array_x - location_x
|
||||
|
||||
# top = location_y
|
||||
# bottom = size_y - array_y - location_y
|
||||
|
||||
# return np.pad(
|
||||
# array,
|
||||
# (
|
||||
# (left, right),
|
||||
# (top, bottom),
|
||||
# ),
|
||||
# constant_values=(-np.inf, -np.inf),
|
||||
# )
|
||||
|
||||
def register_puccs_cmap(puccs_path=None):
|
||||
puccs_path = Path(__file__).resolve().parent / 'puccs.csv' if puccs_path is None else puccs_path
|
||||
|
||||
colors = []
|
||||
# keys = None
|
||||
with open(puccs_path, "r") as f:
|
||||
for i, line in enumerate(f.readlines()):
|
||||
elements = tuple(line.split(","))
|
||||
# if i == 0:
|
||||
# # keys = elements
|
||||
# continue
|
||||
# else:
|
||||
try:
|
||||
colors.append(tuple(map(float, elements[4:])))
|
||||
except ValueError:
|
||||
continue
|
||||
# colors = []
|
||||
# for current in puccs_csv_data:
|
||||
# colors.append(tuple(current[4:]))
|
||||
from matplotlib.colors import LinearSegmentedColormap
|
||||
import matplotlib as mpl
|
||||
mpl.colormaps.register(LinearSegmentedColormap.from_list('puccs', colors))
|
||||
|
||||
def pad_to_size(array, size):
|
||||
if not hasattr(size, "__len__"):
|
||||
size = (size, size)
|
||||
|
||||
left = (
|
||||
(size[0] - array.shape[0] + 1) // 2 if size[0] is not None else 0
|
||||
)
|
||||
right = (
|
||||
(size[0] - array.shape[0]) // 2 if size[0] is not None else 0
|
||||
)
|
||||
top = (
|
||||
(size[1] - array.shape[1] + 1) // 2 if size[1] is not None else 0
|
||||
)
|
||||
bottom = (
|
||||
(size[1] - array.shape[1]) // 2 if size[1] is not None else 0
|
||||
)
|
||||
|
||||
array: np.ndarray = array
|
||||
if array.ndim == 2:
|
||||
return np.pad(
|
||||
array,
|
||||
(
|
||||
(left, right),
|
||||
(top, bottom),
|
||||
),
|
||||
constant_values=(np.nan, np.nan),
|
||||
)
|
||||
elif array.ndim == 3:
|
||||
return np.pad(
|
||||
array,
|
||||
(
|
||||
(left, right),
|
||||
(top, bottom),
|
||||
(0,0)
|
||||
),
|
||||
constant_values=(np.nan, np.nan),
|
||||
)
|
||||
|
||||
def model_plot(model_path, show=True):
|
||||
torch.serialization.add_safe_globals([
|
||||
*util.complexNN.__all__,
|
||||
GlobalSettings,
|
||||
DataSettings,
|
||||
ModelSettings,
|
||||
OptimizerSettings,
|
||||
PytorchSettings,
|
||||
models.regenerator,
|
||||
torch.nn.utils.parametrizations.orthogonal,
|
||||
])
|
||||
checkpoint_dict = torch.load(model_path, weights_only=True)
|
||||
|
||||
dims = checkpoint_dict["model_kwargs"].pop("dims")
|
||||
|
||||
model = models.regenerator(*dims, **checkpoint_dict["model_kwargs"])
|
||||
model.load_state_dict(checkpoint_dict["model_state_dict"], strict=False)
|
||||
|
||||
model_params = []
|
||||
plots = []
|
||||
max_size = np.max(dims)
|
||||
# max_act_size = np.max(dims[1:])
|
||||
|
||||
# angles = [None, None]
|
||||
# weights = [None, None]
|
||||
|
||||
for num, (layer_name, layer) in enumerate(model.named_children()):
|
||||
# each layer contains an "ONN" layer and an "activation" layer
|
||||
# activation layer is approximately the same for all layers and nodes -> rotation by 90 degrees
|
||||
onn_weights = layer.ONN.weight
|
||||
onn_weights = onn_weights.detach().cpu().numpy()
|
||||
onn_values = np.abs(onn_weights).real
|
||||
onn_angles = np.mod(np.angle(onn_weights), 2*np.pi).real
|
||||
|
||||
model_params.append({layer_name: onn_weights})
|
||||
plots.append({layer_name: (num, onn_values, onn_angles)})#, act_values, act_angles)})
|
||||
|
||||
# fig, axs = plt.subplots(3, len(model_params)*2-1, figsize=(20, 5))
|
||||
|
||||
for plot in plots:
|
||||
layer_name, (num, onn_values, onn_angles) = plot.popitem()
|
||||
|
||||
if num == 0:
|
||||
value_img = onn_values
|
||||
angle_img = onn_angles
|
||||
onn_angles = pad_to_size(onn_angles, (max_size, None))
|
||||
onn_values = pad_to_size(onn_values, (max_size, None))
|
||||
else:
|
||||
onn_values = pad_to_size(onn_values, (max_size, onn_values.shape[1]+1))
|
||||
onn_angles = pad_to_size(onn_angles, (max_size, onn_angles.shape[1]+1))
|
||||
value_img = np.concatenate((value_img, onn_values), axis=1)
|
||||
angle_img = np.concatenate((angle_img, onn_angles), axis=1)
|
||||
|
||||
value_img = np.ma.array(value_img, mask=np.isnan(value_img))
|
||||
angle_img = np.ma.array(angle_img, mask=np.isnan(angle_img))
|
||||
|
||||
# from cmcrameri import cm
|
||||
from cmap import Colormap as cm
|
||||
import scicomap as sc
|
||||
# from matplotlib import colors as mcolors
|
||||
# alpha_map = mcolors.LinearSegmentedColormap(
|
||||
# 'alphamap',
|
||||
# {
|
||||
# 'red': [(0, 0, 0), (1, 0, 0)],
|
||||
# 'green': [(0, 0, 0), (1, 0, 0)],
|
||||
# 'blue': [(0, 0, 0), (1, 0, 0)],
|
||||
# 'alpha': [
|
||||
# (0, 1, 1),
|
||||
# # (0.2, 0.2, 0.1),
|
||||
# (1, 0, 0)
|
||||
# ]
|
||||
# }
|
||||
# )
|
||||
# alpha_map.set_bad(color="#AAAAAA")
|
||||
|
||||
|
||||
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
||||
|
||||
fig, axs = plt.subplots(1, 2, sharex=True, sharey=True, figsize=(16, 5))
|
||||
# fig.tight_layout()
|
||||
dividers = map(make_axes_locatable, axs)
|
||||
caxs = list(map(lambda d: d.append_axes("right", size="5%", pad=0.05), dividers))
|
||||
# masked_value_img = np.ma.masked_where(np.isnan(value_img), value_img)
|
||||
masked_value_img = value_img
|
||||
cmap = cm('google:turbo').to_matplotlib()
|
||||
# cmap = sc.ScicoSequential("rainbow").get_mpl_color_map()
|
||||
cmap.set_bad(color="#AAAAAA")
|
||||
im_val = axs[0].imshow(masked_value_img, cmap=cmap, vmin=0, vmax=1)
|
||||
fig.colorbar(im_val, cax=caxs[0], orientation="vertical")
|
||||
|
||||
|
||||
masked_angle_img = np.ma.masked_where(np.isnan(angle_img), angle_img)
|
||||
# cmap = cm('crameri:romao').to_matplotlib()
|
||||
# cmap = plt.get_cmap('puccs')
|
||||
# cmap = sc.ScicoCircular("colorwheel").get_mpl_color_map()
|
||||
cmap = cm('colorcet:CET_C8').to_matplotlib()
|
||||
cmap.set_bad(color="#AAAAAA")
|
||||
im_ang = axs[1].imshow(masked_angle_img, cmap=cmap, vmin=0, vmax=2*np.pi)
|
||||
cbar = fig.colorbar(im_ang, cax=caxs[1], orientation="vertical", ticks=[i/8 * 2*np.pi for i in range(9)])
|
||||
cbar.ax.set_yticklabels(["0", "π/4", "π/2", "3π/4", "π", "5π/4", "3π/2", "7π/4", "2π"])
|
||||
# im_ang_w = axs[2].imshow(masked_angle_img, cmap=cmap)
|
||||
# im_ang_w = axs[2].imshow(masked_value_img, cmap=alpha_map)
|
||||
|
||||
axs[0].axis("off")
|
||||
axs[1].axis("off")
|
||||
# axs[2].axis("off")
|
||||
|
||||
axs[0].set_title("Values")
|
||||
axs[1].set_title("Angles")
|
||||
# axs[2].set_title("Values and Angles")
|
||||
|
||||
|
||||
...
|
||||
if show:
|
||||
plt.show()
|
||||
return fig
|
||||
|
||||
# model = models.regenerator(*dims, **model_kwargs)
|
||||
|
||||
if __name__ == "__main__":
|
||||
register_puccs_cmap()
|
||||
if len(sys.argv) > 1:
|
||||
model_plot(sys.argv[1])
|
||||
else:
|
||||
print("Please provide a model path as an argument")
|
||||
# model_plot(".models/best_20250114_224234.tar")
|
||||
102
src/single-core-regen/puccs.csv
Normal file
102
src/single-core-regen/puccs.csv
Normal file
@@ -0,0 +1,102 @@
|
||||
"x","L","a","b","R","G","B"
|
||||
0.,0.5187848173343539,0.6399990176455989,0.67,0.8889427469969852,0.22673227640012172,0.
|
||||
0.01,0.5374499525557803,0.604014067614707,0.6777967519386492,0.8956274406155226,0.27553288030331824,0.
|
||||
0.02,0.5560867887452998,0.5680836759482211,0.6855816828789898,0.9019507507843885,0.318608215541461,0.
|
||||
0.03,0.5746877595125583,0.5322224300667823,0.6933516322080414,0.907905487190649,0.3580633000693721,0.
|
||||
0.04,0.5932314662487472,0.49647158484797804,0.7010976613543587,0.9134808162089558,0.3949845524063657,0.
|
||||
0.05,0.6117000836392819,0.46086550613202343,0.7088123243737041,0.918668356138916,0.43002019316005363,0.
|
||||
0.06,0.6300828534995973,0.4254249348741487,0.7164911273850869,0.923462736751354,0.4635961938811463,0.
|
||||
0.07,0.6483763163456417,0.3901565406944371,0.7241326253017896,0.9278609626724071,0.49601354353255284,0.
|
||||
0.08,0.6665840140182806,0.3550534951951814,0.7317382976124045,0.9318616057744784,0.5274983630587982,0.
|
||||
0.09,0.6847162776119433,0.3200958808181962,0.7393124597949372,0.9354640163365924,0.5582303922647159,0.
|
||||
0.1,0.7027902128942014,0.2852507189547545,0.7468622572263107,0.9386675557407496,0.5883604892249517,0.004034952213848706
|
||||
0.11,0.7208298719332069,0.25047163906104203,0.7543977368741345,0.9414708123927996,0.6180221032545026,0.016031521294251994
|
||||
0.12,0.7388665670611175,0.2156982733607376,0.7619319784446927,0.943870754968487,0.6473392272576862,0.029857267582036696
|
||||
0.13,0.7569392765472108,0.18085547473834482,0.7694812638396673,0.9458617774020323,0.676432172396153,0.045365670193636125
|
||||
0.14,0.7750950944867471,0.14585244938794778,0.7770652650825484,0.9474345911958609,0.7054219201084561,0.06017985923530026
|
||||
0.15,0.793389684293558,0.11058188251425949,0.7847072337503834,0.9485749196617762,0.7344334940032564,0.07418869502646075
|
||||
0.16,0.8117919447684838,0.07510373484536464,0.792394178330817,0.9492596163836376,0.7634480277996188,0.08767517868137237
|
||||
0.17,0.8293050962981561,0.03629277424762101,0.799038155466063,0.9462308253550155,0.7922009241807345,0.10066327128139077
|
||||
0.18,0.8213303100752708,-0.0062517290795987,0.7879999288492758,0.9088702681901394,0.7940579017644396,0.10139639009534024
|
||||
0.19,0.8134831311534617,-0.048115463155645855,0.7771383286984362,0.8716809050191757,0.7954897210083888,0.10232311621802098
|
||||
0.2,0.80558613530069,-0.0902449644291895,0.7662077749032042,0.8337524177888596,0.7965471523787845,0.10344968926026826
|
||||
0.21,0.7975860185564765,-0.13292460297117392,0.7551344872795225,0.7947193410849823,0.7972381033243311,0.10477682283894393
|
||||
0.22,0.7894147026971006,-0.17651756772919341,0.7438242359834689,0.7540941866826836,0.7975605026647324,0.10631182441371936
|
||||
0.23,0.7809997374598548,-0.2214103719409295,0.7321767396537806,0.7112894518675287,0.7974995317311054,0.1080672415170634
|
||||
0.24,0.7722646970273015,-0.2680107379394189,0.7200862142018722,0.6655745739336695,0.7970267795229349,0.11006041388465265
|
||||
0.25,0.7631307298557146,-0.3167393290089981,0.7074435179925446,0.6160047476007512,0.7960993904970947,0.11231257117602686
|
||||
0.26,0.7535192192483822,-0.36801555555407994,0.6941398344519211,0.5612859274945571,0.794659599537827,0.11484733363789801
|
||||
0.27,0.7433557597838075,-0.42223636134393283,0.6800721760037781,0.4994862901720824,0.7926351396848288,0.11768844813479104
|
||||
0.28,0.732575139048096,-0.479749646583324,0.6651502794883674,0.42731393423789277,0.7899410218414098,0.12085678487511567
|
||||
0.29,0.7211269294461059,-0.5408244362880141,0.6493043460161184,0.3378265607222193,0.786483110019224,0.124366774034814
|
||||
0.3,0.7090756028785993,-0.6051167807996883,0.6326236137723747,0.2098475715121697,0.7821998608677176,0.12819222127525928
|
||||
0.31,0.7094510768540225,-0.6165036055456403,0.5630307498747129,0.15061488620640032,0.7845112116922692,0.21943537230975235
|
||||
0.32,0.7174669421288304,-0.5917687864932311,0.4797229624661701,0.18766933782916642,0.7905828987725732,0.31091344246312086
|
||||
0.33,0.7249009746435938,-0.5688293479200438,0.40246208306061504,0.21160609617940718,0.7962175427587832,0.38519766326885596
|
||||
0.34,0.7317072855135611,-0.5478268906666535,0.3317250285377912,0.22717569971119178,0.8013847719431052,0.4490960048955565
|
||||
0.35,0.7379328517830899,-0.5286164561226088,0.26702357292455026,0.23690087622812972,0.8061220291668977,0.5056371468159843
|
||||
0.36,0.7436229063122554,-0.5110584677642499,0.20788761731555405,0.24226377668817778,0.8104638164122776,0.5563570758573497
|
||||
0.37,0.7488251728809415,-0.4950056627547577,0.15382117501783654,0.24424372086048424,0.8144455902164638,0.6022301663745243
|
||||
0.38,0.7535943992285348,-0.48028910419451787,0.10425526029155024,0.24352232677523483,0.818107753931944,0.6440238320299774
|
||||
0.39,0.757994865186593,-0.4667104416936734,0.05852182167144754,0.240562414747303,0.8214980148949816,0.6824536572462205
|
||||
0.4,0.7620994844391137,-0.4540446830999986,0.015863077249098356,0.2356325204239052,0.8246710357361025,0.7182393675419642
|
||||
0.41,0.7659871096124125,-0.4420485102716773,-0.024540477496154123,0.22880568593963535,0.8276865975886148,0.7521146815529202
|
||||
0.42,0.7697410958994951,-0.4304647113488041,-0.06355514164248566,0.21993360985514526,0.8306086550266585,0.7848331944479765
|
||||
0.43,0.773446484628189,-0.4190308715098135,-0.10206473803580057,0.20858849290850018,0.833503273690861,0.8171544357676854
|
||||
0.44,0.7771893686864673,-0.4074813310994203,-0.14096401824224686,0.1939295692427068,0.8364382500400466,0.8498448067259188
|
||||
0.45,0.7810574093604746,-0.3955455908045306,-0.18116403397486242,0.17438366103820427,0.839483669055626,0.8836865023336339
|
||||
0.46,0.7851360804917298,-0.3829599011818591,-0.2235531031349741,0.14679145002531463,0.8427091517444469,0.9194481212717681
|
||||
0.47,0.789525027020907,-0.369416784561489,-0.26916682191206776,0.10278921007810798,0.8461971304126237,0.9580316568065935
|
||||
0.48,0.7942371698732826,-0.35487637041943493,-0.3181394757087982,0.0013920913109500188,0.8499626968466341,0.9995866371771526
|
||||
0.49,0.7773897680996302,-0.31852357140025195,-0.34537976514700053,0.10740420703601522,0.8254781216972907,1.
|
||||
0.5,0.7604011244310231,-0.28211213216592784,-0.3722846952738428,0.1581725581872408,0.8008522647497104,1.
|
||||
0.51,0.7433440454962605,-0.2455540169176899,-0.3992980063927199,0.19300141807932156,0.7761561224913385,1.
|
||||
0.52,0.7262590833969331,-0.20893614020926626,-0.42635547610418184,0.2194621842292243,0.751443124097109,1.
|
||||
0.53,0.709058602701224,-0.17207067467417486,-0.453595892719742,0.2405673704012788,0.7265803324554873,1.
|
||||
0.54,0.6915768892539101,-0.1346024482921609,-0.48128169789479536,0.25788347992973676,0.701321051230534,1.
|
||||
0.55,0.6736331627810209,-0.09614399811510127,-0.5096991935104321,0.2722888922216317,0.6753950894563805,1.
|
||||
0.56,0.6551463184003872,-0.05652149358027936,-0.5389768254408652,0.28422807900785235,0.6486730893521468,1.
|
||||
0.57,0.6361671326276888,-0.01584376303510615,-0.5690341788729347,0.293907374075009,0.6212117649042732,1.
|
||||
0.58,0.6168396823565967,0.025580396234342995,-0.5996430791016598,0.301442767979156,0.5931976878638505,1.
|
||||
0.59,0.5973210287815495,0.06741435793529688,-0.6305547881733555,0.30694603901024253,0.5648312189065924,1.
|
||||
0.6,0.5777303704171711,0.10940264614179468,-0.661580531294122,0.3105418468883679,0.5362525958007331,1.
|
||||
0.61,0.5581475370499237,0.15137416317967575,-0.6925938819599547,0.3123531986526998,0.5075386530652202,1.
|
||||
0.62,0.5386227795100639,0.19322120739317136,-0.7235152578861672,0.31248922600720636,0.4787151440558522,1.
|
||||
0.63,0.5191666876024412,0.23492108185347996,-0.754327887989376,0.31103663081260624,0.44973844514160927,1.
|
||||
0.64,0.4996990584326256,0.2766456839100268,-0.7851587896650079,0.30803814950244496,0.4204116611935119,1.
|
||||
0.65,0.479957679121191,0.3189570094767831,-0.8164232296840259,0.30343473603466015,0.390226489453496,1.
|
||||
0.66,0.4600072725872886,0.3617163391430824,-0.8480187063016573,0.29717122075330515,0.3591178757512998,1.
|
||||
0.67,0.44600100870220305,0.4113853615984094,-0.8697728377551008,0.3178994129506999,0.3295740682997879,1.
|
||||
0.68,0.4574651571354146,0.44026390446569547,-0.8504539292487465,0.3842479358768364,0.3280946443367561,1.
|
||||
0.69,0.4691809168948424,0.46977626401045774,-0.830711015748157,0.44293649140770447,0.3260767554252525,1.
|
||||
0.7,0.4811696900083858,0.49997635259991063,-0.8105080314416201,0.49708450874457527,0.3234487047238236,1.
|
||||
0.71,0.49350094811609174,0.5310391714342613,-0.7897279055963483,0.5485591109413528,0.3201099534066949,1.
|
||||
0.72,0.5062548753068121,0.5631667067020758,-0.7682355153041539,0.5985798481027601,0.3159263917472715,1.
|
||||
0.73,0.5195243020949684,0.5965928013272943,-0.7458744264238399,0.6480500606439057,0.31071717884730565,1.
|
||||
0.74,0.5334043922713477,0.6315571758288618,-0.7224842728734379,0.6976685401842261,0.3042411890803418,1.
|
||||
0.75,0.5479805812358602,0.6682750446095802,-0.697921082452685,0.7479712773579563,0.29618040787504757,1.
|
||||
0.76,0.5633244502526606,0.7069267230777347,-0.6720642293775535,0.7993701361353484,0.28611136999256687,1.
|
||||
0.77,0.5794956601139,0.7476624986056212,-0.6448131757501174,0.8521918014427678,0.2734527325942473,1.
|
||||
0.78,0.5965429098573916,0.7906050455688622,-0.6160858559672187,0.9067003897516911,0.2573693489198746,1.
|
||||
0.79,0.6145761476424179,0.8360313267658297,-0.5856969899409387,0.963334644317004,0.23648492980159264,1.
|
||||
0.8,0.6232910688128902,0.859291371252556,-0.5300995185388214,1.,0.21867949406239662,0.9712088595948508
|
||||
0.81,0.6159984336377875,0.8439887543380684,-0.44635440435952856,1.,0.21606849746358275,0.9041480210597966
|
||||
0.82,0.6091642745073532,0.8296481879180277,-0.36787420852419694,1.,0.21421830096504035,0.8419706002336461
|
||||
0.83,0.6025478038652375,0.8157644115969636,-0.2918938425681935,1.,0.21295365915197917,0.7823908751330636
|
||||
0.84,0.5961857222953111,0.8024144366282877,-0.21883475834162458,0.9971140114799418,0.21220068235083267,0.7256713129328118
|
||||
0.85,0.5900921771070883,0.7896279492437488,-0.1488594167412921,0.993273906363258,0.2118788857127918,0.671860243327784
|
||||
0.86,0.5842771639541229,0.7774259239818333,-0.08208260304413262,0.9887084084529413,0.21191070453347688,0.6209624706933893
|
||||
0.87,0.578741582584259,0.7658102488427286,-0.018514649521559012,0.9835846378805114,0.2122246941077346,0.5728987835613306
|
||||
0.88,0.5734741590353537,0.7547572669288056,0.04197390858426542,0.9780378159372328,0.21275878699579343,0.5274829957183049
|
||||
0.89,0.5684517008574971,0.7442183119942206,0.09964940221121898,0.9721670725313721,0.21346242315895625,0.4844270603851604
|
||||
0.9,0.5636419856510335,0.7341257696545772,0.15488185789614228,0.9660363209686843,0.21429691147008262,0.4433660148378527
|
||||
0.91,0.5590069340453534,0.7243997354573974,0.20810856081277884,0.9596781387247791,0.2152344151262528,0.4038812338146013
|
||||
0.92,0.5545051525321143,0.7149533506766244,0.25980485409830323,0.9530986696850675,0.21625626438013962,0.3655130449917989
|
||||
0.93,0.5500961975299247,0.705701749880514,0.3104351723857584,0.9462863346513658,0.21735046958786286,0.327780364198278
|
||||
0.94,0.545740378056064,0.6965616468647046,0.36045530782708896,0.93921469089265,0.21851014470332586,0.29014917175372823
|
||||
0.95,0.5414004092067859,0.6874548042588865,0.41029342232076466,0.9318478255642132,0.21973168075163751,0.2519897371806688
|
||||
0.96,0.5370416605957644,0.6783085548415655,0.46034719456417006,0.9241434776436454,0.22101341980094052,0.2124579038400577
|
||||
0.97,0.5326309593934517,0.6690532898786764,0.5109975653738162,0.9160532016485884,0.22235495330179011,0.17018252385769012
|
||||
0.98,0.5281374148557197,0.6596241892863608,0.5625992691950712,0.90752576202319,0.22375597459867458,0.1223073280126531
|
||||
0.99,0.5235317096396147,0.6499597345521199,0.615488972291106,0.8985077346125597,0.22521565729028564,0.05933950582860665
|
||||
1.,0.5187848173343539,0.6399990176455989,0.67,0.8889427469969852,0.22673227640012172,0.
|
||||
|
@@ -1,6 +1,8 @@
|
||||
from datetime import datetime
|
||||
|
||||
import optuna
|
||||
import torch
|
||||
import util
|
||||
from hypertraining.hypertraining import HyperTraining
|
||||
from hypertraining.settings import (
|
||||
GlobalSettings,
|
||||
@@ -16,24 +18,29 @@ global_settings = GlobalSettings(
|
||||
)
|
||||
|
||||
data_settings = DataSettings(
|
||||
config_path="data/*-128-16384-100000-0-0-17-0-PAM4-0.ini",
|
||||
# config_path="data/*-128-16384-100000-0-0-17-0-PAM4-0.ini",
|
||||
config_path="data/20241204-131003-128-16384-100000-0-0-17-0-PAM4-0.ini",
|
||||
dtype="complex64",
|
||||
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
|
||||
symbols=13, # study: single_core_regen_20241123_011232
|
||||
# symbols=13, # study: single_core_regen_20241123_011232
|
||||
# symbols = (3, 13),
|
||||
symbols=4,
|
||||
# output_size = (11, 32), # ballpark 26 taps -> 2 taps per input symbol -> 1 tap every 0.01m (model has 52 inputs)
|
||||
output_size=26, # study: single_core_regen_20241123_011232 (model_input_dim/2)
|
||||
# output_size=26, # study: single_core_regen_20241123_011232 (model_input_dim/2)
|
||||
output_size=(8, 30),
|
||||
shuffle=True,
|
||||
in_out_delay=0,
|
||||
xy_delay=0,
|
||||
drop_first=128 * 100,
|
||||
drop_first=256,
|
||||
train_split=0.8,
|
||||
randomise_polarisations=False,
|
||||
)
|
||||
|
||||
pytorch_settings = PytorchSettings(
|
||||
epochs=10000,
|
||||
epochs=10,
|
||||
batchsize=2**10,
|
||||
device="cuda",
|
||||
dataloader_workers=12,
|
||||
dataloader_workers=4,
|
||||
dataloader_prefetch=4,
|
||||
summary_dir=".runs",
|
||||
write_every=2**5,
|
||||
@@ -43,28 +50,70 @@ pytorch_settings = PytorchSettings(
|
||||
|
||||
model_settings = ModelSettings(
|
||||
output_dim=2,
|
||||
# n_hidden_layers = (3, 8),
|
||||
n_hidden_layers=4,
|
||||
overrides={
|
||||
"n_hidden_nodes_0": 8,
|
||||
"n_hidden_nodes_1": 6,
|
||||
"n_hidden_nodes_2": 4,
|
||||
"n_hidden_nodes_3": 8,
|
||||
},
|
||||
model_activation_func="Mag",
|
||||
# satabsT0=(1e-6, 1),
|
||||
n_hidden_layers = (2, 5),
|
||||
n_hidden_nodes=(2, 16),
|
||||
model_activation_func="EOActivation",
|
||||
dropout_prob=0,
|
||||
model_layer_function="ONNRect",
|
||||
model_layer_kwargs={"square": True},
|
||||
# scale=(False, True),
|
||||
scale=False,
|
||||
model_layer_parametrizations=[
|
||||
{
|
||||
"tensor_name": "weight",
|
||||
"parametrization": util.complexNN.energy_conserving,
|
||||
},
|
||||
{
|
||||
"tensor_name": "alpha",
|
||||
"parametrization": util.complexNN.clamp,
|
||||
},
|
||||
{
|
||||
"tensor_name": "gain",
|
||||
"parametrization": util.complexNN.clamp,
|
||||
"kwargs": {
|
||||
"min": 0,
|
||||
"max": float("inf"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"tensor_name": "phase_bias",
|
||||
"parametrization": util.complexNN.clamp,
|
||||
"kwargs": {
|
||||
"min": 0,
|
||||
"max": 2 * torch.pi,
|
||||
},
|
||||
},
|
||||
{
|
||||
"tensor_name": "scales",
|
||||
"parametrization": util.complexNN.clamp,
|
||||
},
|
||||
{
|
||||
"tensor_name": "angle",
|
||||
"parametrization": util.complexNN.clamp,
|
||||
"kwargs": {
|
||||
"min": -torch.pi,
|
||||
"max": torch.pi,
|
||||
},
|
||||
},
|
||||
{
|
||||
"tensor_name": "loss",
|
||||
"parametrization": util.complexNN.clamp,
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
optimizer_settings = OptimizerSettings(
|
||||
optimizer="Adam",
|
||||
# learning_rate = (1e-5, 1e-1),
|
||||
learning_rate=5e-3
|
||||
# learning_rate=5e-4,
|
||||
optimizer="AdamW",
|
||||
optimizer_kwargs={
|
||||
"lr": 5e-3,
|
||||
"amsgrad": True,
|
||||
# "weight_decay": 1e-7,
|
||||
},
|
||||
)
|
||||
|
||||
optuna_settings = OptunaSettings(
|
||||
n_trials=1,
|
||||
n_workers=1,
|
||||
n_trials=1024,
|
||||
n_workers=8,
|
||||
timeout=3600,
|
||||
directions=("minimize",),
|
||||
metrics_names=("mse",),
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
# from datetime import datetime
|
||||
from pathlib import Path
|
||||
import matplotlib
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.tensorboard
|
||||
import torch.utils.tensorboard.summary
|
||||
from hypertraining.settings import (
|
||||
GlobalSettings,
|
||||
DataSettings,
|
||||
@@ -9,7 +13,7 @@ from hypertraining.settings import (
|
||||
OptimizerSettings,
|
||||
)
|
||||
|
||||
from hypertraining.training import Trainer
|
||||
from hypertraining.training import RegenerationTrainer#, PolarizationTrainer
|
||||
|
||||
# import torch
|
||||
import json
|
||||
@@ -22,26 +26,39 @@ global_settings = GlobalSettings(
|
||||
)
|
||||
|
||||
data_settings = DataSettings(
|
||||
# config_path="data/*-128-16384-50000-0-0-17-0-PAM4-0.ini",
|
||||
config_path=[f"data/*-128-16384-{length}-0-0-17-0-PAM4-0.ini" for length in (40000, 50000, 60000)],
|
||||
# config_path="data/20250115-114319-128-16384-inf-100000-0-0-0-0-PAM4-1-0-10.ini", # no effects enabled - baseline
|
||||
# config_path = "data/20250115-233553-128-16384-1060.0-100000-0-0.2-17.0-0.058-PAM4-1.0-0.0-10.ini", # dispersion + slope only
|
||||
# config_path="data/20250115-115836-128-16384-60.0-100000-0-0.2-17-0.058-PAM4-1000-0.2-10.ini", # all linear effects enabled with realistic values + noise + pmd (delta_beta=0.2) + ortho_error = 0.1
|
||||
# config_path="data/20250118-225840-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini", # a)
|
||||
# config_path="data/20250116-214606-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini", # b)
|
||||
# config_path="data/20250116-214547-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini", # c)
|
||||
# config_path="data/20250117-143926-128-16384-inf-100000-0-0.2-17.0-0.058-PAM4-0-0-10.ini", # d) 10ps dgd
|
||||
config_path="data/20250120-105720-128-16384-inf-100000-0-0.2-17-0.058-PAM4-0-0-10.ini", # d) 10ns
|
||||
|
||||
# config_path="data/20250114-215547-128-16384-60.0-100000-1.15-0.2-17-0.058-PAM4-1-0-10.ini", # with gamma=1.15, 2.5dBm launch power, no pmd
|
||||
|
||||
|
||||
dtype="complex64",
|
||||
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
|
||||
symbols=13, # study: single_core_regen_20241123_011232
|
||||
symbols=4, # study: single_core_regen_20241123_011232 -> taps spread over 4 symbols @ 10GBd
|
||||
# output_size = (11, 32), # 26 taps -> 2 taps per input symbol -> 1 tap every 1cm (model has 52 inputs (x/y))
|
||||
output_size=26, # study: single_core_regen_20241123_011232 (model_input_dim/2)
|
||||
shuffle=True,
|
||||
in_out_delay=0,
|
||||
xy_delay=0,
|
||||
drop_first=128 * 64,
|
||||
output_size=20, # = model_input_dim/2 study: single_core_regen_20241123_011232
|
||||
shuffle=False,
|
||||
drop_first=256,
|
||||
drop_last=256,
|
||||
train_split=0.8,
|
||||
randomise_polarisations=False,
|
||||
polarisations=False,
|
||||
# cross_pol_interference=0.01,
|
||||
osnr=16, #16dB due to amplification with NF 5
|
||||
)
|
||||
|
||||
pytorch_settings = PytorchSettings(
|
||||
epochs=10000,
|
||||
batchsize=2**12,
|
||||
epochs=1000,
|
||||
batchsize=2**13,
|
||||
device="cuda",
|
||||
dataloader_workers=12,
|
||||
dataloader_prefetch=8,
|
||||
dataloader_workers=32,
|
||||
dataloader_prefetch=4,
|
||||
summary_dir=".runs",
|
||||
write_every=2**5,
|
||||
save_models=True,
|
||||
@@ -50,70 +67,51 @@ pytorch_settings = PytorchSettings(
|
||||
|
||||
model_settings = ModelSettings(
|
||||
output_dim=2,
|
||||
n_hidden_layers=4,
|
||||
n_hidden_layers=3,
|
||||
overrides={
|
||||
"n_hidden_nodes_0": 4,
|
||||
"n_hidden_nodes_1": 4,
|
||||
"n_hidden_nodes_2": 4,
|
||||
"n_hidden_nodes_3": 4,
|
||||
# "hidden_layer_dims": (8, 8, 4, 4),
|
||||
"n_hidden_nodes_0": 16,
|
||||
"n_hidden_nodes_1": 8,
|
||||
"n_hidden_nodes_2": 8,
|
||||
# "n_hidden_nodes_3": 4,
|
||||
# "n_hidden_nodes_4": 2,
|
||||
},
|
||||
model_activation_func="EOActivation",
|
||||
dropout_prob=0.01,
|
||||
dropout_prob=0,
|
||||
model_layer_function="ONNRect",
|
||||
model_layer_kwargs={"square": True},
|
||||
scale=True,
|
||||
scale=2.0,
|
||||
model_layer_parametrizations=[
|
||||
{
|
||||
"tensor_name": "weight",
|
||||
"parametrization": util.complexNN.energy_conserving,
|
||||
},
|
||||
# EOactivation
|
||||
{
|
||||
"tensor_name": "alpha",
|
||||
"parametrization": util.complexNN.clamp,
|
||||
"kwargs": {
|
||||
"min": 0,
|
||||
"max": 1,
|
||||
},
|
||||
},
|
||||
# ONNRect
|
||||
{
|
||||
"tensor_name": "gain",
|
||||
"tensor_name": "weight",
|
||||
"parametrization": torch.nn.utils.parametrizations.orthogonal,
|
||||
},
|
||||
# Scale
|
||||
{
|
||||
"tensor_name": "scale",
|
||||
"parametrization": util.complexNN.clamp,
|
||||
"kwargs": {
|
||||
"min": 0,
|
||||
"max": float("inf"),
|
||||
"max": 10,
|
||||
},
|
||||
},
|
||||
{
|
||||
"tensor_name": "phase_bias",
|
||||
"parametrization": util.complexNN.clamp,
|
||||
"kwargs": {
|
||||
"min": 0,
|
||||
"max": 2 * torch.pi,
|
||||
},
|
||||
},
|
||||
{
|
||||
"tensor_name": "scales",
|
||||
"parametrization": util.complexNN.clamp,
|
||||
},
|
||||
# {
|
||||
# "tensor_name": "scale",
|
||||
# "parametrization": util.complexNN.clamp,
|
||||
# },
|
||||
# {
|
||||
# "tensor_name": "bias",
|
||||
# "parametrization": util.complexNN.clamp,
|
||||
# },
|
||||
# {
|
||||
# "tensor_name": "V",
|
||||
# "parametrization": torch.nn.utils.parametrizations.orthogonal,
|
||||
# },
|
||||
{
|
||||
"tensor_name": "loss",
|
||||
"parametrization": util.complexNN.clamp,
|
||||
},
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
optimizer_settings = OptimizerSettings(
|
||||
optimizer="AdamW",
|
||||
optimizer_kwargs={
|
||||
"lr": 0.05,
|
||||
"lr": 0.005,
|
||||
"amsgrad": True,
|
||||
# "weight_decay": 1e-7,
|
||||
},
|
||||
@@ -121,96 +119,35 @@ optimizer_settings = OptimizerSettings(
|
||||
scheduler="ReduceLROnPlateau",
|
||||
scheduler_kwargs={
|
||||
"patience": 2**6,
|
||||
"factor": 0.75,
|
||||
"factor": 0.5,
|
||||
# "threshold": 1e-3,
|
||||
"min_lr": 1e-6,
|
||||
"cooldown": 10,
|
||||
},
|
||||
early_stopping=True,
|
||||
early_stop_kwargs={
|
||||
"threshold": 1e-06,
|
||||
"plateau": 2**7,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def save_dict_to_file(dictionary, filename):
|
||||
"""
|
||||
Save the best dictionary to a JSON file.
|
||||
|
||||
:param best: Dictionary containing the best training results.
|
||||
:type best: dict
|
||||
:param filename: Path to the JSON file where the dictionary will be saved.
|
||||
:type filename: str
|
||||
"""
|
||||
with open(filename, "w") as f:
|
||||
json.dump(dictionary, f, indent=4)
|
||||
|
||||
|
||||
def sweep_lengths(*lengths, model=None):
|
||||
assert model is not None, "Model must be provided."
|
||||
model = model
|
||||
|
||||
fiber_ins = {}
|
||||
fiber_outs = {}
|
||||
regens = {}
|
||||
timestampss = {}
|
||||
|
||||
for length in lengths:
|
||||
trainer = Trainer(
|
||||
checkpoint_path=model,
|
||||
settings_override={
|
||||
"data_settings": {
|
||||
"config_path": f"data/*-128-16384-{length}-0-0-17-0-PAM4-0.ini",
|
||||
"train_split": 1,
|
||||
"shuffle": True,
|
||||
}
|
||||
},
|
||||
)
|
||||
trainer.define_model()
|
||||
loader, _ = trainer.get_sliced_data()
|
||||
fiber_in, fiber_out, regen, timestamps = trainer.run_model(trainer.model, loader=loader)
|
||||
|
||||
fiber_ins[length] = fiber_in
|
||||
fiber_outs[length] = fiber_out
|
||||
regens[length] = regen
|
||||
timestampss[length] = timestamps
|
||||
|
||||
data = torch.zeros(2 * len(lengths), 2, fiber_out.shape[0])
|
||||
channel_names = ["" for _ in range(2 * len(lengths))]
|
||||
|
||||
for li, length in enumerate(lengths):
|
||||
data[2 * li, 0, :] = timestampss[length] / 128
|
||||
data[2 * li, 1, :] = regens[length][:, 0].abs().square()
|
||||
data[2 * li + 1, 0, :] = timestampss[length] / 128
|
||||
data[2 * li + 1, 1, :] = regens[length][:, 1].abs().square()
|
||||
|
||||
channel_names[2 * li] = f"regen x {length}"
|
||||
channel_names[2 * li + 1] = f"regen y {length}"
|
||||
|
||||
# get current backend
|
||||
backend = matplotlib.get_backend()
|
||||
|
||||
matplotlib.use("TkCairo")
|
||||
eye = util.eye_diagram.eye_diagram(data.to(dtype=torch.float32).detach().cpu().numpy(), channel_names=channel_names)
|
||||
|
||||
print_attrs = ("channel", "success", "min_area")
|
||||
with np.printoptions(precision=3, suppress=True, formatter={'float': '{:0.3e}'.format}):
|
||||
for result in eye.eye_stats:
|
||||
print_dict = {attr: result[attr] for attr in print_attrs}
|
||||
rprint(print_dict)
|
||||
rprint()
|
||||
|
||||
eye.plot()
|
||||
matplotlib.use(backend)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# sweep_lengths(30000, 40000, 50000, 60000, 70000, model=".models/best_20241202_143149.tar")
|
||||
|
||||
trainer = Trainer(
|
||||
|
||||
trainer = RegenerationTrainer(
|
||||
global_settings=global_settings,
|
||||
data_settings=data_settings,
|
||||
pytorch_settings=pytorch_settings,
|
||||
model_settings=model_settings,
|
||||
optimizer_settings=optimizer_settings,
|
||||
# checkpoint_path=".models/best_20241202_143149.tar",
|
||||
# 20241202_143149
|
||||
checkpoint_path=".models/best_20250117_144001.tar",
|
||||
new_model=True,
|
||||
settings_override={
|
||||
"data_settings": data_settings.__dict__,
|
||||
# "optimizer_settings": {
|
||||
# "early_stop_kwargs":{
|
||||
# "plateau": 2**8,
|
||||
# }
|
||||
# }
|
||||
}
|
||||
)
|
||||
trainer.train()
|
||||
trainer.train()
|
||||
|
||||
@@ -26,7 +26,7 @@ while not (__parent_dir / "pypho" / "pypho").exists() and __parent_dir != Path("
|
||||
__parent_dir = __parent_dir.parent
|
||||
|
||||
if __parent_dir != Path("/"):
|
||||
sys.path.append(str(__parent_dir / "pypho"))
|
||||
sys.path.insert(0, str(__parent_dir / "pypho"))
|
||||
__log.append(f"Added '{__parent_dir/ "pypho"}' to 'PATH'")
|
||||
else:
|
||||
__log.append('pypho not found')
|
||||
729
src/single-core-regen/signal_gen/generate_signal.py
Normal file
729
src/single-core-regen/signal_gen/generate_signal.py
Normal file
@@ -0,0 +1,729 @@
|
||||
"""
|
||||
generate_signal.py
|
||||
|
||||
This file is part of the repo "optical-regeneration"
|
||||
https://git.suuppl.dev/seppl/optical-regeneration.git
|
||||
|
||||
Joseph Hopfmüller
|
||||
Copyright 2024
|
||||
Licensed under the EUPL
|
||||
|
||||
Full license text in LICENSE file
|
||||
"""
|
||||
|
||||
import configparser
|
||||
# import copy
|
||||
from datetime import datetime
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
import time
|
||||
import h5py
|
||||
from matplotlib import pyplot as plt # noqa: F401
|
||||
import numpy as np
|
||||
|
||||
from . import add_pypho # noqa: F401
|
||||
import pypho
|
||||
|
||||
default_config = f"""
|
||||
[glova]
|
||||
sps = 128
|
||||
nos = 16384
|
||||
f0 = 193414489032258.06
|
||||
symbolrate = 10e9
|
||||
wisdom_dir = "{str((Path.home() / ".pypho"))}"
|
||||
flags = "FFTW_PATIENT"
|
||||
nthreads = 32
|
||||
|
||||
[fiber]
|
||||
length = 10000
|
||||
gamma = 1.14
|
||||
alpha = 0.2
|
||||
D = 17
|
||||
S = 0.058
|
||||
bireflength = 10
|
||||
pmd_q = 0.2
|
||||
; birefseed = 0xC0FFEE
|
||||
|
||||
[signal]
|
||||
; seed = 0xC0FFEE
|
||||
|
||||
modulation = "pam"
|
||||
mod_order = 4
|
||||
mod_depth = 1
|
||||
max_jitter = 0.02
|
||||
; jitter_seed = 0xC0FFEE
|
||||
laser_power = 0
|
||||
edfa_power = 0
|
||||
edfa_nf = 5
|
||||
pulse_shape = "gauss"
|
||||
fwhm = 0.33
|
||||
osnr = "inf"
|
||||
|
||||
[data]
|
||||
dir = "data"
|
||||
npy_dir = "npys"
|
||||
"""
|
||||
|
||||
|
||||
def get_config(config_file=None):
|
||||
"""
|
||||
DANGER! The function uses eval() to parse the config file. Do not use this function with untrusted input.
|
||||
"""
|
||||
if config_file is None:
|
||||
config_file = Path(__file__).parent / "signal_generation.ini"
|
||||
config_file = Path(config_file)
|
||||
if not config_file.exists():
|
||||
with open(config_file, "w") as f:
|
||||
f.write(default_config)
|
||||
config = configparser.ConfigParser()
|
||||
config.read(config_file)
|
||||
|
||||
conf = {}
|
||||
for section in config.sections():
|
||||
# print(f"[{section}]")
|
||||
conf[section] = {}
|
||||
for key in config[section]:
|
||||
# print(f"{key} = {config[section][key]}")
|
||||
try:
|
||||
conf[section][key] = eval(config[section][key])
|
||||
except NameError:
|
||||
conf[section][key] = float(config[section][key])
|
||||
# if isinstance(conf[section][key], str):
|
||||
# conf[section][key] = config[section][key].strip('"')
|
||||
return conf
|
||||
|
||||
|
||||
class PDM_IM_IPM:
|
||||
def __init__(
|
||||
self,
|
||||
glova,
|
||||
mod_order=8,
|
||||
seed=None,
|
||||
):
|
||||
assert np.cbrt(mod_order) == int(np.cbrt(mod_order)) and mod_order > 1, (
|
||||
"mod_order must be a cube of an integer greater than 1"
|
||||
)
|
||||
self.glova = glova
|
||||
self.mod_order = mod_order
|
||||
self.symbols_per_dim = int(np.cbrt(mod_order))
|
||||
self.seed = seed
|
||||
|
||||
def generate_symbols(self, n):
|
||||
rs = np.random.RandomState(self.seed)
|
||||
symbols = rs.randint(0, self.mod_order, n)
|
||||
return symbols
|
||||
|
||||
|
||||
class pam_generator:
|
||||
def __init__(
|
||||
self, glova, mod_order=None, mod_depth=0.5, pulse_shape="gauss", fwhm=0.33, seed=None, single_channel=False
|
||||
) -> None:
|
||||
self.glova = glova
|
||||
self.pulse_shape = pulse_shape
|
||||
self.modulation_depth = mod_depth
|
||||
self.mod_order = mod_order
|
||||
self.fwhm = fwhm
|
||||
self.seed = seed
|
||||
self.single_channel = single_channel
|
||||
|
||||
def __call__(self, E, symbols, max_jitter=0):
|
||||
max_jitter = int(round(max_jitter * self.glova.sps))
|
||||
if self.pulse_shape == "gauss":
|
||||
wavelet = self.gauss(oversampling=6)
|
||||
else:
|
||||
raise ValueError(f"Unknown pulse shape: {self.pulse_shape}")
|
||||
|
||||
# prepare symbols
|
||||
symbols_x = symbols[0] / (self.mod_order)
|
||||
diffs_x = np.diff(symbols_x, prepend=symbols_x[0])
|
||||
digital_x = self.generate_digital_signal(diffs_x, max_jitter)
|
||||
digital_x = np.pad(digital_x, (0, self.glova.sps // 2), "constant", constant_values=(0, 0))
|
||||
|
||||
# create analog signal of diff of symbols
|
||||
E_x = np.convolve(digital_x, wavelet)
|
||||
|
||||
# convert to pam and set modulation depth (scale and move up such that 1 stays at 1)
|
||||
E_x = np.cumsum(E_x) * self.modulation_depth + (1 - self.modulation_depth)
|
||||
|
||||
# cut off the wavelet tails
|
||||
E_x = E_x[self.glova.sps // 2 + len(wavelet) // 2 - 1 : -len(wavelet) // 2]
|
||||
|
||||
# modulate the laser
|
||||
E[0]["E"][0] = np.sqrt(np.multiply(np.square(E[0]["E"][0]), E_x))
|
||||
|
||||
if not self.single_channel:
|
||||
symbols_y = symbols[1] / (self.mod_order)
|
||||
diffs_y = np.diff(symbols_y, prepend=symbols_y[0])
|
||||
digital_y = self.generate_digital_signal(diffs_y, max_jitter)
|
||||
digital_y = np.pad(digital_y, (0, self.glova.sps // 2), "constant", constant_values=(0, 0))
|
||||
E_y = np.convolve(digital_y, wavelet)
|
||||
|
||||
E_y = np.cumsum(E_y) * self.modulation_depth + (1 - self.modulation_depth)
|
||||
|
||||
E_y = E_y[self.glova.sps // 2 + len(wavelet) // 2 - 1 : -len(wavelet) // 2]
|
||||
|
||||
E[0]["E"][1] = np.sqrt(np.multiply(np.square(E[0]["E"][1]), E_y))
|
||||
|
||||
# rotate the signal on the y-polarisation by 90°
|
||||
# E[0]["E"][1] *= 1j
|
||||
else:
|
||||
E[0]["E"][1] = np.zeros_like(E[0]["E"][0], dtype=E[0]["E"][0].dtype)
|
||||
return E
|
||||
|
||||
def generate_digital_signal(self, symbols, max_jitter=0):
|
||||
rs = np.random.RandomState(self.seed)
|
||||
signal = np.zeros(self.glova.nos * self.glova.sps)
|
||||
for index in range(self.glova.nos):
|
||||
jitter = max_jitter != 0 and rs.randint(-max_jitter, max_jitter)
|
||||
signal_index = index * self.glova.sps + jitter
|
||||
if signal_index < 0:
|
||||
continue
|
||||
if signal_index >= len(signal):
|
||||
continue
|
||||
signal[signal_index] = symbols[index]
|
||||
return signal
|
||||
|
||||
def gauss(self, oversampling=1):
|
||||
sample_points = np.linspace(
|
||||
-oversampling * self.glova.sps,
|
||||
oversampling * self.glova.sps,
|
||||
oversampling * 2 * self.glova.sps,
|
||||
endpoint=True,
|
||||
)
|
||||
sigma = self.fwhm / (1 * np.sqrt(2 * np.log(2))) * self.glova.sps
|
||||
pulse = 1 / (sigma * np.sqrt(2 * np.pi)) * np.exp(-np.square(sample_points) / (2 * np.square(sigma)))
|
||||
return pulse
|
||||
|
||||
|
||||
def initialize_fiber_and_data(config):
|
||||
f0 = config["glova"].get("f0", None)
|
||||
if f0 is None:
|
||||
f0 = 299792458/(config["glova"].get("lambda0", 1550)*1e-9)
|
||||
config["glova"]["f0"] = f0
|
||||
py_glova = pypho.setup(
|
||||
nos=config["glova"]["nos"],
|
||||
sps=config["glova"]["sps"],
|
||||
f0=f0,
|
||||
symbolrate=config["glova"]["symbolrate"],
|
||||
wisdom_dir=config["glova"]["wisdom_dir"],
|
||||
flags=config["glova"]["flags"],
|
||||
nthreads=config["glova"]["nthreads"],
|
||||
)
|
||||
|
||||
c_glova = pypho.cfiber.GlovaWrapper.from_setup(py_glova)
|
||||
c_data = pypho.cfiber.DataWrapper(py_glova.sps * py_glova.nos)
|
||||
py_edfa = pypho.oamp(py_glova, Pmean=config["signal"]["edfa_power"], NF=config["signal"]["edfa_nf"])
|
||||
|
||||
osnr = config["signal"]["osnr"] = float(config["signal"].get("osnr", "inf"))
|
||||
|
||||
config["signal"]["seed"] = config["signal"].get("seed", (int(time.time() * 1000)) % 2**32)
|
||||
config["signal"]["jitter_seed"] = config["signal"].get("jitter_seed", (int(time.time() * 1000)) % 2**32)
|
||||
symbolsrc = pypho.symbols(
|
||||
py_glova, py_glova.nos, pattern="ones", p1=config["signal"]["mod_order"], seed=config["signal"]["seed"]
|
||||
)
|
||||
laserx = pypho.lasmod(py_glova, power=0, Df=0, theta=np.pi/4)
|
||||
# lasery = pypho.lasmod(py_glova, power=0, Df=25, theta=0)
|
||||
|
||||
modulator = pam_generator(
|
||||
py_glova,
|
||||
mod_depth=config["signal"]["mod_depth"],
|
||||
pulse_shape=config["signal"]["pulse_shape"],
|
||||
fwhm=config["signal"]["fwhm"],
|
||||
seed=config["signal"]["jitter_seed"],
|
||||
mod_order=config["signal"]["mod_order"],
|
||||
)
|
||||
|
||||
symbols_x = symbolsrc(pattern="random")
|
||||
symbols_y = symbolsrc(pattern="random")
|
||||
symbols_x[:3] = 0
|
||||
symbols_y[:3] = 0
|
||||
# symbols_x += 1
|
||||
|
||||
|
||||
cw = laserx()
|
||||
# cwy = lasery()
|
||||
# cw[0]['E'][0] = cw[0]['E'][0]
|
||||
# cw[0]['E'][1] = cwy[0]['E'][0]
|
||||
|
||||
|
||||
source_signal = modulator(E=cw, symbols=(symbols_x, symbols_y))
|
||||
|
||||
if osnr != float("inf"):
|
||||
osnr_lin = 10 ** (osnr / 10)
|
||||
signal_power = np.sum(pypho.functions.getpower_W(source_signal[0]["E"]))
|
||||
noise_power = signal_power / osnr_lin
|
||||
noise = np.random.normal(0, 1, source_signal[0]["E"].shape) + 1j * np.random.normal(
|
||||
0, 1, source_signal[0]["E"].shape
|
||||
)
|
||||
noise_power_is = np.sum(pypho.functions.getpower_W(noise))
|
||||
noise = noise * np.sqrt(noise_power / noise_power_is)
|
||||
noise_power_is = np.sum(pypho.functions.getpower_W(noise))
|
||||
source_signal[0]["E"] += noise
|
||||
source_signal[0]["noise"] = noise_power_is
|
||||
|
||||
# source_signal[0]['E'][1] += source_signal[0]['E'][1][np.argmin(np.abs(source_signal[0]['E'][1]))]
|
||||
|
||||
## side channels
|
||||
# df = 100
|
||||
# signal_power = pypho.functions.W_to_dBm(np.sum(pypho.functions.getpower_W(source_signal[0]["E"])))
|
||||
|
||||
|
||||
# symbols_x_side = symbolsrc(pattern="random")
|
||||
# symbols_y_side = symbolsrc(pattern="random")
|
||||
# symbols_x_side[:3] = 0
|
||||
# symbols_y_side[:3] = 0
|
||||
|
||||
# cw_left = laser(Df=-df)
|
||||
# source_signal_left = modulator(E=cw_left, symbols=(symbols_x_side, symbols_y_side))
|
||||
|
||||
# cw_right = laser(Df=df)
|
||||
# source_signal_right = modulator(E=cw_right, symbols=(symbols_y_side, symbols_x_side))
|
||||
|
||||
E_in_pure = source_signal[0]["E"]
|
||||
|
||||
nf = py_edfa.NF
|
||||
pmean = py_edfa.Pmean
|
||||
|
||||
# ideal amplification to launch power into fiber
|
||||
source_signal = py_edfa(E=source_signal, NF=0, Pmean=config["signal"]["laser_power"])
|
||||
# source_signal_left = py_edfa(E=source_signal_left, NF=0, Pmean=config["signal"]["laser_power"])
|
||||
# source_signal_right = py_edfa(E=source_signal_right, NF=0, Pmean=config["signal"]["laser_power"])
|
||||
|
||||
# source_signal[0]["E"][0] += source_signal_left[0]["E"][0] + source_signal_right[0]["E"][0]
|
||||
# source_signal[0]["E"][1] += source_signal_left[0]["E"][1] + source_signal_right[0]["E"][1]
|
||||
|
||||
c_data.E_in = source_signal[0]["E"]
|
||||
noise = source_signal[0]["noise"]
|
||||
|
||||
py_edfa.NF = nf
|
||||
py_edfa.Pmean = pmean
|
||||
|
||||
py_fiber = pypho.fiber(
|
||||
glova=py_glova,
|
||||
l=config["fiber"]["length"],
|
||||
alpha=pypho.functions.dB_to_Neper(config["fiber"]["alpha"]) / 1000,
|
||||
gamma=config["fiber"]["gamma"],
|
||||
D=config["fiber"]["d"],
|
||||
S=config["fiber"]["s"],
|
||||
phi_max=0.02,
|
||||
)
|
||||
|
||||
config["fiber"]["birefsteps"] = config["fiber"].get(
|
||||
"birefsteps", config["fiber"]["length"] // config["fiber"].get("bireflength", config["fiber"]["length"])
|
||||
)
|
||||
if config["fiber"]["birefsteps"] > 0:
|
||||
config["fiber"]["bireflength"] = config["fiber"].get("bireflength", config["fiber"]["length"] / config["fiber"]["birefsteps"])
|
||||
seed = config["fiber"].get("birefseed", (int(time.time() * 1000)) % 2**32)
|
||||
py_fiber.birefarray = pypho.birefringence_segment.create_pmd_fibre(
|
||||
config["fiber"]["length"],
|
||||
config["fiber"]["bireflength"],
|
||||
maxDeltaBeta=config["fiber"].get("pmd_q", 0)/np.sqrt(2*config["fiber"]["bireflength"]),
|
||||
seed=seed,
|
||||
)
|
||||
elif (dgd := config['fiber'].get('dgd', 0)) > 0:
|
||||
py_fiber.birefarray = [
|
||||
pypho.birefringence_segment(z_point=0, angle=np.pi/2, delta_beta=1000*dgd/config["fiber"]["length"])
|
||||
]
|
||||
c_params = pypho.cfiber.ParamsWrapper.from_fiber(py_fiber, max_step=config["fiber"]["length"] if py_fiber.gamma == 0 else 200)
|
||||
c_fiber = pypho.cfiber.FiberWrapper(c_data, c_params, c_glova)
|
||||
|
||||
return c_fiber, c_data, noise, py_edfa, (symbols_x, symbols_y), py_glova, E_in_pure
|
||||
|
||||
|
||||
def save_data(data, config, **metadata):
|
||||
data_dir = Path(config["data"]["dir"])
|
||||
npy_dir = config["data"].get("npy_dir", "")
|
||||
save_dir = data_dir / npy_dir if len(npy_dir) else data_dir
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
save_data = np.column_stack([
|
||||
data.E_in[0],
|
||||
data.E_in[1],
|
||||
data.E_out[0],
|
||||
data.E_out[1],
|
||||
])
|
||||
timestamp = datetime.now()
|
||||
seed = config["signal"].get("seed", False)
|
||||
jitter_seed = config["signal"].get("jitter_seed", False)
|
||||
birefseed = config["fiber"].get("birefseed", False)
|
||||
osnr = float(config["signal"].get("osnr", "inf"))
|
||||
|
||||
config_content = "\n".join((
|
||||
f"; Generated by {str(Path(__file__).name)} @ {timestamp.strftime('%Y-%m-%d %H:%M:%S')}",
|
||||
"[glova]",
|
||||
f"sps = {config['glova']['sps']}",
|
||||
f"nos = {config['glova']['nos']}",
|
||||
f"f0 = {config['glova']['f0']}",
|
||||
f"symbolrate = {config['glova']['symbolrate']}",
|
||||
f'wisdom_dir = "{config["glova"]["wisdom_dir"]}"',
|
||||
f'flags = "{config["glova"]["flags"]}"',
|
||||
f"nthreads = {config['glova']['nthreads']}",
|
||||
"",
|
||||
"[fiber]",
|
||||
f"length = {config['fiber']['length']}",
|
||||
f"gamma = {config['fiber']['gamma']}",
|
||||
f"alpha = {config['fiber']['alpha']}",
|
||||
f"D = {config['fiber']['d']}",
|
||||
f"S = {config['fiber']['s']}",
|
||||
f"birefsteps = {config['fiber'].get('birefsteps', 0)}",
|
||||
f"pmd_q = {config['fiber'].get('pmd_q', 0)}",
|
||||
f"birefseed = {hex(birefseed)}" if birefseed else "; birefseed = not set",
|
||||
f"dgd = {config['fiber'].get('dgd', 0)}",
|
||||
f"ortho_error = {config['fiber'].get('ortho_error', 0)}",
|
||||
f"pol_error = {config['fiber'].get('pol_error', 0)}",
|
||||
"",
|
||||
"[signal]",
|
||||
f"seed = {hex(seed)}" if seed else "; seed = not set",
|
||||
"",
|
||||
f'modulation = "{config["signal"]["modulation"]}"',
|
||||
f"mod_order = {config['signal']['mod_order']}",
|
||||
f"mod_depth = {config['signal']['mod_depth']}",
|
||||
"",
|
||||
f"max_jitter = {config['signal']['max_jitter']}",
|
||||
f"jitter_seed = {hex(jitter_seed)}" if jitter_seed else "; jitter_seed = not set",
|
||||
"",
|
||||
f"laser_power = {config['signal']['laser_power']}",
|
||||
f"edfa_power = {config['signal']['edfa_power']}",
|
||||
f"edfa_nf = {config['signal']['edfa_nf']}",
|
||||
f"osnr = {osnr}",
|
||||
"",
|
||||
f'pulse_shape = "{config["signal"]["pulse_shape"]}"',
|
||||
f"fwhm = {config['signal']['fwhm']}",
|
||||
"",
|
||||
"[data]",
|
||||
f'dir = "{str(data_dir)}"',
|
||||
f'npy_dir = "{npy_dir}"',
|
||||
"file = ",
|
||||
))
|
||||
config_hash = hashlib.md5(config_content.encode()).hexdigest()
|
||||
save_file = f"{config_hash}.h5"
|
||||
config_content += f'"{str(save_file)}"\n'
|
||||
|
||||
config_filename:Path = create_config_filename(config, data_dir, timestamp)
|
||||
while config_filename.exists():
|
||||
time.sleep(1)
|
||||
config_filename = create_config_filename(config, data_dir=data_dir)
|
||||
|
||||
|
||||
with open(config_filename, "w") as f:
|
||||
f.write(config_content)
|
||||
|
||||
with h5py.File(save_dir / save_file, "w") as outfile:
|
||||
outfile.create_dataset("data", data=save_data)
|
||||
outfile.create_dataset("symbols", data=metadata.pop("symbols"))
|
||||
for key, value in metadata.items():
|
||||
# if isinstance(value, dict):
|
||||
# value = json.dumps(model_runner.convert_arrays(value))
|
||||
outfile.attrs[key] = value
|
||||
# np.save(save_dir / save_file, save_data)
|
||||
|
||||
# print("Saved config to", config_filename)
|
||||
# print("Saved data to", save_dir / save_file)
|
||||
|
||||
return config_filename
|
||||
|
||||
def create_config_filename(config, data_dir:Path, timestamp=None):
|
||||
if timestamp is None:
|
||||
timestamp = datetime.now()
|
||||
filename_components = (
|
||||
timestamp.strftime("%Y%m%d-%H%M%S"),
|
||||
config["glova"]["sps"],
|
||||
config["glova"]["nos"],
|
||||
config["signal"]["osnr"],
|
||||
config["fiber"]["length"],
|
||||
config["fiber"]["gamma"],
|
||||
config["fiber"]["alpha"],
|
||||
config["fiber"]["d"],
|
||||
config["fiber"]["s"],
|
||||
f"{config['signal']['modulation'].upper()}{config['signal']['mod_order']}",
|
||||
config["fiber"].get("birefsteps", 0),
|
||||
config["fiber"].get("pmd_q", 0),
|
||||
int(config["glova"]["symbolrate"] / 1e9),
|
||||
)
|
||||
lookup_file = "-".join(map(str, filename_components)) + ".ini"
|
||||
return data_dir / lookup_file
|
||||
|
||||
def length_loop(config, lengths, save=True):
|
||||
lengths = sorted(lengths)
|
||||
for length in lengths:
|
||||
print(f"\nGenerating data for fiber length {length}m")
|
||||
config["fiber"]["length"] = length
|
||||
|
||||
cfiber, cdata, noise, edfa, symbols, py_glova = initialize_fiber_and_data(config)
|
||||
|
||||
mean_power_in = np.sum(pypho.functions.getpower_W(cdata.E_in))
|
||||
cfiber()
|
||||
mean_power_out = np.sum(pypho.functions.getpower_W(cdata.E_out))
|
||||
|
||||
print(f"Mean power in: {mean_power_in:.3e} W ({pypho.functions.W_to_dBm(mean_power_in):.3e} dBm)")
|
||||
print(f"Mean power out: {mean_power_out:.3e} W ({pypho.functions.W_to_dBm(mean_power_out):.3e} dBm)")
|
||||
|
||||
E_tmp = [{"E": cdata.E_out, "noise": noise * (-cfiber.params.l * cfiber.params.alpha)}]
|
||||
E_tmp = edfa(E=E_tmp)
|
||||
cdata.E_out = E_tmp[0]["E"]
|
||||
|
||||
mean_power_amp = np.sum(pypho.functions.getpower_W(cdata.E_out))
|
||||
print(f"Mean power after EDFA: {mean_power_amp:.3e} W ({pypho.functions.W_to_dBm(mean_power_amp):.3e} dBm)")
|
||||
|
||||
if save:
|
||||
save_data(cdata, config)
|
||||
|
||||
in_out_eyes(cfiber, cdata)
|
||||
|
||||
|
||||
def single_run_with_plot(config, save=True):
|
||||
cfiber, cdata, config_filename = single_run(config, save)
|
||||
|
||||
in_out_eyes(cfiber, cdata, show_pols=False)
|
||||
return config_filename
|
||||
|
||||
|
||||
def single_run(config, save=True, silent=True):
|
||||
cfiber, cdata, noise, edfa, symbols, glova, E_in = initialize_fiber_and_data(config)
|
||||
|
||||
# transmit
|
||||
cfiber()
|
||||
|
||||
# amplify
|
||||
E_tmp = [{"E": cdata.E_out, "noise": noise}]
|
||||
|
||||
E_tmp = edfa(E=E_tmp)
|
||||
|
||||
|
||||
# rotate
|
||||
# ortho error
|
||||
ortho_error = config["fiber"].get("ortho_error", 0)
|
||||
|
||||
E_tmp[0]["E"] = np.stack((
|
||||
E_tmp[0]["E"][0] * np.cos(ortho_error/2) + E_tmp[0]["E"][1] * np.sin(ortho_error/2),
|
||||
E_tmp[0]["E"][0] * np.sin(ortho_error/2) + E_tmp[0]["E"][1] * np.cos(ortho_error/2)
|
||||
), axis=0)
|
||||
|
||||
|
||||
pol_error = config['fiber'].get('pol_error', 0)
|
||||
|
||||
E_tmp[0]["E"] = np.stack((
|
||||
E_tmp[0]["E"][0] * np.cos(pol_error) - E_tmp[0]["E"][1] * np.sin(pol_error),
|
||||
E_tmp[0]["E"][0] * np.sin(pol_error) + E_tmp[0]["E"][1] * np.cos(pol_error)
|
||||
), axis=0)
|
||||
|
||||
|
||||
|
||||
|
||||
# output
|
||||
cdata.E_out = E_tmp[0]["E"]
|
||||
|
||||
config_filename = None
|
||||
symbols = np.array(symbols)
|
||||
if save:
|
||||
config_filename = save_data(cdata, config, **{"symbols": symbols})
|
||||
if not silent:
|
||||
print(f"Saved config to {config_filename}")
|
||||
return cfiber, cdata, config_filename
|
||||
|
||||
|
||||
def in_out_eyes(cfiber, cdata, show_pols=False):
|
||||
fig, axs = plt.subplots(2, 2, sharex=True, sharey=True)
|
||||
eye_head = min(cfiber.glova.nos, 2000)
|
||||
symbolrate_scale = 1e12
|
||||
amplitude_scale = 1e3
|
||||
plot_eye_diagram(
|
||||
amplitude_scale * np.abs(cdata.E_in[0]) ** 2,
|
||||
2 * cfiber.glova.sps,
|
||||
normalize=False,
|
||||
samplerate=cfiber.glova.symbolrate * cfiber.glova.sps / symbolrate_scale,
|
||||
head=eye_head,
|
||||
ax=axs[0][0],
|
||||
show=False,
|
||||
color="C0",
|
||||
)
|
||||
if show_pols:
|
||||
plot_eye_diagram(
|
||||
amplitude_scale * np.abs(cdata.E_in[0].real) ** 2,
|
||||
2 * cfiber.glova.sps,
|
||||
normalize=False,
|
||||
samplerate=cfiber.glova.symbolrate * cfiber.glova.sps / symbolrate_scale,
|
||||
head=eye_head,
|
||||
ax=axs[0][0],
|
||||
color="C2",
|
||||
show=False,
|
||||
)
|
||||
plot_eye_diagram(
|
||||
amplitude_scale * np.abs(cdata.E_in[0].imag) ** 2,
|
||||
2 * cfiber.glova.sps,
|
||||
normalize=False,
|
||||
samplerate=cfiber.glova.symbolrate * cfiber.glova.sps / symbolrate_scale,
|
||||
head=eye_head,
|
||||
ax=axs[0][0],
|
||||
color="C3",
|
||||
show=False,
|
||||
)
|
||||
plot_eye_diagram(
|
||||
amplitude_scale * np.abs(cdata.E_out[0]) ** 2,
|
||||
2 * cfiber.glova.sps,
|
||||
normalize=False,
|
||||
samplerate=cfiber.glova.symbolrate * cfiber.glova.sps / symbolrate_scale,
|
||||
head=eye_head,
|
||||
ax=axs[0][1],
|
||||
color="C1",
|
||||
show=False,
|
||||
)
|
||||
if show_pols:
|
||||
plot_eye_diagram(
|
||||
amplitude_scale * np.abs(cdata.E_out[0].real) ** 2,
|
||||
2 * cfiber.glova.sps,
|
||||
normalize=False,
|
||||
samplerate=cfiber.glova.symbolrate * cfiber.glova.sps / symbolrate_scale,
|
||||
head=eye_head,
|
||||
ax=axs[0][1],
|
||||
color="C4",
|
||||
show=False,
|
||||
)
|
||||
plot_eye_diagram(
|
||||
amplitude_scale * np.abs(cdata.E_out[0].imag) ** 2,
|
||||
2 * cfiber.glova.sps,
|
||||
normalize=False,
|
||||
samplerate=cfiber.glova.symbolrate * cfiber.glova.sps / symbolrate_scale,
|
||||
head=eye_head,
|
||||
ax=axs[0][1],
|
||||
color="C5",
|
||||
show=False,
|
||||
)
|
||||
plot_eye_diagram(
|
||||
amplitude_scale * np.abs(cdata.E_in[1]) ** 2,
|
||||
2 * cfiber.glova.sps,
|
||||
normalize=False,
|
||||
samplerate=cfiber.glova.symbolrate * cfiber.glova.sps / symbolrate_scale,
|
||||
head=eye_head,
|
||||
ax=axs[1][0],
|
||||
color="C0",
|
||||
show=False,
|
||||
)
|
||||
if show_pols:
|
||||
plot_eye_diagram(
|
||||
amplitude_scale * np.abs(cdata.E_in[1].real) ** 2,
|
||||
2 * cfiber.glova.sps,
|
||||
normalize=False,
|
||||
samplerate=cfiber.glova.symbolrate * cfiber.glova.sps / symbolrate_scale,
|
||||
head=eye_head,
|
||||
ax=axs[1][0],
|
||||
color="C2",
|
||||
show=False,
|
||||
)
|
||||
plot_eye_diagram(
|
||||
amplitude_scale * np.abs(cdata.E_in[1].imag) ** 2,
|
||||
2 * cfiber.glova.sps,
|
||||
normalize=False,
|
||||
samplerate=cfiber.glova.symbolrate * cfiber.glova.sps / symbolrate_scale,
|
||||
head=eye_head,
|
||||
ax=axs[1][0],
|
||||
color="C3",
|
||||
show=False,
|
||||
)
|
||||
plot_eye_diagram(
|
||||
amplitude_scale * np.abs(cdata.E_out[1]) ** 2,
|
||||
2 * cfiber.glova.sps,
|
||||
normalize=False,
|
||||
samplerate=cfiber.glova.symbolrate * cfiber.glova.sps / symbolrate_scale,
|
||||
head=eye_head,
|
||||
ax=axs[1][1],
|
||||
color="C1",
|
||||
show=False,
|
||||
)
|
||||
if show_pols:
|
||||
plot_eye_diagram(
|
||||
amplitude_scale * np.abs(cdata.E_out[1].real) ** 2,
|
||||
2 * cfiber.glova.sps,
|
||||
normalize=False,
|
||||
samplerate=cfiber.glova.symbolrate * cfiber.glova.sps / symbolrate_scale,
|
||||
head=eye_head,
|
||||
ax=axs[0][1],
|
||||
color="C4",
|
||||
show=False,
|
||||
)
|
||||
plot_eye_diagram(
|
||||
amplitude_scale * np.abs(cdata.E_out[1].imag) ** 2,
|
||||
2 * cfiber.glova.sps,
|
||||
normalize=False,
|
||||
samplerate=cfiber.glova.symbolrate * cfiber.glova.sps / symbolrate_scale,
|
||||
head=eye_head,
|
||||
ax=axs[0][1],
|
||||
color="C5",
|
||||
show=False,
|
||||
)
|
||||
|
||||
title_map = [
|
||||
["Input x", "Output x"],
|
||||
["Input y", "Output y"],
|
||||
]
|
||||
title_map = np.array(title_map)
|
||||
for ax, title in zip(axs.flatten(), title_map.flatten()):
|
||||
ax.grid(True)
|
||||
ax.set_xlabel("Time [ps]")
|
||||
ax.set_ylabel("Power [mW]")
|
||||
ax.set_title(title)
|
||||
fig.tight_layout()
|
||||
|
||||
plt.show()
|
||||
|
||||
|
||||
def plot_eye_diagram(
|
||||
signal: np.ndarray,
|
||||
eye_width,
|
||||
offset=0,
|
||||
*,
|
||||
head=None,
|
||||
samplerate=1,
|
||||
normalize=True,
|
||||
ax=None,
|
||||
color="C0",
|
||||
show=True,
|
||||
):
|
||||
ax = ax or plt.gca()
|
||||
if head is not None:
|
||||
signal = signal[: head * eye_width]
|
||||
if normalize:
|
||||
signal = signal / np.max(signal)
|
||||
slices = np.lib.stride_tricks.sliding_window_view(signal, eye_width + 1)[offset % (eye_width + 1) :: eye_width]
|
||||
plt_ax = np.arange(-eye_width // 2, eye_width // 2 + 1) / samplerate
|
||||
for slice in slices:
|
||||
ax.plot(plt_ax, slice, color=color, alpha=0.1)
|
||||
ax.grid()
|
||||
if show:
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
add_pypho.show_log()
|
||||
config = get_config()
|
||||
|
||||
# ranges = (1000,10000)
|
||||
# scales = tuple(range(1, 10))
|
||||
# scales = (1,)
|
||||
# lengths = [range_ * scale for range_ in ranges for scale in scales]
|
||||
# lengths.append(10*max(ranges))
|
||||
# lengths = [*lengths, *lengths]
|
||||
lengths = (
|
||||
# 8000, 9000,
|
||||
10000,
|
||||
20000,
|
||||
30000,
|
||||
40000,
|
||||
50000,
|
||||
60000,
|
||||
70000,
|
||||
80000,
|
||||
90000,
|
||||
95000,
|
||||
100000,
|
||||
105000,
|
||||
110000,
|
||||
115000,
|
||||
120000,
|
||||
)
|
||||
|
||||
# lengths = (10000,100000)
|
||||
|
||||
# length_loop(config, lengths, save=True)
|
||||
# birefringence is constant over coupling length -> several 100m -> bireflength=1000 (m)
|
||||
|
||||
single_run_with_plot(config, save=False)
|
||||
@@ -39,7 +39,7 @@ import numpy as np
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
dataset = FiberRegenerationDataset("data/20241115-175517-128-16384-10000-0-0-17-0-PAM4-0.ini", symbols=13, drop_first=100, output_dim=26, num_symbols=100)
|
||||
dataset = FiberRegenerationDataset("data/202412*-128-16384-50000-0-0-17-0-PAM4-0.ini", symbols=13, drop_first=100, output_dim=26, num_symbols=100)
|
||||
|
||||
loader = DataLoader(dataset, batch_size=10, shuffle=True)
|
||||
|
||||
|
||||
138
src/single-core-regen/testing/prob_dens.ipynb
Normal file
138
src/single-core-regen/testing/prob_dens.ipynb
Normal file
File diff suppressed because one or more lines are too long
1351
src/single-core-regen/tolerance_testing.py
Normal file
1351
src/single-core-regen/tolerance_testing.py
Normal file
File diff suppressed because it is too large
Load Diff
234
src/single-core-regen/train_pol_estimator.py
Normal file
234
src/single-core-regen/train_pol_estimator.py
Normal file
@@ -0,0 +1,234 @@
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
import matplotlib
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.tensorboard
|
||||
import torch.utils.tensorboard.summary
|
||||
from hypertraining.settings import (
|
||||
GlobalSettings,
|
||||
DataSettings,
|
||||
PytorchSettings,
|
||||
ModelSettings,
|
||||
OptimizerSettings,
|
||||
)
|
||||
|
||||
from hypertraining.training import RegenerationTrainer, PolarizationTrainer
|
||||
|
||||
# import torch
|
||||
import json
|
||||
import util
|
||||
|
||||
from rich import print as rprint
|
||||
|
||||
global_settings = GlobalSettings(
|
||||
seed=0xC0FFEE,
|
||||
)
|
||||
|
||||
data_settings = DataSettings(
|
||||
config_path="data/20241211-105524-128-16384-1-0-0-0-0-PAM4-0-0.ini",
|
||||
# config_path=[f"data/20241202-*-128-16384-{length}-0-0-17-0-PAM4-0.ini" for length in range(48000, 53000, 1000)],
|
||||
dtype="complex64",
|
||||
# symbols = (9, 20), # 13 symbol @ 10GBd <-> 1.3ns <-> 0.26m of fiber
|
||||
symbols=13, # study: single_core_regen_20241123_011232
|
||||
# output_size = (11, 32), # 26 taps -> 2 taps per input symbol -> 1 tap every 1cm (model has 52 inputs (x/y))
|
||||
output_size=26, # study: single_core_regen_20241123_011232 (model_input_dim/2)
|
||||
shuffle=True,
|
||||
drop_first=64,
|
||||
train_split=0.8,
|
||||
# polarisations=tuple(np.random.rand(2)*2*np.pi),
|
||||
randomise_polarisations=True,
|
||||
)
|
||||
|
||||
pytorch_settings = PytorchSettings(
|
||||
epochs=10000,
|
||||
batchsize=2**12,
|
||||
device="cuda",
|
||||
dataloader_workers=16,
|
||||
dataloader_prefetch=8,
|
||||
summary_dir=".runs",
|
||||
write_every=2**5,
|
||||
save_models=True,
|
||||
model_dir=".models",
|
||||
)
|
||||
|
||||
model_settings = ModelSettings(
|
||||
output_dim=1,
|
||||
n_hidden_layers=3,
|
||||
overrides={
|
||||
"n_hidden_nodes_0": 4,
|
||||
"n_hidden_nodes_1": 4,
|
||||
"n_hidden_nodes_2": 4,
|
||||
},
|
||||
dropout_prob=0,
|
||||
model_layer_function="ONNRect",
|
||||
model_activation_func="EOActivation",
|
||||
model_layer_kwargs={"square": True},
|
||||
scale=False,
|
||||
model_layer_parametrizations=[
|
||||
{
|
||||
"tensor_name": "weight",
|
||||
"parametrization": util.complexNN.energy_conserving,
|
||||
},
|
||||
{
|
||||
"tensor_name": "alpha",
|
||||
"parametrization": util.complexNN.clamp,
|
||||
},
|
||||
{
|
||||
"tensor_name": "gain",
|
||||
"parametrization": util.complexNN.clamp,
|
||||
"kwargs": {
|
||||
"min": 0,
|
||||
"max": float("inf"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"tensor_name": "phase_bias",
|
||||
"parametrization": util.complexNN.clamp,
|
||||
"kwargs": {
|
||||
"min": 0,
|
||||
"max": 2 * torch.pi,
|
||||
},
|
||||
},
|
||||
{
|
||||
"tensor_name": "scales",
|
||||
"parametrization": util.complexNN.clamp,
|
||||
},
|
||||
{
|
||||
"tensor_name": "angle",
|
||||
"parametrization": util.complexNN.clamp,
|
||||
"kwargs": {
|
||||
"min": 0,
|
||||
"max": 2*torch.pi,
|
||||
},
|
||||
},
|
||||
{
|
||||
"tensor_name": "loss",
|
||||
"parametrization": util.complexNN.clamp,
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
optimizer_settings = OptimizerSettings(
|
||||
optimizer="RMSprop",
|
||||
# optimizer="AdamW",
|
||||
optimizer_kwargs={
|
||||
"lr": 0.01,
|
||||
"alpha": 0.9,
|
||||
"momentum": 0.1,
|
||||
"eps": 1e-8,
|
||||
"centered": True,
|
||||
# "amsgrad": True,
|
||||
# "weight_decay": 1e-7,
|
||||
},
|
||||
scheduler="ReduceLROnPlateau",
|
||||
scheduler_kwargs={
|
||||
"patience": 2**5,
|
||||
"factor": 0.75,
|
||||
# "threshold": 1e-3,
|
||||
"min_lr": 1e-6,
|
||||
# "cooldown": 10,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def save_dict_to_file(dictionary, filename):
|
||||
"""
|
||||
Save the best dictionary to a JSON file.
|
||||
|
||||
:param best: Dictionary containing the best training results.
|
||||
:type best: dict
|
||||
:param filename: Path to the JSON file where the dictionary will be saved.
|
||||
:type filename: str
|
||||
"""
|
||||
with open(filename, "w") as f:
|
||||
json.dump(dictionary, f, indent=4)
|
||||
|
||||
|
||||
def sweep_lengths(*lengths, model=None, data_glob: str = None, strategy="newest"):
|
||||
assert model is not None, "Model must be provided."
|
||||
assert data_glob is not None, "Data glob must be provided."
|
||||
model = model
|
||||
|
||||
fiber_ins = {}
|
||||
fiber_outs = {}
|
||||
regens = {}
|
||||
timestampss = {}
|
||||
|
||||
trainer = RegenerationTrainer(
|
||||
checkpoint_path=model,
|
||||
)
|
||||
trainer.define_model()
|
||||
|
||||
for length in lengths:
|
||||
data_glob_length = data_glob.replace("{length}", str(length))
|
||||
files = list(Path.cwd().glob(data_glob_length))
|
||||
if len(files) == 0:
|
||||
continue
|
||||
if strategy == "newest":
|
||||
sorted_kwargs = {
|
||||
"key": lambda x: x.stat().st_mtime,
|
||||
"reverse": True,
|
||||
}
|
||||
elif strategy == "oldest":
|
||||
sorted_kwargs = {
|
||||
"key": lambda x: x.stat().st_mtime,
|
||||
"reverse": False,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unknown strategy {strategy}.")
|
||||
file = sorted(files, **sorted_kwargs)[0]
|
||||
|
||||
loader, _ = trainer.get_sliced_data(override={"config_path": file})
|
||||
fiber_in, fiber_out, regen, timestamps = trainer.run_model(trainer.model, loader=loader)
|
||||
|
||||
fiber_ins[length] = fiber_in
|
||||
fiber_outs[length] = fiber_out
|
||||
regens[length] = regen
|
||||
timestampss[length] = timestamps
|
||||
|
||||
data = torch.zeros(2 * len(timestampss.keys()) + 2, 2, tuple(fiber_outs.values())[-1].shape[0])
|
||||
channel_names = ["" for _ in range(2 * len(timestampss.keys()) + 2)]
|
||||
|
||||
data[1, 0, :] = timestampss[tuple(timestampss.keys())[-1]] / 128
|
||||
data[1, 1, :] = fiber_ins[tuple(timestampss.keys())[-1]][:, 0].abs().square()
|
||||
|
||||
channel_names[1] = "fiber in x"
|
||||
|
||||
for li, length in enumerate(timestampss.keys()):
|
||||
data[2 + 2 * li, 0, :] = timestampss[length] / 128
|
||||
data[2 + 2 * li, 1, :] = fiber_outs[length][:, 0].abs().square()
|
||||
data[2 + 2 * li + 1, 0, :] = timestampss[length] / 128
|
||||
data[2 + 2 * li + 1, 1, :] = regens[length][:, 0].abs().square()
|
||||
|
||||
channel_names[2 + 2 * li + 1] = f"regen x {length}"
|
||||
channel_names[2 + 2 * li] = f"fiber out x {length}"
|
||||
|
||||
# get current backend
|
||||
backend = matplotlib.get_backend()
|
||||
|
||||
matplotlib.use("TkCairo")
|
||||
eye = util.eye_diagram.eye_diagram(data.to(dtype=torch.float32).detach().cpu().numpy(), channel_names=channel_names)
|
||||
|
||||
print_attrs = ("channel_name", "success", "min_area")
|
||||
with np.printoptions(precision=3, suppress=True, formatter={"float": "{:0.3e}".format}):
|
||||
for result in eye.eye_stats:
|
||||
print_dict = {attr: result[attr] for attr in print_attrs}
|
||||
rprint(print_dict)
|
||||
rprint()
|
||||
|
||||
eye.plot(all_stats=False)
|
||||
matplotlib.use(backend)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
trainer = PolarizationTrainer(
|
||||
global_settings=global_settings,
|
||||
data_settings=data_settings,
|
||||
pytorch_settings=pytorch_settings,
|
||||
model_settings=model_settings,
|
||||
optimizer_settings=optimizer_settings,
|
||||
# checkpoint_path='.models/pol_pol_20241208_122418_1116.tar',
|
||||
# reset_epoch=True
|
||||
)
|
||||
trainer.train()
|
||||
@@ -260,12 +260,117 @@ class ONNRect(nn.Module):
|
||||
self.crop = lambda x: x
|
||||
self.crop.__doc__ = "No cropping"
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
x = self.pad(x)
|
||||
x = self.pad(x).to(dtype=self.weight.dtype)
|
||||
out = self.crop((self.weight @ x.mT).mT)
|
||||
return out
|
||||
|
||||
class polarimeter(nn.Module):
|
||||
def __init__(self):
|
||||
super(polarimeter, self).__init__()
|
||||
# self.input_length = input_length
|
||||
|
||||
def forward(self, data):
|
||||
# S0 = I
|
||||
# S1 = (2*I_x - I)/I
|
||||
# S2 = (2*I_45 - I)/I
|
||||
# S3 = (2*I_RHC - I)/I
|
||||
|
||||
# # data: (batch, input_length*2) -> (batch, input_length, 2)
|
||||
data = data.view(data.shape[0], -1, 2)
|
||||
x = data[:, :, 0].mean(dim=1)
|
||||
y = data[:, :, 1].mean(dim=1)
|
||||
|
||||
# x = x.mean(dim=1)
|
||||
# y = y.mean(dim=1)
|
||||
|
||||
# angle = torch.atan2(y.abs().square().real, x.abs().square().real)
|
||||
|
||||
# return torch.stack([angle, angle, angle, angle], dim=1)
|
||||
|
||||
# horizontal polarisation
|
||||
I_x = x.abs().square()
|
||||
|
||||
# vertical polarisation
|
||||
I_y = y.abs().square()
|
||||
|
||||
# 45 degree polarisation
|
||||
I_45 = (x + y).abs().square()
|
||||
|
||||
|
||||
# right hand circular polarisation
|
||||
I_RHC = (x + 1j*y).abs().square()
|
||||
|
||||
# S0 = I_x + I_y
|
||||
# S1 = I_x - I_y
|
||||
# S2 = I_45 - I_m45
|
||||
# S3 = I_RHC - I_LHC
|
||||
|
||||
S0 = (I_x + I_y)
|
||||
S1 = ((2*I_x - S0)/S0)
|
||||
S2 = ((2*I_45 - S0)/S0)
|
||||
S3 = ((2*I_RHC - S0)/S0)
|
||||
|
||||
return torch.stack([S0/S0, S1/S0, S2/S0, S3/S0], dim=1)
|
||||
|
||||
class normalize_by_first(nn.Module):
|
||||
def __init__(self):
|
||||
super(normalize_by_first, self).__init__()
|
||||
|
||||
def forward(self, data):
|
||||
return data / data[:, 0].unsqueeze(1)
|
||||
|
||||
class rotate(nn.Module):
|
||||
def __init__(self):
|
||||
super(rotate, self).__init__()
|
||||
|
||||
def forward(self, data, angle):
|
||||
# data -> (batch, n*2)
|
||||
# angle -> (batch, n)
|
||||
data_ = data
|
||||
if angle.ndim == 1:
|
||||
angle_ = angle.unsqueeze(1)
|
||||
else:
|
||||
angle_ = angle
|
||||
angle_ = angle_.expand(-1, data_.shape[1]//2)
|
||||
c = torch.cos(angle_)
|
||||
s = torch.sin(angle_)
|
||||
rot = torch.stack([torch.stack([c, -s], dim=2),
|
||||
torch.stack([s, c], dim=2)], dim=3)
|
||||
d = torch.bmm(data_.reshape(-1, 1, 2), rot.view(-1, 2, 2).to(dtype=data_.dtype)).reshape(*data.shape)
|
||||
# d = torch.bmm(data.unsqueeze(-1).mT, rot.to(dtype=data.dtype).mT).mT.squeeze(-1)
|
||||
|
||||
return d
|
||||
|
||||
|
||||
class photodiode(nn.Module):
|
||||
def __init__(self, size, bias=True):
|
||||
super(photodiode, self).__init__()
|
||||
self.input_dim = size
|
||||
self.scale = nn.Parameter(torch.rand(size))
|
||||
self.pd_bias = nn.Parameter(torch.rand(size))
|
||||
|
||||
def forward(self, x):
|
||||
return x.abs().square().to(dtype=x.dtype.to_real()).mul(self.scale).add(self.pd_bias)
|
||||
|
||||
|
||||
class input_rotator(nn.Module):
|
||||
def __init__(self, input_dim):
|
||||
super(input_rotator, self).__init__()
|
||||
assert input_dim % 2 == 0, "Input dimension must be even"
|
||||
self.input_dim = input_dim
|
||||
# self.angle = nn.Parameter(torch.randn(1, dtype=self.dtype.to_real()))
|
||||
|
||||
def forward(self, x, angle=None):
|
||||
# take channels (0,1), (2,3), ... and rotate them by the angle
|
||||
angle = angle or self.angle
|
||||
sine = torch.sin(angle)
|
||||
cosine = torch.cos(angle)
|
||||
rot = torch.tensor([[cosine, -sine], [sine, cosine]], dtype=self.dtype)
|
||||
return torch.matmul(x.view(-1, 2), rot).view(x.shape)
|
||||
|
||||
|
||||
|
||||
# def __repr__(self):
|
||||
# return f"ONNRect({self.input_dim}, {self.output_dim})"
|
||||
|
||||
@@ -336,8 +441,7 @@ class ONNRect(nn.Module):
|
||||
# return out
|
||||
|
||||
|
||||
#### as defined by zhang et al
|
||||
|
||||
#### as defined by zhang et alas
|
||||
|
||||
class DropoutComplex(nn.Module):
|
||||
def __init__(self, p=0.5):
|
||||
@@ -359,7 +463,7 @@ class Scale(nn.Module):
|
||||
self.scale = nn.Parameter(torch.ones(size, dtype=torch.float32))
|
||||
|
||||
def forward(self, x):
|
||||
return x * self.scale
|
||||
return x * torch.sqrt(self.scale)
|
||||
|
||||
def __repr__(self):
|
||||
return f"Scale({self.size})"
|
||||
@@ -371,11 +475,20 @@ class Identity(nn.Module):
|
||||
M(z) = z
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, size=None):
|
||||
super(Identity, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
class phase_shift(nn.Module):
|
||||
def __init__(self, size):
|
||||
super(phase_shift, self).__init__()
|
||||
self.size = size
|
||||
self.phase = nn.Parameter(torch.rand(size))
|
||||
|
||||
def forward(self, x):
|
||||
return x * torch.exp(1j*self.phase)
|
||||
|
||||
|
||||
class PowRot(nn.Module):
|
||||
@@ -404,54 +517,68 @@ class MZISingle(nn.Module):
|
||||
def forward(self, x: torch.Tensor):
|
||||
return x * torch.exp(1j * self.phi) * torch.sin(self.omega + self.func(x))
|
||||
|
||||
def naive_angle_loss(x: torch.Tensor, target: torch.Tensor, mod=2*torch.pi):
|
||||
return torch.fmod((x.abs().real - target.abs().real), mod).abs().mean()
|
||||
|
||||
def cosine_loss(x: torch.Tensor, target: torch.Tensor):
|
||||
return (2*(1 - torch.cos(x - target))).mean()
|
||||
|
||||
def angle_mse_loss(x: torch.Tensor, target: torch.Tensor):
|
||||
x = torch.fmod(x, 2*torch.pi)
|
||||
target = torch.fmod(target, 2*torch.pi)
|
||||
|
||||
x_cos = torch.cos(x)
|
||||
x_sin = torch.sin(x)
|
||||
target_cos = torch.cos(target)
|
||||
target_sin = torch.sin(target)
|
||||
|
||||
cos_diff = x_cos - target_cos
|
||||
sin_diff = x_sin - target_sin
|
||||
squared_diff = cos_diff**2 + sin_diff**2
|
||||
return squared_diff.mean()
|
||||
|
||||
class EOActivation(nn.Module):
|
||||
def __init__(self, bias, size=None):
|
||||
# 10.1109/SiPhotonics60897.2024.10543376
|
||||
def __init__(self, size=None):
|
||||
# 10.1109/JSTQE.2019.2930455
|
||||
super(EOActivation, self).__init__()
|
||||
if size is None:
|
||||
raise ValueError("Size must be specified")
|
||||
self.size = size
|
||||
self.alpha = nn.Parameter(torch.ones(size))
|
||||
self.V_bias = nn.Parameter(torch.ones(size))
|
||||
self.gain = nn.Parameter(torch.ones(size))
|
||||
# if bias:
|
||||
# self.phase_bias = nn.Parameter(torch.zeros(size))
|
||||
# else:
|
||||
# self.register_buffer("phase_bias", torch.zeros(size))
|
||||
self.register_buffer("phase_bias", torch.clamp(torch.ones(size) + torch.randn(size)*0.1, 0, 1)*torch.pi)
|
||||
self.register_buffer("responsivity", torch.ones(size)*0.9)
|
||||
self.register_buffer("V_pi", torch.ones(size)*3)
|
||||
self.alpha = nn.Parameter(torch.rand(size))
|
||||
self.gain = nn.Parameter(torch.rand(size))
|
||||
self.V_bias = nn.Parameter(torch.rand(size))
|
||||
# self.register_buffer("gain", torch.ones(size))
|
||||
# self.register_buffer("responsivity", torch.ones(size))
|
||||
# self.register_buffer("V_pi", torch.ones(size))
|
||||
|
||||
self.reset_weights()
|
||||
|
||||
def reset_weights(self):
|
||||
if "alpha" in self._parameters:
|
||||
self.alpha.data = torch.ones(self.size)*0.5
|
||||
if "V_pi" in self._parameters:
|
||||
self.V_pi.data = torch.ones(self.size)*3
|
||||
self.alpha.data = torch.rand(self.size)
|
||||
# if "V_pi" in self._parameters:
|
||||
# self.V_pi.data = torch.rand(self.size)*3
|
||||
if "V_bias" in self._parameters:
|
||||
self.V_bias.data = torch.zeros(self.size)
|
||||
self.V_bias.data = torch.randn(self.size)
|
||||
if "gain" in self._parameters:
|
||||
self.gain.data = torch.ones(self.size)
|
||||
if "responsivity" in self._parameters:
|
||||
self.responsivity.data = torch.ones(self.size)*0.9
|
||||
if "bias" in self._parameters:
|
||||
self.phase_bias.data = torch.zeros(self.size)
|
||||
self.gain.data = torch.rand(self.size)
|
||||
# if "responsivity" in self._parameters:
|
||||
# self.responsivity.data = torch.ones(self.size)*0.9
|
||||
# if "bias" in self._parameters:
|
||||
# self.phase_bias.data = torch.zeros(self.size)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
phi_b = torch.pi * self.V_bias / (self.V_pi + 1e-8)
|
||||
g_phi = torch.pi * (self.alpha * self.gain * self.responsivity) / (self.V_pi + 1e-8)
|
||||
phi_b = torch.pi * self.V_bias# / (self.V_pi)
|
||||
g_phi = torch.pi * (self.alpha * self.gain)# * self.responsivity)# / (self.V_pi)
|
||||
intermediate = g_phi * x.abs().square() + phi_b
|
||||
return (
|
||||
1j
|
||||
* torch.sqrt(1 - self.alpha)
|
||||
* torch.exp(-0.5j * (intermediate + self.phase_bias))
|
||||
* torch.exp(-0.5j * intermediate)
|
||||
* torch.cos(0.5 * intermediate)
|
||||
* x
|
||||
)
|
||||
|
||||
|
||||
class Pow(nn.Module):
|
||||
"""
|
||||
implements the activation function
|
||||
@@ -574,6 +701,7 @@ class ZReLU(nn.Module):
|
||||
__all__ = [
|
||||
complex_sse_loss,
|
||||
complex_mse_loss,
|
||||
angle_mse_loss,
|
||||
UnitaryLayer,
|
||||
unitary,
|
||||
energy_conserving,
|
||||
@@ -590,6 +718,8 @@ __all__ = [
|
||||
ZReLU,
|
||||
MZISingle,
|
||||
EOActivation,
|
||||
photodiode,
|
||||
phase_shift,
|
||||
# SaturableAbsorberLambertW,
|
||||
# SaturableAbsorber,
|
||||
# SpreadLayer,
|
||||
|
||||
105
src/single-core-regen/util/core.py
Normal file
105
src/single-core-regen/util/core.py
Normal file
@@ -0,0 +1,105 @@
|
||||
# Copyright (c) 2015, Warren Weckesser. All rights reserved.
|
||||
# This software is licensed according to the "BSD 2-clause" license.
|
||||
|
||||
import hashlib
|
||||
import h5py
|
||||
import numpy as _np
|
||||
from scipy.interpolate import interp1d as _interp1d
|
||||
from scipy.ndimage import gaussian_filter as _gaussian_filter
|
||||
from ._brescount import bres_curve_count as _bres_curve_count
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
__all__ = ['grid_count']
|
||||
|
||||
|
||||
def grid_count(y, window_size, offset=0, size=None, fuzz=True, blur=0, bounds=None):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
`y` is the 1-d array of signal samples.
|
||||
|
||||
`window_size` is the number of samples to show horizontally in the
|
||||
eye diagram. Typically this is twice the number of samples in a
|
||||
"symbol" (i.e. in a data bit).
|
||||
|
||||
`offset` is the number of initial samples to skip before computing
|
||||
the eye diagram. This allows the overall phase of the diagram to
|
||||
be adjusted.
|
||||
|
||||
`size` must be a tuple of two integers. It sets the size of the
|
||||
array of counts, (height, width). The default is (800, 640).
|
||||
|
||||
`fuzz`: If True, the values in `y` are reinterpolated with a
|
||||
random "fuzz factor" before plotting in the eye diagram. This
|
||||
reduces an aliasing-like effect that arises with the use of
|
||||
Bresenham's algorithm.
|
||||
|
||||
`bounds` must be a tuple of two floating point values, (ymin, ymax).
|
||||
These set the y range of the returned array. If not given, the
|
||||
bounds are `(y.min() - 0.05*A, y.max() + 0.05*A)`, where `A` is
|
||||
`y.max() - y.min()`.
|
||||
|
||||
Return Value
|
||||
------------
|
||||
Returns a numpy array of integers.
|
||||
|
||||
"""
|
||||
# hash input params
|
||||
param_ob = (y, window_size, offset, size, fuzz, blur, bounds)
|
||||
param_hash = hashlib.md5(str(param_ob).encode()).hexdigest()
|
||||
cache_dir = Path.home()/".eyediagram"/".cache"
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
if (cache_dir/param_hash).is_file():
|
||||
try:
|
||||
with h5py.File(cache_dir/param_hash, "r") as infile:
|
||||
counts = infile["counts"][:]
|
||||
if counts.len() != 0:
|
||||
return counts
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
|
||||
if size is None:
|
||||
size = (800, 640)
|
||||
height, width = size
|
||||
dt = width / window_size
|
||||
counts = _np.zeros((width, height), dtype=_np.int32)
|
||||
|
||||
if bounds is None:
|
||||
ymin = y.min()
|
||||
ymax = y.max()
|
||||
yamp = ymax - ymin
|
||||
ymin = ymin - 0.05*yamp
|
||||
ymax = ymax + 0.05*yamp
|
||||
ymax = _np.ceil(ymax*10)/10
|
||||
ymin = _np.floor(ymin*10)/10
|
||||
else:
|
||||
ymin, ymax = bounds
|
||||
|
||||
start = offset
|
||||
while start + window_size < len(y):
|
||||
end = start + window_size
|
||||
yy = y[start:end+1]
|
||||
k = _np.arange(len(yy))
|
||||
xx = dt*k
|
||||
if fuzz:
|
||||
f = _interp1d(xx, yy, kind='cubic')
|
||||
jiggle = dt*(_np.random.beta(a=3, b=3, size=len(xx)-2) - 0.5)
|
||||
xx[1:-1] += jiggle
|
||||
yd = f(xx)
|
||||
else:
|
||||
yd = yy
|
||||
iyd = (height * (yd - ymin)/(ymax - ymin)).astype(_np.int32)
|
||||
_bres_curve_count(xx.astype(_np.int32), iyd, counts)
|
||||
|
||||
start = end
|
||||
|
||||
if blur != 0:
|
||||
counts = _gaussian_filter(counts, sigma=blur)
|
||||
|
||||
with h5py.File(cache_dir/param_hash, "w") as outfile:
|
||||
outfile.create_dataset("data", data=counts)
|
||||
|
||||
return counts
|
||||
@@ -1,10 +1,12 @@
|
||||
from pathlib import Path
|
||||
import h5py
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
# from torch.utils.data import Sampler
|
||||
import numpy as np
|
||||
import configparser
|
||||
import multiprocessing as mp
|
||||
|
||||
# class SubsetSampler(Sampler[int]):
|
||||
# """
|
||||
@@ -24,7 +26,22 @@ import configparser
|
||||
# return len(self.indices)
|
||||
|
||||
|
||||
def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=False, device=None, dtype=None):
|
||||
def load_from_file(datapath):
|
||||
if str(datapath).endswith(".h5"):
|
||||
symbols = None
|
||||
with h5py.File(datapath, "r") as infile:
|
||||
data = infile["data"][:]
|
||||
try:
|
||||
symbols = np.swapaxes(infile["symbols"][:], 0, 1)
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
symbols = None
|
||||
data = np.load(datapath)
|
||||
return data, symbols
|
||||
|
||||
|
||||
def load_data(config_path, skipfirst=0, skiplast=0, symbols=None, real=False, normalize=1, device=None, dtype=None):
|
||||
filepath = Path(config_path)
|
||||
filepath = filepath.parent.glob(filepath.name)
|
||||
config = configparser.ConfigParser()
|
||||
@@ -40,25 +57,39 @@ def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=Fals
|
||||
if symbols is None:
|
||||
symbols = int(config["glova"]["nos"]) - skipfirst
|
||||
|
||||
data = np.load(datapath)[int(skipfirst * sps) : int(symbols * sps + skipfirst * sps)]
|
||||
timestamps = np.arange(int(skipfirst * sps), int(symbols * sps + skipfirst * sps))
|
||||
data, orig_symbols = load_from_file(datapath)
|
||||
|
||||
if normalize:
|
||||
# square gets normalized to 1, as the power is (proportional to) the square of the amplitude
|
||||
a, b, c, d = np.square(data.T)
|
||||
a, b, c, d = a / np.max(np.abs(a)), b / np.max(np.abs(b)), c / np.max(np.abs(c)), d / np.max(np.abs(d))
|
||||
data = np.sqrt(np.array([a, b, c, d]).T)
|
||||
data = data[int(skipfirst * sps) : int(symbols * sps + skipfirst * sps - skiplast * sps)]
|
||||
orig_symbols = orig_symbols[skipfirst : symbols + skipfirst - skiplast]
|
||||
timestamps = np.arange(int(skipfirst * sps), int(symbols * sps + skipfirst * sps - skiplast * sps))
|
||||
|
||||
data *= np.sqrt(normalize)
|
||||
|
||||
launch_power = float(config["signal"]["laser_power"])
|
||||
output_power = float(config["signal"]["edfa_power"])
|
||||
|
||||
target_normalization = 10 ** (output_power / 10) / 10 ** (launch_power / 10)
|
||||
# target_normalization *= 0.5 # allow 50% power loss, so the network can ignore parts of the signal
|
||||
|
||||
data[:, 0:2] *= np.sqrt(target_normalization)
|
||||
|
||||
# if normalize:
|
||||
# # square gets normalized to 1, as the power is (proportional to) the square of the amplitude
|
||||
# a, b, c, d = data.T
|
||||
# a, b, c, d = a - np.min(np.abs(a)), b - np.min(np.abs(b)), c - np.min(np.abs(c)), d - np.min(np.abs(d))
|
||||
# a, b, c, d = a / np.max(np.abs(a)), b / np.max(np.abs(b)), c / np.max(np.abs(c)), d / np.max(np.abs(d))
|
||||
# data = np.array([a, b, c, d]).T
|
||||
|
||||
if real:
|
||||
data = np.abs(data)
|
||||
|
||||
config["glova"]["nos"] = str(symbols)
|
||||
|
||||
data = np.concatenate([data, timestamps.reshape(-1,1)], axis=-1)
|
||||
data = np.concatenate([data, timestamps.reshape(-1, 1)], axis=-1)
|
||||
|
||||
data = torch.tensor(data, device=device, dtype=dtype)
|
||||
|
||||
return data, config
|
||||
return data, config, orig_symbols
|
||||
|
||||
|
||||
def roll_along(arr, shifts, dim):
|
||||
@@ -110,9 +141,15 @@ class FiberRegenerationDataset(Dataset):
|
||||
target_delay: float | int = 0,
|
||||
xy_delay: float | int = 0,
|
||||
drop_first: float | int = 0,
|
||||
drop_last=0,
|
||||
dtype: torch.dtype = None,
|
||||
real: bool = False,
|
||||
device=None,
|
||||
# osnr: float|None = None,
|
||||
polarisations=None,
|
||||
randomise_polarisations: bool = False,
|
||||
repeat_randoms: int = 1,
|
||||
# cross_pol_interference: float = 0,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -145,50 +182,54 @@ class FiberRegenerationDataset(Dataset):
|
||||
assert output_dim is None or output_dim > 0, "output_len must be positive or None"
|
||||
assert drop_first >= 0, "drop_first must be non-negative"
|
||||
|
||||
faux = kwargs.pop("faux", False)
|
||||
self.randomise_polarisations = randomise_polarisations
|
||||
# self.cross_pol_interference = cross_pol_interference
|
||||
|
||||
if faux:
|
||||
data_raw = np.array(
|
||||
[[i + 0.1j, i + 0.2j, i + 1.1j, i + 1.2j] for i in range(12800)],
|
||||
dtype=np.complex128,
|
||||
data_raw = None
|
||||
self.config = None
|
||||
files = []
|
||||
self.orig_symbols = None
|
||||
for file_path in file_path if isinstance(file_path, (tuple, list)) else [file_path]:
|
||||
data, config, orig_syms = load_data(
|
||||
file_path,
|
||||
skipfirst=drop_first,
|
||||
skiplast=drop_last,
|
||||
symbols=kwargs.get("num_symbols", None),
|
||||
real=real,
|
||||
normalize=1000,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
data_raw = torch.tensor(data_raw, device=device, dtype=dtype)
|
||||
timestamps = torch.arange(12800)
|
||||
|
||||
data_raw = torch.concatenate([data_raw, timestamps.reshape(-1, 1)], axis=-1)
|
||||
|
||||
self.config = {
|
||||
"data": {"dir": '"."', "npy_dir": '"."', "file": "faux"},
|
||||
"glova": {"sps": 128},
|
||||
}
|
||||
else:
|
||||
data_raw = None
|
||||
self.config = None
|
||||
files = []
|
||||
for file_path in (file_path if isinstance(file_path, (tuple, list)) else [file_path]):
|
||||
data, config = load_data(
|
||||
file_path,
|
||||
skipfirst=drop_first,
|
||||
symbols=kwargs.get("num_symbols", None),
|
||||
real=real,
|
||||
normalize=True,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
if data_raw is None:
|
||||
data_raw = data
|
||||
if orig_syms is not None:
|
||||
if self.orig_symbols is None:
|
||||
self.orig_symbols = orig_syms
|
||||
else:
|
||||
data_raw = torch.cat([data_raw, data], dim=0)
|
||||
if self.config is None:
|
||||
self.config = config
|
||||
else:
|
||||
assert self.config["glova"]["sps"] == config["glova"]["sps"], "samples per symbol must be the same"
|
||||
files.append(config["data"]["file"].strip('"'))
|
||||
self.config["data"]["file"] = str(files)
|
||||
|
||||
self.orig_symbols = np.concat((self.orig_symbols, orig_syms), axis=-1)
|
||||
|
||||
if data_raw is None:
|
||||
data_raw = data
|
||||
else:
|
||||
data_raw = torch.cat([data_raw, data], dim=0)
|
||||
if self.config is None:
|
||||
self.config = config
|
||||
else:
|
||||
assert self.config["glova"]["sps"] == config["glova"]["sps"], "samples per symbol must be the same"
|
||||
files.append(config["data"]["file"].strip('"'))
|
||||
self.config["data"]["file"] = str(files)
|
||||
|
||||
# if polarisations is not None:
|
||||
# data_raw_clone = data_raw.clone()
|
||||
# # rotate the polarisation by 180 degrees
|
||||
# data_raw_clone[2, :] *= -1
|
||||
# data_raw_clone[3, :] *= -1
|
||||
# data_raw = torch.cat([data_raw, data_raw_clone], dim=0)
|
||||
|
||||
self.polarisations = bool(polarisations)
|
||||
|
||||
self.device = data_raw.device
|
||||
|
||||
self.samples_per_symbol = int(self.config["glova"]["sps"])
|
||||
# self.num_symbols = int(self.config["glova"]["nos"])
|
||||
self.samples_per_slice = int(symbols * self.samples_per_symbol)
|
||||
self.symbols_per_slice = self.samples_per_slice / self.samples_per_symbol
|
||||
|
||||
@@ -258,17 +299,98 @@ class FiberRegenerationDataset(Dataset):
|
||||
elif self.target_delay_samples < 0:
|
||||
data_raw = data_raw[:, : self.target_delay_samples]
|
||||
|
||||
timestamps = data_raw[-1, :]
|
||||
data_raw = data_raw[:-1, :]
|
||||
timestamps = data_raw[4, :]
|
||||
data_raw = data_raw[:4, :]
|
||||
data_raw = data_raw.view(2, 2, -1)
|
||||
timestamps_doubled = torch.cat([timestamps.unsqueeze(dim=0), timestamps.unsqueeze(dim=0)], dim=0).unsqueeze(dim=1)
|
||||
data_raw = torch.cat([data_raw, timestamps_doubled], dim=1)
|
||||
fiber_in = data_raw[0, :, :]
|
||||
fiber_out = data_raw[1, :, :]
|
||||
# timestamps_doubled = torch.cat([timestamps.unsqueeze(dim=0), timestamps.unsqueeze(dim=0)], dim=0).unsqueeze(
|
||||
# dim=1
|
||||
# )
|
||||
fiber_in = torch.cat([fiber_in, timestamps.unsqueeze(0)], dim=0)
|
||||
fiber_out = torch.cat([fiber_out, timestamps.unsqueeze(0)], dim=0)
|
||||
|
||||
# fiber_out: [E_out_x, E_out_y, timestamps]
|
||||
|
||||
# add noise related to amplification necessary due to splitting of the signal
|
||||
# gain_lin = output_dim*2
|
||||
# gain_lin = 1
|
||||
# edfa_nf = float(self.config["signal"]["edfa_nf"])
|
||||
# nf_lin = 10**(edfa_nf/10)
|
||||
# f0 = float(self.config["glova"]["f0"])
|
||||
|
||||
# noise_add = (gain_lin-1)*nf_lin*6.626e-34*f0*12.5e9
|
||||
|
||||
# noise = torch.randn_like(fiber_out[:2, :])
|
||||
# noise_power = torch.sum(torch.mean(noise.abs().square().squeeze(), dim=-1))
|
||||
# noise = noise * torch.sqrt(noise_add / noise_power)
|
||||
# fiber_out[:2, :] += noise
|
||||
|
||||
# if osnr is None:
|
||||
# noisy = fiber_out[:2, :]
|
||||
# else:
|
||||
# noisy = self.add_noise(fiber_out[:2, :], osnr)
|
||||
|
||||
# fiber_out = torch.cat([fiber_out, noisy], dim=0)
|
||||
|
||||
# fiber_out: [E_out_x, E_out_y, timestamps, E_out_x_noisy, E_out_y_noisy]
|
||||
|
||||
if repeat_randoms > 1:
|
||||
fiber_in = fiber_in.repeat(1, 1, repeat_randoms)
|
||||
fiber_out = fiber_out.repeat(1, 1, repeat_randoms)
|
||||
# review: potential problems with repeated timestamps when plotting
|
||||
else:
|
||||
repeat_randoms = 1
|
||||
|
||||
if self.randomise_polarisations:
|
||||
angles = torch.fmod(torch.randperm(data_raw.shape[-1]*repeat_randoms, device=fiber_out.device), 2) * torch.pi
|
||||
start_angle = torch.rand(1) * 2 * torch.pi
|
||||
angles = start_angle + torch.cumsum(torch.randn(data_raw.shape[-1])/333, dim=0) * torch.pi * 2 # random walk
|
||||
angles = torch.randn(data_raw.shape[-1], device=fiber_out.device) * 2*torch.pi / 36 # sigma = 10 degrees
|
||||
# self.angles = torch.rand(self.data.shape[0]) * 2 * torch.pi
|
||||
else:
|
||||
angles = torch.zeros(data_raw.shape[-1], device=fiber_out.device)
|
||||
|
||||
sin = torch.sin(angles)
|
||||
cos = torch.cos(angles)
|
||||
rot = torch.stack([torch.stack([cos, -sin], dim=1), torch.stack([sin, cos], dim=1)], dim=2)
|
||||
data_rot = torch.bmm(fiber_out[:2, :].T.unsqueeze(1), rot.to(dtype=data_raw.dtype)).squeeze(1).T
|
||||
# data_rot_noisy = torch.bmm(fiber_out[3:5, :].T.unsqueeze(1), rot.to(dtype=data_raw.dtype)).squeeze(1).T
|
||||
fiber_out = torch.cat((fiber_out, data_rot), dim=0)
|
||||
fiber_out = torch.cat([fiber_out, angles.unsqueeze(0)], dim=0)
|
||||
|
||||
# fiber_in:
|
||||
# 0 E_in_x,
|
||||
# 1 E_in_y,
|
||||
# 2 timestamps
|
||||
|
||||
# fiber_out:
|
||||
# 0 E_out_x,
|
||||
# 1 E_out_y,
|
||||
# 2 timestamps,
|
||||
# 3 E_out_x_rot,
|
||||
# 4 E_out_y_rot,
|
||||
# 5 angle
|
||||
|
||||
# data_raw = torch.cat([data_raw, timestamps_doubled], dim=1)
|
||||
# data layout
|
||||
# [ [E_in_x, E_in_y, timestamps],
|
||||
# [E_out_x, E_out_y, timestamps] ]
|
||||
|
||||
self.data = data_raw.unfold(dimension=-1, size=self.samples_per_slice, step=1)
|
||||
self.data = self.data.movedim(-2, 0)
|
||||
self.fiber_in = fiber_in.unfold(dimension=-1, size=self.samples_per_slice, step=1)
|
||||
self.fiber_in = self.fiber_in.movedim(-2, 0)
|
||||
|
||||
self.fiber_out = fiber_out.unfold(dimension=-1, size=self.samples_per_slice, step=1)
|
||||
self.fiber_out = self.fiber_out.movedim(-2, 0)
|
||||
|
||||
# if self.randomise_polarisations:
|
||||
# self.angles = torch.cumsum((torch.rand(self.fiber_out.shape[0]) - 0.5) * 2 * torch.pi * 2 / 5000, dim=0)
|
||||
|
||||
# self.data = data_raw.unfold(dimension=-1, size=self.samples_per_slice, step=1)
|
||||
# self.data = self.data.movedim(-2, 0)
|
||||
# self.angles = torch.zeros(self.data.shape[0])
|
||||
...
|
||||
# ...
|
||||
# -> [no_slices, 2, 3, samples_per_slice]
|
||||
|
||||
# data layout
|
||||
@@ -278,33 +400,160 @@ class FiberRegenerationDataset(Dataset):
|
||||
# ...
|
||||
# ] -> [no_slices, 2, 3, samples_per_slice]
|
||||
|
||||
...
|
||||
|
||||
def __len__(self):
|
||||
return self.data.shape[0]
|
||||
return self.fiber_in.shape[0]
|
||||
|
||||
def add_noise(self, data, osnr):
|
||||
osnr_lin = 10 ** (osnr / 10)
|
||||
popt = torch.mean(data.abs().square().squeeze(), dim=-1)
|
||||
noise = torch.randn_like(data)
|
||||
pn = torch.mean(noise.abs().square().squeeze(), dim=-1)
|
||||
|
||||
mult = torch.sqrt(popt / (pn * osnr_lin))
|
||||
mult = mult * torch.eye(popt.shape[0], device=mult.device)
|
||||
mult = mult.to(dtype=noise.dtype)
|
||||
|
||||
noise = mult @ noise
|
||||
pn = torch.mean(noise.abs().square().squeeze(), dim=-1)
|
||||
noisy = data + noise
|
||||
return noisy
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if isinstance(idx, slice):
|
||||
return [self.__getitem__(i) for i in range(*idx.indices(len(self)))]
|
||||
else:
|
||||
data_slice = self.data[idx].squeeze()
|
||||
# fiber in: [E_in_x, E_in_y, timestamps]
|
||||
# fiber out: [E_out_x, E_out_y, timestamps, E_out_x_rot, E_out_y_rot, angle]
|
||||
|
||||
data_slice = data_slice[:, :, :data_slice.shape[2] // self.output_dim * self.output_dim]
|
||||
# if self.polarisations:
|
||||
output_dim = self.output_dim // 2
|
||||
self.output_dim = output_dim * 2
|
||||
|
||||
data_slice = data_slice.view(data_slice.shape[0], data_slice.shape[1], self.output_dim, -1)
|
||||
|
||||
target = data_slice[0, :, self.output_dim//2, 0]
|
||||
data = data_slice[1, :, :, 0]
|
||||
|
||||
# data_timestamps = data[-1,:].real
|
||||
data = data[:-1, :]
|
||||
target_timestamp = target[-1].real
|
||||
target = target[:-1]
|
||||
if not self.polarisations:
|
||||
output_dim = 2 * output_dim
|
||||
|
||||
|
||||
fiber_in = self.fiber_in[idx].squeeze()
|
||||
fiber_out = self.fiber_out[idx].squeeze()
|
||||
|
||||
fiber_in = fiber_in[..., : fiber_in.shape[-1] // output_dim * output_dim]
|
||||
fiber_out = fiber_out[..., : fiber_out.shape[-1] // output_dim * output_dim]
|
||||
|
||||
fiber_in = fiber_in.view(fiber_in.shape[0], output_dim, -1)
|
||||
fiber_out = fiber_out.view(fiber_out.shape[0], output_dim, -1)
|
||||
|
||||
center_angle = fiber_out[5, output_dim // 2, 0]
|
||||
angles = fiber_out[5, :, 0]
|
||||
plot_data_rot = fiber_out[3:5, output_dim // 2, 0].detach().clone()
|
||||
data = fiber_out[0:2, :, 0]
|
||||
plot_data = fiber_out[0:2, output_dim // 2, 0].detach().clone()
|
||||
|
||||
target = fiber_in[:2, output_dim // 2, 0]
|
||||
plot_target = fiber_in[:2, output_dim // 2, 0].detach().clone()
|
||||
target_timestamp = fiber_in[2, output_dim // 2, 0].real
|
||||
...
|
||||
|
||||
if self.polarisations:
|
||||
rot = int(np.random.randint(2) * 2 - 1)
|
||||
data = rot * data
|
||||
target = rot * target
|
||||
plot_data_rot = rot * plot_data_rot
|
||||
center_angle = center_angle + (rot - 1) * torch.pi / 2 # rot: -1 or 1 -> -2 or 0 -> -pi or 0
|
||||
angles = angles + (rot - 1) * torch.pi / 2
|
||||
|
||||
pol_flipped_data = -data
|
||||
pol_flipped_target = -target
|
||||
|
||||
# transpose to interleave the x and y data in the output tensor
|
||||
data = data.transpose(0, 1).flatten().squeeze()
|
||||
data = data / torch.sqrt(torch.ones(1) * len(data)) # power loss due to splitting
|
||||
pol_flipped_data = pol_flipped_data.transpose(0, 1).flatten().squeeze()
|
||||
pol_flipped_data = pol_flipped_data / torch.sqrt(
|
||||
torch.ones(1) * len(pol_flipped_data)
|
||||
) # power loss due to splitting
|
||||
# angle_data = angle_data.transpose(0, 1).flatten().squeeze()
|
||||
# angle_data2 = angle_data2.transpose(0,1).flatten().squeeze()
|
||||
center_angle = center_angle.flatten().squeeze()
|
||||
angles = angles.flatten().squeeze()
|
||||
# data_timestamps = data_timestamps.flatten().squeeze()
|
||||
# target = target.transpose(0,1).flatten().squeeze()
|
||||
target = target.flatten().squeeze()
|
||||
pol_flipped_target = pol_flipped_target.flatten().squeeze()
|
||||
target_timestamp = target_timestamp.flatten().squeeze()
|
||||
plot_target = plot_target.flatten().squeeze()
|
||||
plot_data = plot_data.flatten().squeeze()
|
||||
plot_data_rot = plot_data_rot.flatten().squeeze()
|
||||
|
||||
return data, target, target_timestamp
|
||||
return {
|
||||
"x": data,
|
||||
"x_flipped": pol_flipped_data,
|
||||
"x_stacked": torch.cat([data, pol_flipped_data], dim=-1),
|
||||
"y": target,
|
||||
"y_flipped": pol_flipped_target,
|
||||
"y_stacked": torch.cat([target, pol_flipped_target], dim=-1),
|
||||
"center_angle": center_angle,
|
||||
"angles": angles,
|
||||
"mean_angle": angles.mean(),
|
||||
# "sop": sop,
|
||||
# "angle_data": angle_data,
|
||||
# "angle_data2": angle_data2,
|
||||
"timestamp": target_timestamp,
|
||||
"plot_target": plot_target,
|
||||
"plot_data": plot_data,
|
||||
"plot_data_rot": plot_data_rot,
|
||||
# "plot_clean": fiber_out_plot_clean,
|
||||
}
|
||||
|
||||
def complex_max(self, data, dim=-1):
|
||||
# returns element(s) with the maximum absolute value along a given dimension
|
||||
# ind = torch.argmax(data.abs(), dim=dim, keepdim=True)
|
||||
# max_values = torch.gather(data, dim, ind).squeeze(dim=dim)
|
||||
# return max_values
|
||||
return torch.gather(data, dim, torch.argmax(data.abs(), dim=dim, keepdim=True)).squeeze(dim=dim)
|
||||
|
||||
def rotate(self, data, angle):
|
||||
# rotates a 2d tensor by a given angle
|
||||
# data: [2, ...]
|
||||
# angle: [1]
|
||||
# returns: [2, ...]
|
||||
|
||||
# get sine and cosine of the angle
|
||||
sine = torch.sin(angle)
|
||||
cosine = torch.cos(angle)
|
||||
|
||||
return torch.stack([data[0] * cosine - data[1] * sine, data[0] * sine + data[1] * cosine], dim=0)
|
||||
|
||||
def rotate_all(self):
|
||||
def do_rotation(j, num_processes):
|
||||
for i in range(len(self) // num_processes):
|
||||
index = i * num_processes + j
|
||||
self.data[index, 1, :2, :] = self.rotate(self.data[index, 1, :2, :], self.angles[index])
|
||||
|
||||
self.processes = []
|
||||
|
||||
for j in range(mp.cpu_count()):
|
||||
self.processes.append(mp.Process(target=do_rotation, args=(j, mp.cpu_count())))
|
||||
self.processes[-1].start()
|
||||
|
||||
for p in self.processes:
|
||||
p.join()
|
||||
|
||||
for i in range(len(self) // mp.cpu_count() * mp.cpu_count(), len(self)):
|
||||
self.data[i, 1, :2, :] = self.rotate(self.data[i, 1, :2, :], self.angles[i])
|
||||
|
||||
def polarimeter(self, data):
|
||||
# data: [2, ...] -> x, y
|
||||
# returns [4] -> S0, S1, S2, S3
|
||||
x = data[0].mean()
|
||||
y = data[1].mean()
|
||||
I_X = x.abs().square()
|
||||
I_Y = y.abs().square()
|
||||
I_45 = (x + y).abs().square()
|
||||
I_RHC = (x + 1j * y).abs().square()
|
||||
|
||||
S0 = I_X + I_Y
|
||||
S1 = (2 * I_X - S0) / S0
|
||||
S2 = (2 * I_45 - S0) / S0
|
||||
S3 = (2 * I_RHC - S0) / S0
|
||||
|
||||
return torch.stack([S0, S1, S2, S3], dim=0)
|
||||
|
||||
@@ -1,15 +1,23 @@
|
||||
from datetime import datetime
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
import h5py
|
||||
from matplotlib import pyplot as plt
|
||||
from matplotlib.colors import LinearSegmentedColormap
|
||||
# from cmap import Colormap as cm
|
||||
import numpy as np
|
||||
from scipy.cluster.vq import kmeans2
|
||||
import warnings
|
||||
import multiprocessing
|
||||
|
||||
from rich.traceback import install
|
||||
from rich import pretty
|
||||
from rich import print
|
||||
|
||||
install()
|
||||
pretty.install()
|
||||
# from rich import pretty
|
||||
# from rich import print
|
||||
|
||||
# pretty.install()
|
||||
|
||||
|
||||
def generate_sample_data(n_symbols=1000, sps=128, noise=0.01, skew=1):
|
||||
@@ -20,6 +28,7 @@ def generate_sample_data(n_symbols=1000, sps=128, noise=0.01, skew=1):
|
||||
xaxis = np.arange(0, len(signal)) / sps
|
||||
return np.vstack([xaxis, signal])
|
||||
|
||||
|
||||
def create_symbol_sequence(n_symbols, skew=1):
|
||||
np.random.seed(42)
|
||||
data = np.random.randint(0, 4, n_symbols) / 4
|
||||
@@ -38,6 +47,14 @@ def generate_signal(data, sps):
|
||||
signal = np.convolve(data_padded, wavelet)
|
||||
signal = np.cumsum(signal)
|
||||
signal = signal[sps // 2 + len(wavelet) // 2 - 1 : -len(wavelet) // 2]
|
||||
mi, ma = np.min(signal), np.max(signal)
|
||||
|
||||
signal = (signal - mi) / (ma - mi)
|
||||
|
||||
mod = 0.8
|
||||
|
||||
signal *= mod
|
||||
signal += 1 - mod
|
||||
|
||||
return signal
|
||||
|
||||
@@ -48,8 +65,8 @@ def normalization_with_noise(signal, noise=0):
|
||||
signal += awgn
|
||||
|
||||
# min-max normalization
|
||||
signal = signal - np.min(signal)
|
||||
signal = signal / np.max(signal)
|
||||
# signal = signal - np.min(signal)
|
||||
# signal = signal / np.max(signal)
|
||||
return signal
|
||||
|
||||
|
||||
@@ -67,197 +84,424 @@ def generate_wavelet(sps, oversample=3):
|
||||
|
||||
|
||||
class eye_diagram:
|
||||
def __init__(self, data, *, channel_names=None, horizontal_bins=256, vertical_bins=1000, n_levels=4):
|
||||
def __init__(
|
||||
self,
|
||||
data,
|
||||
*,
|
||||
channel_names=None,
|
||||
horizontal_bins=256,
|
||||
vertical_bins=1000,
|
||||
n_levels=4,
|
||||
multithreaded=True,
|
||||
save_file_or_dir=None,
|
||||
):
|
||||
# data has shape [channels, 2, samples]
|
||||
# each sample has a timestamp and a value
|
||||
if data.ndim == 2:
|
||||
data = data[np.newaxis, :, :]
|
||||
self.channel_names = channel_names
|
||||
self.raw_data = data
|
||||
self.channels = data.shape[0]
|
||||
|
||||
self.y_bins = np.zeros(1)
|
||||
self.x_bins = np.zeros(1)
|
||||
self.eye_data = np.zeros(1)
|
||||
self.channel_names = channel_names
|
||||
self.n_channels = data.shape[0]
|
||||
self.n_levels = n_levels
|
||||
self.eye_stats = [{"success": False} for _ in range(self.channels)]
|
||||
self.eye_stats = [{"success": False} for _ in range(self.n_channels)]
|
||||
self.horizontal_bins = horizontal_bins
|
||||
self.vertical_bins = vertical_bins
|
||||
self.multi_threaded = multithreaded
|
||||
self.analysed = False
|
||||
self.eye_built = False
|
||||
self.analyse(self.n_levels)
|
||||
|
||||
def generate_eye_data(self):
|
||||
self.x_bins = np.linspace(0, 2, self.horizontal_bins, endpoint=False)
|
||||
self.y_bins = np.zeros((self.channels, self.vertical_bins))
|
||||
self.eye_data = np.zeros((self.channels, self.vertical_bins, self.horizontal_bins))
|
||||
for i in range(self.channels):
|
||||
data_min = np.min(self.raw_data[i, 1, :])
|
||||
data_max = np.max(self.raw_data[i, 1, :])
|
||||
self.y_bins[i] = np.linspace(data_min, data_max, self.vertical_bins, endpoint=False)
|
||||
|
||||
t_vals = self.raw_data[i, 0, :] % 2
|
||||
val_vals = self.raw_data[i, 1, :]
|
||||
|
||||
x_indices = np.digitize(t_vals, self.x_bins) - 1
|
||||
y_indices = np.digitize(val_vals, self.y_bins[i]) - 1
|
||||
|
||||
np.add.at(self.eye_data[i], (y_indices, x_indices), 1)
|
||||
self.eye_built = True
|
||||
self.save_file = save_file_or_dir
|
||||
|
||||
def load_data(self, file=None):
|
||||
file = self.save_file if file is None else file
|
||||
|
||||
if file is None:
|
||||
raise FileNotFoundError("No file specified.")
|
||||
|
||||
self.save_file = str(file)
|
||||
# self.file_or_dir = self.save_file
|
||||
with h5py.File(file, "r") as infile:
|
||||
self.y_bins = infile["y_bins"][:]
|
||||
self.x_bins = infile["x_bins"][:]
|
||||
self.eye_data = infile["eye_data"][:]
|
||||
self.channel_names = infile.attrs["channel_names"]
|
||||
self.n_channels = infile.attrs["n_channels"]
|
||||
self.n_levels = infile.attrs["n_levels"]
|
||||
self.eye_stats = infile.attrs["eye_stats"]
|
||||
self.eye_stats = [json.loads(stat) for stat in self.eye_stats]
|
||||
self.horizontal_bins = infile.attrs["horizontal_bins"]
|
||||
self.vertical_bins = infile.attrs["vertical_bins"]
|
||||
self.multi_threaded = infile.attrs["multithreaded"]
|
||||
self.analysed = infile.attrs["analysed"]
|
||||
self.eye_built = infile.attrs["eye_built"]
|
||||
|
||||
def save_data(self, file_or_dir=None):
|
||||
file_or_dir = self.save_file if file_or_dir is None else file_or_dir
|
||||
if file_or_dir is None:
|
||||
file = Path(f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_eye_data.eye.h5")
|
||||
elif Path(file_or_dir).is_dir():
|
||||
file = Path(file_or_dir) / f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_eye_data.eye.h5"
|
||||
else:
|
||||
file = Path(file_or_dir)
|
||||
|
||||
# file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.save_file = str(file)
|
||||
|
||||
with h5py.File(file, "w") as outfile:
|
||||
outfile.create_dataset("eye_data", data=self.eye_data)
|
||||
outfile.create_dataset("y_bins", data=self.y_bins)
|
||||
outfile.create_dataset("x_bins", data=self.x_bins)
|
||||
outfile.attrs["channel_names"] = self.channel_names
|
||||
outfile.attrs["n_channels"] = self.n_channels
|
||||
outfile.attrs["n_levels"] = self.n_levels
|
||||
self.eye_stats = eye_diagram.convert_arrays(self.eye_stats)
|
||||
outfile.attrs["eye_stats"] = [json.dumps(stat) for stat in self.eye_stats]
|
||||
outfile.attrs["horizontal_bins"] = self.horizontal_bins
|
||||
outfile.attrs["vertical_bins"] = self.vertical_bins
|
||||
outfile.attrs["multithreaded"] = self.multi_threaded
|
||||
outfile.attrs["analysed"] = self.analysed
|
||||
outfile.attrs["eye_built"] = self.eye_built
|
||||
|
||||
@staticmethod
|
||||
def convert_arrays(input_object):
|
||||
"""
|
||||
convert ndarrays in (nested) dict to lists
|
||||
"""
|
||||
|
||||
if isinstance(input_object, np.ndarray):
|
||||
return input_object.tolist()
|
||||
elif isinstance(input_object, list):
|
||||
return [eye_diagram.convert_arrays(old) for old in input_object]
|
||||
elif isinstance(input_object, tuple):
|
||||
return tuple(eye_diagram.convert_arrays(old) for old in input_object)
|
||||
elif isinstance(input_object, dict):
|
||||
dict_out = {}
|
||||
for key, value in input_object.items():
|
||||
dict_out[key] = eye_diagram.convert_arrays(value)
|
||||
return dict_out
|
||||
return input_object
|
||||
|
||||
def generate_eye_data(
|
||||
self, mode: Literal["default", "load", "save", "nosave"] = "default", file_or_dir: Path | None | str = None
|
||||
):
|
||||
# modes:
|
||||
# default: try to load eye data from file, if not found, generate and save
|
||||
# load: try to load eye data from file, if not found, generate but don't save
|
||||
# save: generate eye data and save
|
||||
update_save = True
|
||||
if mode == "load":
|
||||
self.load_data(file_or_dir)
|
||||
update_save = False
|
||||
elif mode == "default":
|
||||
try:
|
||||
self.load_data(file_or_dir)
|
||||
update_save = False
|
||||
except (FileNotFoundError, IsADirectoryError):
|
||||
pass
|
||||
|
||||
def plot(self, title="Eye Diagram", stats=True, show=True):
|
||||
if not self.eye_built:
|
||||
self.generate_eye_data()
|
||||
update_save = True
|
||||
self.x_bins = np.linspace(0, 2, self.horizontal_bins, endpoint=False)
|
||||
self.y_bins = np.zeros((self.n_channels, self.vertical_bins))
|
||||
self.eye_data = np.zeros((self.n_channels, self.vertical_bins, self.horizontal_bins))
|
||||
datas = [self.raw_data[i] for i in range(self.n_channels)]
|
||||
if self.multi_threaded:
|
||||
with multiprocessing.Pool() as pool:
|
||||
results = pool.map(self.generate_eye_data_single, datas)
|
||||
for i, result in enumerate(results):
|
||||
self.eye_data[i], self.y_bins[i] = result
|
||||
else:
|
||||
for i, data in enumerate(datas):
|
||||
self.eye_data[i], self.y_bins[i] = self.generate_eye_data_single(data)
|
||||
self.eye_built = True
|
||||
|
||||
if mode == "save" or (mode == "default" and update_save):
|
||||
self.save_data(file_or_dir)
|
||||
|
||||
def generate_eye_data_single(self, data):
|
||||
eye_data = np.zeros((self.vertical_bins, self.horizontal_bins))
|
||||
data_min = np.min(data[1, :])
|
||||
data_max = np.max(data[1, :])
|
||||
# round down/up to 1 decimal
|
||||
data_min = np.floor(data_min*10)/10
|
||||
data_max = np.ceil(data_max*10)/10
|
||||
# data_range = data_max - data_min
|
||||
# data_min -= 0.1 * data_range
|
||||
# data_max += 0.1 * data_range
|
||||
# data_min = -0.05
|
||||
# data_max += 0.05
|
||||
# data[1,:] -= np.min(data[1, :])
|
||||
# data[1,:] /= np.max(data[1, :])
|
||||
# data_min = 0
|
||||
# data_max = 1
|
||||
y_bins = np.linspace(data_min, data_max, self.vertical_bins, endpoint=False)
|
||||
t_vals = data[0, :] % 2 # + np.random.randn(*data[0, :].shape) * 1 / (512)
|
||||
val_vals = data[1, :] # + np.random.randn(*data[1, :].shape) * 1 / (320)
|
||||
x_indices = np.digitize(t_vals, self.x_bins) - 1
|
||||
y_indices = np.digitize(val_vals, y_bins) - 1
|
||||
np.add.at(eye_data, (y_indices, x_indices), 1)
|
||||
return eye_data, y_bins
|
||||
|
||||
def plot(
|
||||
self,
|
||||
title="Eye Diagram",
|
||||
stats=True,
|
||||
all_stats=True,
|
||||
show=True,
|
||||
mode: Literal["default", "load", "save", "nosave"] = "default",
|
||||
# save_images = False,
|
||||
# image_dir = None,
|
||||
# cmap=None,
|
||||
):
|
||||
if stats and not self.analysed:
|
||||
self.analyse(mode=mode)
|
||||
if not self.eye_built:
|
||||
self.generate_eye_data(mode=mode)
|
||||
cmap = LinearSegmentedColormap.from_list(
|
||||
"eyemap",
|
||||
[(0, "white"), (0.1, "blue"), (0.2, "cyan"), (0.5, "green"), (0.8, "yellow"), (0.9, "red"), (1, "magenta")],
|
||||
[
|
||||
(0, "#FFFFFF00"),
|
||||
(0.1, "blue"),
|
||||
(0.2, "cyan"),
|
||||
(0.5, "green"),
|
||||
(0.8, "yellow"),
|
||||
(0.9, "red"),
|
||||
(1, "magenta"),
|
||||
],
|
||||
)
|
||||
if self.channels % 2 == 0:
|
||||
# cmap = cm('google:turbo_r' if cmap is None else cmap)
|
||||
# first = cmap(-1)
|
||||
# cmap = cmap.to_mpl()
|
||||
# cmap.set_under(first, alpha=0)
|
||||
if self.n_channels % 2 == 0:
|
||||
rows = 2
|
||||
cols = self.channels // 2
|
||||
cols = self.n_channels // 2
|
||||
else:
|
||||
cols = int(np.ceil(np.sqrt(self.channels)))
|
||||
rows = int(np.ceil(self.channels / cols))
|
||||
cols = int(np.ceil(np.sqrt(self.n_channels)))
|
||||
rows = int(np.ceil(self.n_channels / cols))
|
||||
fig, ax = plt.subplots(rows, cols, sharex=True, sharey=False)
|
||||
fig.suptitle(title)
|
||||
fig.tight_layout()
|
||||
ax = np.atleast_1d(ax).transpose().flatten()
|
||||
for i in range(self.channels):
|
||||
ax[i].set_title(self.channel_names[i] if self.channel_names is not None else f"Channel {i+1}")
|
||||
ax[i].set_xlabel("Symbol")
|
||||
ax[i].set_ylabel("Amplitude")
|
||||
for i in range(self.n_channels):
|
||||
ax[i].set_title(self.channel_names[i] if self.channel_names is not None else f"Channel {i + 1}")
|
||||
if (i + 1) % rows == 0:
|
||||
ax[i].set_xlabel("Symbol")
|
||||
if i < rows:
|
||||
ax[i].set_ylabel("Amplitude")
|
||||
ax[i].grid()
|
||||
ax[i].set_axisbelow(True)
|
||||
ax[i].imshow(
|
||||
self.eye_data[i],
|
||||
self.eye_data[i] - 0.1,
|
||||
origin="lower",
|
||||
aspect="auto",
|
||||
cmap=cmap,
|
||||
extent=[self.x_bins[0], self.x_bins[-1], self.y_bins[i][0], self.y_bins[i][-1]],
|
||||
interpolation="gaussian",
|
||||
vmin=0,
|
||||
zorder=3,
|
||||
)
|
||||
ax[i].set_xlim((self.x_bins[0], self.x_bins[-1]))
|
||||
ymin = np.min(self.y_bins[:, 0])
|
||||
ymax = np.max(self.y_bins[:, -1])
|
||||
yspan = ymax - ymin
|
||||
ax[i].set_ylim((ymin - 0.1 * yspan, ymax + 0.1 * yspan))
|
||||
# if save_images:
|
||||
# image_dir = "images_out" if image_dir is None else image_dir
|
||||
# image_path = Path(image_dir) / (slugify(f"{datetime.now().strftime("%Y%m%d_%H%M%S")}_{title.replace(" ","_")}_{self.channel_names[i].replace(" ", "_") if self.channel_names is not None else f"{i + 1}"}_{ymin:.1f}_{ymax:.1f}") + ".png")
|
||||
# image_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
# # plt.imsave(
|
||||
# # image_path,
|
||||
# # self.eye_data[i] - 0.1,
|
||||
# # origin="lower",
|
||||
# # # aspect="auto",
|
||||
# # cmap=cmap,
|
||||
# # # extent=[self.x_bins[0], self.x_bins[-1], self.y_bins[i][0], self.y_bins[i][-1]],
|
||||
# # # interpolation="gaussian",
|
||||
# # vmin=0,
|
||||
# # # zorder=3,
|
||||
# # )
|
||||
if stats and self.eye_stats[i]["success"]:
|
||||
ax[i].plot([0, 2], [self.eye_stats[i]["levels"], self.eye_stats[i]["levels"]], "k--")
|
||||
ax[i].set_yticks(self.eye_stats[i]["levels"])
|
||||
# add arrows for amplitudes
|
||||
for j in range(len(self.eye_stats[i]["amplitudes"])):
|
||||
ax[i].annotate(
|
||||
"",
|
||||
xy=(0.05, self.eye_stats[i]["levels"][j]),
|
||||
xytext=(0.05, self.eye_stats[i]["levels"][j + 1]),
|
||||
arrowprops=dict(arrowstyle="<->", facecolor="black"),
|
||||
)
|
||||
ax[i].annotate(
|
||||
f"{self.eye_stats[i]['amplitudes'][j]:.2e}",
|
||||
xy=(0.06, (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2),
|
||||
)
|
||||
# add arrows for eye heights
|
||||
for j in range(len(self.eye_stats[i]["heights"])):
|
||||
try:
|
||||
bot = np.max(self.eye_stats[i]["amplitude_clusters"][j])
|
||||
top = np.min(self.eye_stats[i]["amplitude_clusters"][j + 1])
|
||||
# # add min_area above the plot
|
||||
# ax[i].annotate(
|
||||
# f"Min Area: {self.eye_stats[i]['min_area']:.2e}",
|
||||
# xy=(0.05, ymax + 0.05 * yspan),
|
||||
# # xycoords="axes fraction",
|
||||
# ha="left",
|
||||
# va="center",
|
||||
# bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"),
|
||||
# )
|
||||
|
||||
if all_stats:
|
||||
ax[i].plot([0, 2], [self.eye_stats[i]["levels"], self.eye_stats[i]["levels"]], "k--")
|
||||
y_ticks = (*self.eye_stats[i]["levels"], *self.eye_stats[i]["thresholds"])
|
||||
# y_ticks = np.sort(y_ticks)
|
||||
ax[i].set_yticks(y_ticks)
|
||||
# add arrows for amplitudes
|
||||
for j in range(len(self.eye_stats[i]["amplitudes"])):
|
||||
ax[i].annotate(
|
||||
"",
|
||||
xy=(self.eye_stats[i]["time_midpoint"], bot),
|
||||
xytext=(self.eye_stats[i]["time_midpoint"], top),
|
||||
xy=(0.05, self.eye_stats[i]["levels"][j]),
|
||||
xytext=(0.05, self.eye_stats[i]["levels"][j + 1]),
|
||||
arrowprops=dict(arrowstyle="<->", facecolor="black"),
|
||||
)
|
||||
ax[i].annotate(
|
||||
f"{self.eye_stats[i]['heights'][j]:.2e}",
|
||||
xy=(self.eye_stats[i]["time_midpoint"] + 0.015, (bot + top) / 2 + 0.04),
|
||||
f"{self.eye_stats[i]['amplitudes'][j]:.2e}",
|
||||
xy=(0.06, (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2),
|
||||
bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"),
|
||||
)
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
# add arrows for eye widths
|
||||
for j in range(len(self.eye_stats[i]["widths"])):
|
||||
try:
|
||||
left = np.max(self.eye_stats[i]["time_clusters"][j][0])
|
||||
right = np.min(self.eye_stats[i]["time_clusters"][j][1])
|
||||
vertical = (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2
|
||||
# add arrows for eye heights
|
||||
for j in range(len(self.eye_stats[i]["heights"])):
|
||||
try:
|
||||
bot = np.max(self.eye_stats[i]["amplitude_clusters"][j])
|
||||
top = np.min(self.eye_stats[i]["amplitude_clusters"][j + 1])
|
||||
|
||||
ax[i].annotate(
|
||||
"",
|
||||
xy=(left, vertical),
|
||||
xytext=(right, vertical),
|
||||
arrowprops=dict(arrowstyle="<->", facecolor="black"),
|
||||
)
|
||||
ax[i].annotate(
|
||||
f"{self.eye_stats[i]['widths'][j]:.2e}",
|
||||
xy=((left + right) / 2 - 0.15, vertical + 0.01),
|
||||
)
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
ax[i].annotate(
|
||||
"",
|
||||
xy=(self.eye_stats[i]["time_midpoint"], bot),
|
||||
xytext=(self.eye_stats[i]["time_midpoint"], top),
|
||||
arrowprops=dict(arrowstyle="<->", facecolor="black"),
|
||||
)
|
||||
ax[i].annotate(
|
||||
f"{self.eye_stats[i]['heights'][j]:.2e}",
|
||||
xy=(self.eye_stats[i]["time_midpoint"] + 0.015, (bot + top) / 2 + 0.04),
|
||||
bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"),
|
||||
)
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
# add arrows for eye widths
|
||||
# for j in range(len(self.eye_stats[i]["widths"])):
|
||||
# try:
|
||||
# left = np.max(self.eye_stats[i]["time_clusters"][j][0])
|
||||
# right = np.min(self.eye_stats[i]["time_clusters"][j][1])
|
||||
# vertical = (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2
|
||||
|
||||
# ax[i].annotate(
|
||||
# "",
|
||||
# xy=(left, vertical),
|
||||
# xytext=(right, vertical),
|
||||
# arrowprops=dict(arrowstyle="<->", facecolor="black"),
|
||||
# )
|
||||
# ax[i].annotate(
|
||||
# f"{self.eye_stats[i]['widths'][j]:.2e}",
|
||||
# xy=((left + right) / 2 - 0.15, vertical + 0.01),
|
||||
# bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"),
|
||||
# )
|
||||
# except (ValueError, IndexError):
|
||||
# pass
|
||||
|
||||
# # add area
|
||||
# for j in range(len(self.eye_stats[i]["areas"])):
|
||||
# horizontal = self.eye_stats[i]["time_midpoint"]
|
||||
# vertical = (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2
|
||||
# ax[i].annotate(
|
||||
# f"{self.eye_stats[i]['areas'][j]:.2e}",
|
||||
# xy=(horizontal + 0.035, vertical - 0.07),
|
||||
# bbox=dict(facecolor="white", alpha=0.5, edgecolor="none"),
|
||||
# )
|
||||
|
||||
# add area
|
||||
for j in range(len(self.eye_stats[i]["areas"])):
|
||||
horizontal = self.eye_stats[i]["time_midpoint"]
|
||||
vertical = (self.eye_stats[i]["levels"][j] + self.eye_stats[i]["levels"][j + 1]) / 2
|
||||
ax[i].annotate(
|
||||
f"{self.eye_stats[i]['areas'][j]:.2e}",
|
||||
xy=(horizontal + 0.035, vertical - 0.07),
|
||||
)
|
||||
|
||||
# add min_area above the plot
|
||||
ax[i].annotate(
|
||||
f"Min Area: {self.eye_stats[i]['min_area']:.2e}",
|
||||
xy=(0.05, ymax + 0.05 * yspan),
|
||||
# xycoords="axes fraction",
|
||||
ha="left",
|
||||
va="center",
|
||||
)
|
||||
fig.tight_layout()
|
||||
|
||||
if show:
|
||||
plt.show()
|
||||
return fig
|
||||
|
||||
def analyse(self, n_levels=4):
|
||||
@staticmethod
|
||||
def calculate_thresholds(levels):
|
||||
ret = np.cumsum(levels, dtype=float)
|
||||
ret[2:] = ret[2:] - ret[:-2]
|
||||
return ret[1:] / 2
|
||||
|
||||
def analyse_single(self, data, index):
|
||||
warnings.filterwarnings("error")
|
||||
for i in range(self.channels):
|
||||
self.eye_stats[i]["channel"] = str(i+1) if self.channel_names is None else self.channel_names[i]
|
||||
try:
|
||||
approx_levels = eye_diagram.approximate_levels(self.raw_data[i], n_levels)
|
||||
eye_stats = {}
|
||||
eye_stats["channel_name"] = str(index + 1) if self.channel_names is None else self.channel_names[index]
|
||||
try:
|
||||
approx_levels = eye_diagram.approximate_levels(data, self.n_levels)
|
||||
|
||||
time_bounds = eye_diagram.calculate_time_bounds(self.raw_data[i], approx_levels)
|
||||
time_bounds = eye_diagram.calculate_time_bounds(data, approx_levels)
|
||||
|
||||
self.eye_stats[i]["time_midpoint"] = (time_bounds[0] + time_bounds[1]) / 2
|
||||
eye_stats["time_midpoint"] = float((time_bounds[0] + time_bounds[1]) / 2)
|
||||
# eye_stats["time_midpoint"] = 1.0
|
||||
|
||||
self.eye_stats[i]["levels"], self.eye_stats[i]["amplitude_clusters"] = eye_diagram.calculate_levels(
|
||||
self.raw_data[i], approx_levels, time_bounds
|
||||
)
|
||||
eye_stats["levels"], eye_stats["amplitude_clusters"] = eye_diagram.calculate_levels(
|
||||
data, approx_levels, time_bounds
|
||||
)
|
||||
|
||||
self.eye_stats[i]["amplitudes"] = np.diff(self.eye_stats[i]["levels"])
|
||||
eye_stats["thresholds"] = self.calculate_thresholds(eye_stats["levels"])
|
||||
|
||||
self.eye_stats[i]["heights"] = eye_diagram.calculate_eye_heights(
|
||||
self.eye_stats[i]["amplitude_clusters"]
|
||||
)
|
||||
eye_stats["amplitudes"] = np.diff(eye_stats["levels"])
|
||||
|
||||
self.eye_stats[i]["widths"], self.eye_stats[i]["time_clusters"] = eye_diagram.calculate_eye_widths(
|
||||
self.raw_data[i], self.eye_stats[i]["levels"]
|
||||
)
|
||||
eye_stats["heights"] = eye_diagram.calculate_eye_heights(eye_stats["amplitude_clusters"])
|
||||
|
||||
# # check if time clusters are valid (upper bound > time_midpoint > lower bound)
|
||||
# # if not: raise ValueError
|
||||
# for j in range(len(self.eye_stats[i]['time_clusters'])):
|
||||
# if not (np.max(self.eye_stats[i]['time_clusters'][j][0]) < self.eye_stats[i]["time_midpoint"] < np.min(self.eye_stats[i]['time_clusters'][j][1])):
|
||||
# raise ValueError
|
||||
eye_stats["widths"], eye_stats["time_clusters"] = eye_diagram.calculate_eye_widths(
|
||||
data, eye_stats["levels"]
|
||||
)
|
||||
|
||||
self.eye_stats[i]["areas"] = self.eye_stats[i]["heights"] * self.eye_stats[i]["widths"]
|
||||
self.eye_stats[i]["mean_area"] = np.mean(self.eye_stats[i]["areas"])
|
||||
self.eye_stats[i]["min_area"] = np.min(self.eye_stats[i]["areas"])
|
||||
# # check if time clusters are valid (upper bound > time_midpoint > lower bound)
|
||||
# # if not: raise ValueError
|
||||
# for j in range(len(eye_stats['time_clusters'])):
|
||||
# if not (np.max(eye_stats['time_clusters'][j][0]) < eye_stats["time_midpoint"] < np.min(eye_stats['time_clusters'][j][1])):
|
||||
# raise ValueError
|
||||
|
||||
self.eye_stats[i]["success"] = True
|
||||
except (RuntimeWarning, UserWarning, ValueError):
|
||||
self.eye_stats[i]["success"] = False
|
||||
self.eye_stats[i]["time_midpoint"] = 0
|
||||
self.eye_stats[i]["levels"] = np.zeros(n_levels)
|
||||
self.eye_stats[i]["amplitude_clusters"] = []
|
||||
self.eye_stats[i]["amplitudes"] = np.zeros(n_levels - 1)
|
||||
self.eye_stats[i]["heights"] = np.zeros(n_levels - 1)
|
||||
self.eye_stats[i]["widths"] = np.zeros(n_levels - 1)
|
||||
self.eye_stats[i]["areas"] = np.zeros(n_levels - 1)
|
||||
self.eye_stats[i]["mean_area"] = 0
|
||||
self.eye_stats[i]["min_area"] = 0
|
||||
# eye_stats["areas"] = eye_stats["heights"] * eye_stats["widths"]
|
||||
# eye_stats["mean_area"] = np.mean(eye_stats["areas"])
|
||||
# eye_stats["min_area"] = np.min(eye_stats["areas"])
|
||||
|
||||
eye_stats["success"] = True
|
||||
except (RuntimeWarning, UserWarning, ValueError):
|
||||
eye_stats["success"] = False
|
||||
eye_stats["time_midpoint"] = None
|
||||
eye_stats["levels"] = None
|
||||
eye_stats["thresholds"] = None
|
||||
eye_stats["amplitude_clusters"] = None
|
||||
eye_stats["amplitudes"] = None
|
||||
eye_stats["heights"] = None
|
||||
eye_stats["widths"] = None
|
||||
# eye_stats["areas"] = np.zeros(self.n_levels - 1)
|
||||
# eye_stats["mean_area"] = 0
|
||||
# eye_stats["min_area"] = 0
|
||||
warnings.resetwarnings()
|
||||
return eye_stats
|
||||
|
||||
def analyse(
|
||||
self, mode: Literal["default", "load", "save", "nosave"] = "default", file_or_dir: Path | None | str = None
|
||||
):
|
||||
# modes:
|
||||
# default: try to load eye data from file, if not found, generate and save
|
||||
# load: try to load eye data from file, if not found, generate but don't save
|
||||
# save: generate eye data and save
|
||||
update_save = True
|
||||
if mode == "load":
|
||||
self.load_data(file_or_dir)
|
||||
update_save = False
|
||||
elif mode == "default":
|
||||
try:
|
||||
self.load_data(file_or_dir)
|
||||
update_save = False
|
||||
except (FileNotFoundError, IsADirectoryError):
|
||||
pass
|
||||
|
||||
if not self.analysed:
|
||||
update_save = True
|
||||
self.eye_stats = []
|
||||
if self.multi_threaded:
|
||||
with multiprocessing.Pool() as pool:
|
||||
results = pool.starmap(self.analyse_single, [(self.raw_data[i], i) for i in range(self.n_channels)])
|
||||
for i, result in enumerate(results):
|
||||
self.eye_stats.append(result)
|
||||
else:
|
||||
for i in range(self.n_channels):
|
||||
self.eye_stats.append(self.analyse_single(self.raw_data[i], i))
|
||||
self.analysed = True
|
||||
|
||||
if mode == "save" or (mode == "default" and update_save):
|
||||
self.save_data(file_or_dir)
|
||||
|
||||
@staticmethod
|
||||
def approximate_levels(data, levels):
|
||||
@@ -399,7 +643,7 @@ class eye_diagram:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
length = int(2**14)
|
||||
length = int(2**16)
|
||||
# data = generate_sample_data(length, noise=1)
|
||||
# data1 = generate_sample_data(length, noise=0.01)
|
||||
# data2 = generate_sample_data(length, noise=0.01, skew=1.2)
|
||||
@@ -407,12 +651,13 @@ if __name__ == "__main__":
|
||||
|
||||
# data = np.stack([data, data1, data2, data3])
|
||||
|
||||
data = generate_sample_data(length, noise=0.005)
|
||||
eye = eye_diagram(data, horizontal_bins=256, vertical_bins=256)
|
||||
attrs = ("success", "amplitudes", "time_midpoint", "levels", "heights", "widths", "area", "mean_area", "min_area")
|
||||
for i, channel in enumerate(eye.eye_stats):
|
||||
print(f"Channel {i}")
|
||||
print_data = {attr: channel[attr] for attr in attrs}
|
||||
print(print_data)
|
||||
data = generate_sample_data(length, noise=0.0000)
|
||||
eye = eye_diagram(data, horizontal_bins=1056, vertical_bins=1200)
|
||||
eye.plot(mode="nosave", stats=False)
|
||||
# attrs = ("success", "amplitudes", "time_midpoint", "levels", "heights", "widths")#, "area", "mean_area", "min_area")
|
||||
# for i, channel in enumerate(eye.eye_stats):
|
||||
# print(f"Channel {i}")
|
||||
# print_data = {attr: channel[attr] for attr in attrs}
|
||||
# print(print_data)
|
||||
|
||||
eye.plot()
|
||||
# eye.plot()
|
||||
|
||||
122
src/single-core-regen/util/mpl.py
Normal file
122
src/single-core-regen/util/mpl.py
Normal file
@@ -0,0 +1,122 @@
|
||||
# Copyright (c) 2015, Warren Weckesser. All rights reserved.
|
||||
# This software is licensed according to the "BSD 2-clause" license.
|
||||
|
||||
# modified by Joseph Hopfmüller in 2025,
|
||||
# for integration into optical regeneration analysis scripts
|
||||
|
||||
from pathlib import Path
|
||||
from matplotlib.colors import LinearSegmentedColormap
|
||||
import matplotlib.colors as colors
|
||||
import numpy as _np
|
||||
from .core import grid_count as _grid_count
|
||||
import matplotlib.pyplot as _plt
|
||||
import numpy as np
|
||||
from scipy.ndimage import gaussian_filter
|
||||
|
||||
|
||||
# from ._common import _common_doc
|
||||
|
||||
|
||||
__all__ = ["eyediagram"] # , 'eyediagram_lines']
|
||||
|
||||
|
||||
# def eyediagram_lines(y, window_size, offset=0, **plotkwargs):
|
||||
# """
|
||||
# Plot an eye diagram using matplotlib by repeatedly calling the `plot`
|
||||
# function.
|
||||
# <common>
|
||||
|
||||
# """
|
||||
# start = offset
|
||||
# while start < len(y):
|
||||
# end = start + window_size
|
||||
# if end > len(y):
|
||||
# end = len(y)
|
||||
# yy = y[start:end+1]
|
||||
# _plt.plot(_np.arange(len(yy)), yy, 'k', **plotkwargs)
|
||||
# start = end
|
||||
|
||||
# eyediagram_lines.__doc__ = eyediagram_lines.__doc__.replace("<common>",
|
||||
# _common_doc)
|
||||
|
||||
|
||||
eyemap = LinearSegmentedColormap.from_list(
|
||||
"eyemap",
|
||||
[
|
||||
(0, "#0000FF00"),
|
||||
(0.1, "blue"),
|
||||
(0.2, "cyan"),
|
||||
(0.5, "green"),
|
||||
(0.8, "yellow"),
|
||||
(0.9, "red"),
|
||||
(1, "magenta"),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def eyediagram(
|
||||
y,
|
||||
window_size,
|
||||
offset=0,
|
||||
colorbar=False,
|
||||
show=False,
|
||||
save_im=False,
|
||||
overwrite=False,
|
||||
blur: int | bool = True,
|
||||
save_path="out.png",
|
||||
bounds=None,
|
||||
**imshowkwargs,
|
||||
):
|
||||
"""
|
||||
Plot an eye diagram using matplotlib by creating an image and calling
|
||||
the `imshow` function.
|
||||
<common>
|
||||
"""
|
||||
if bounds is None:
|
||||
ymax = y.max()
|
||||
ymin = y.min()
|
||||
yamp = ymax - ymin
|
||||
ymin = ymin - 0.05 * yamp
|
||||
ymax = ymax + 0.05 * yamp
|
||||
ymin = np.floor(ymin * 10) / 10
|
||||
ymax = np.ceil(ymax * 10) / 10
|
||||
bounds = (ymin, ymax)
|
||||
counts = _grid_count(y, window_size, offset, bounds=bounds, size=(1000, 1200), blur=int(blur))
|
||||
counts = counts.astype(_np.float32)
|
||||
origin = imshowkwargs.pop("origin", "lower")
|
||||
cmap: colors.Colormap = imshowkwargs.pop("cmap", eyemap)
|
||||
vmin = imshowkwargs.pop("vmin", 1)
|
||||
vmax = imshowkwargs.pop("vmax", None)
|
||||
cmap.set_under("white", alpha=0)
|
||||
|
||||
if show:
|
||||
_plt.imshow(
|
||||
counts.T[::-1, :],
|
||||
extent=[0, 2, *bounds],
|
||||
origin=origin,
|
||||
cmap=cmap,
|
||||
vmin=vmin,
|
||||
vmax=vmax,
|
||||
**imshowkwargs,
|
||||
)
|
||||
_plt.grid()
|
||||
if colorbar:
|
||||
_plt.colorbar()
|
||||
|
||||
if Path(save_path).is_file() and not overwrite:
|
||||
save_im = False
|
||||
if save_im:
|
||||
from PIL import Image
|
||||
arr = counts.T[::-1, :]
|
||||
if origin == "lower":
|
||||
arr = arr[::-1]
|
||||
arr = (arr-arr.min())/(arr.max()-arr.min())
|
||||
image = Image.fromarray((cmap(arr)[:, :, :] * 255).astype(np.uint8))
|
||||
image.save(save_path)
|
||||
# print("-")
|
||||
|
||||
if show:
|
||||
_plt.show()
|
||||
|
||||
|
||||
# eyediagram.__doc__ = eyediagram.__doc__.replace("<common>", _common_doc)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user