model robustness testing

This commit is contained in:
Joseph Hopfmüller
2025-01-10 23:40:54 +01:00
parent 3af73343c1
commit f38d0ca3bb
13 changed files with 1558 additions and 334 deletions

View File

@@ -481,6 +481,15 @@ class Identity(nn.Module):
def forward(self, x):
return x
class phase_shift(nn.Module):
def __init__(self, size):
super(phase_shift, self).__init__()
self.size = size
self.phase = nn.Parameter(torch.rand(size))
def forward(self, x):
return x * torch.exp(1j*self.phase)
class PowRot(nn.Module):
@@ -531,19 +540,19 @@ def angle_mse_loss(x: torch.Tensor, target: torch.Tensor):
class EOActivation(nn.Module):
def __init__(self, size=None):
# 10.1109/SiPhotonics60897.2024.10543376
# 10.1109/JSTQE.2019.2930455
super(EOActivation, self).__init__()
if size is None:
raise ValueError("Size must be specified")
self.size = size
self.alpha = nn.Parameter(torch.ones(size))
self.V_bias = nn.Parameter(torch.ones(size))
self.gain = nn.Parameter(torch.ones(size))
self.alpha = nn.Parameter(torch.rand(size))
self.V_bias = nn.Parameter(torch.rand(size))
self.gain = nn.Parameter(torch.rand(size))
# if bias:
# self.phase_bias = nn.Parameter(torch.zeros(size))
# else:
# self.register_buffer("phase_bias", torch.zeros(size))
self.register_buffer("phase_bias", torch.clamp(torch.ones(size) + torch.randn(size)*0.1, 0, 1)*torch.pi)
# self.register_buffer("phase_bias", torch.clamp(torch.ones(size) + torch.randn(size)*0.1, 0, 1)*torch.pi)
self.register_buffer("responsivity", torch.ones(size)*0.9)
self.register_buffer("V_pi", torch.ones(size)*3)
@@ -551,17 +560,17 @@ class EOActivation(nn.Module):
def reset_weights(self):
if "alpha" in self._parameters:
self.alpha.data = torch.ones(self.size)*0.5
self.alpha.data = torch.rand(self.size)
if "V_pi" in self._parameters:
self.V_pi.data = torch.ones(self.size)*3
self.V_pi.data = torch.rand(self.size)*3
if "V_bias" in self._parameters:
self.V_bias.data = torch.zeros(self.size)
self.V_bias.data = torch.randn(self.size)
if "gain" in self._parameters:
self.gain.data = torch.ones(self.size)
self.gain.data = torch.rand(self.size)
if "responsivity" in self._parameters:
self.responsivity.data = torch.ones(self.size)*0.9
if "bias" in self._parameters:
self.phase_bias.data = torch.zeros(self.size)
# if "bias" in self._parameters:
# self.phase_bias.data = torch.zeros(self.size)
def forward(self, x: torch.Tensor):
phi_b = torch.pi * self.V_bias / (self.V_pi + 1e-8)
@@ -570,12 +579,11 @@ class EOActivation(nn.Module):
return (
1j
* torch.sqrt(1 - self.alpha)
* torch.exp(-0.5j * (intermediate + self.phase_bias))
* torch.exp(-0.5j * intermediate)
* torch.cos(0.5 * intermediate)
* x
)
class Pow(nn.Module):
"""
implements the activation function
@@ -716,6 +724,7 @@ __all__ = [
MZISingle,
EOActivation,
photodiode,
phase_shift,
# SaturableAbsorberLambertW,
# SaturableAbsorber,
# SpreadLayer,