model robustness testing
This commit is contained in:
@@ -481,6 +481,15 @@ class Identity(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
class phase_shift(nn.Module):
|
||||
def __init__(self, size):
|
||||
super(phase_shift, self).__init__()
|
||||
self.size = size
|
||||
self.phase = nn.Parameter(torch.rand(size))
|
||||
|
||||
def forward(self, x):
|
||||
return x * torch.exp(1j*self.phase)
|
||||
|
||||
|
||||
class PowRot(nn.Module):
|
||||
@@ -531,19 +540,19 @@ def angle_mse_loss(x: torch.Tensor, target: torch.Tensor):
|
||||
|
||||
class EOActivation(nn.Module):
|
||||
def __init__(self, size=None):
|
||||
# 10.1109/SiPhotonics60897.2024.10543376
|
||||
# 10.1109/JSTQE.2019.2930455
|
||||
super(EOActivation, self).__init__()
|
||||
if size is None:
|
||||
raise ValueError("Size must be specified")
|
||||
self.size = size
|
||||
self.alpha = nn.Parameter(torch.ones(size))
|
||||
self.V_bias = nn.Parameter(torch.ones(size))
|
||||
self.gain = nn.Parameter(torch.ones(size))
|
||||
self.alpha = nn.Parameter(torch.rand(size))
|
||||
self.V_bias = nn.Parameter(torch.rand(size))
|
||||
self.gain = nn.Parameter(torch.rand(size))
|
||||
# if bias:
|
||||
# self.phase_bias = nn.Parameter(torch.zeros(size))
|
||||
# else:
|
||||
# self.register_buffer("phase_bias", torch.zeros(size))
|
||||
self.register_buffer("phase_bias", torch.clamp(torch.ones(size) + torch.randn(size)*0.1, 0, 1)*torch.pi)
|
||||
# self.register_buffer("phase_bias", torch.clamp(torch.ones(size) + torch.randn(size)*0.1, 0, 1)*torch.pi)
|
||||
self.register_buffer("responsivity", torch.ones(size)*0.9)
|
||||
self.register_buffer("V_pi", torch.ones(size)*3)
|
||||
|
||||
@@ -551,17 +560,17 @@ class EOActivation(nn.Module):
|
||||
|
||||
def reset_weights(self):
|
||||
if "alpha" in self._parameters:
|
||||
self.alpha.data = torch.ones(self.size)*0.5
|
||||
self.alpha.data = torch.rand(self.size)
|
||||
if "V_pi" in self._parameters:
|
||||
self.V_pi.data = torch.ones(self.size)*3
|
||||
self.V_pi.data = torch.rand(self.size)*3
|
||||
if "V_bias" in self._parameters:
|
||||
self.V_bias.data = torch.zeros(self.size)
|
||||
self.V_bias.data = torch.randn(self.size)
|
||||
if "gain" in self._parameters:
|
||||
self.gain.data = torch.ones(self.size)
|
||||
self.gain.data = torch.rand(self.size)
|
||||
if "responsivity" in self._parameters:
|
||||
self.responsivity.data = torch.ones(self.size)*0.9
|
||||
if "bias" in self._parameters:
|
||||
self.phase_bias.data = torch.zeros(self.size)
|
||||
# if "bias" in self._parameters:
|
||||
# self.phase_bias.data = torch.zeros(self.size)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
phi_b = torch.pi * self.V_bias / (self.V_pi + 1e-8)
|
||||
@@ -570,12 +579,11 @@ class EOActivation(nn.Module):
|
||||
return (
|
||||
1j
|
||||
* torch.sqrt(1 - self.alpha)
|
||||
* torch.exp(-0.5j * (intermediate + self.phase_bias))
|
||||
* torch.exp(-0.5j * intermediate)
|
||||
* torch.cos(0.5 * intermediate)
|
||||
* x
|
||||
)
|
||||
|
||||
|
||||
class Pow(nn.Module):
|
||||
"""
|
||||
implements the activation function
|
||||
@@ -716,6 +724,7 @@ __all__ = [
|
||||
MZISingle,
|
||||
EOActivation,
|
||||
photodiode,
|
||||
phase_shift,
|
||||
# SaturableAbsorberLambertW,
|
||||
# SaturableAbsorber,
|
||||
# SpreadLayer,
|
||||
|
||||
Reference in New Issue
Block a user