add training script for polarization estimation, refactor model definitions, randomised polarisation support in data_loader
This commit is contained in:
@@ -260,12 +260,94 @@ class ONNRect(nn.Module):
|
||||
self.crop = lambda x: x
|
||||
self.crop.__doc__ = "No cropping"
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
x = self.pad(x)
|
||||
x = self.pad(x).to(dtype=self.weight.dtype)
|
||||
out = self.crop((self.weight @ x.mT).mT)
|
||||
return out
|
||||
|
||||
class polarimeter(nn.Module):
|
||||
def __init__(self):
|
||||
super(polarimeter, self).__init__()
|
||||
# self.input_length = input_length
|
||||
|
||||
def forward(self, data):
|
||||
# S0 = I
|
||||
# S1 = (2*I_x - I)/I
|
||||
# S2 = (2*I_45 - I)/I
|
||||
# S3 = (2*I_RHC - I)/I
|
||||
|
||||
# # data: (batch, input_length*2) -> (batch, input_length, 2)
|
||||
data = data.view(data.shape[0], -1, 2)
|
||||
x = data[:, :, 0].mean(dim=1)
|
||||
y = data[:, :, 1].mean(dim=1)
|
||||
|
||||
# x = x.mean(dim=1)
|
||||
# y = y.mean(dim=1)
|
||||
|
||||
# angle = torch.atan2(y.abs().square().real, x.abs().square().real)
|
||||
|
||||
# return torch.stack([angle, angle, angle, angle], dim=1)
|
||||
|
||||
# horizontal polarisation
|
||||
I_x = x.abs().square()
|
||||
|
||||
# vertical polarisation
|
||||
I_y = y.abs().square()
|
||||
|
||||
# 45 degree polarisation
|
||||
I_45 = (x + y).abs().square()
|
||||
|
||||
|
||||
# right hand circular polarisation
|
||||
I_RHC = (x + 1j*y).abs().square()
|
||||
|
||||
# S0 = I_x + I_y
|
||||
# S1 = I_x - I_y
|
||||
# S2 = I_45 - I_m45
|
||||
# S3 = I_RHC - I_LHC
|
||||
|
||||
S0 = (I_x + I_y)
|
||||
S1 = ((2*I_x - S0)/S0)
|
||||
S2 = ((2*I_45 - S0)/S0)
|
||||
S3 = ((2*I_RHC - S0)/S0)
|
||||
|
||||
return torch.stack([S0/S0, S1/S0, S2/S0, S3/S0], dim=1)
|
||||
|
||||
class normalize_by_first(nn.Module):
|
||||
def __init__(self):
|
||||
super(normalize_by_first, self).__init__()
|
||||
|
||||
def forward(self, data):
|
||||
return data / data[:, 0].unsqueeze(1)
|
||||
|
||||
class photodiode(nn.Module):
|
||||
def __init__(self, size, bias=True):
|
||||
super(photodiode, self).__init__()
|
||||
self.input_dim = size
|
||||
self.scale = nn.Parameter(torch.rand(size))
|
||||
self.pd_bias = nn.Parameter(torch.rand(size))
|
||||
|
||||
def forward(self, x):
|
||||
return x.abs().square().to(dtype=x.dtype.to_real()).mul(self.scale).add(self.pd_bias)
|
||||
|
||||
|
||||
class input_rotator(nn.Module):
|
||||
def __init__(self, input_dim):
|
||||
super(input_rotator, self).__init__()
|
||||
assert input_dim % 2 == 0, "Input dimension must be even"
|
||||
self.input_dim = input_dim
|
||||
# self.angle = nn.Parameter(torch.randn(1, dtype=self.dtype.to_real()))
|
||||
|
||||
def forward(self, x, angle=None):
|
||||
# take channels (0,1), (2,3), ... and rotate them by the angle
|
||||
angle = angle or self.angle
|
||||
sine = torch.sin(angle)
|
||||
cosine = torch.cos(angle)
|
||||
rot = torch.tensor([[cosine, -sine], [sine, cosine]], dtype=self.dtype)
|
||||
return torch.matmul(x.view(-1, 2), rot).view(x.shape)
|
||||
|
||||
|
||||
|
||||
# def __repr__(self):
|
||||
# return f"ONNRect({self.input_dim}, {self.output_dim})"
|
||||
|
||||
@@ -371,7 +453,7 @@ class Identity(nn.Module):
|
||||
M(z) = z
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, size=None):
|
||||
super(Identity, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
@@ -404,9 +486,28 @@ class MZISingle(nn.Module):
|
||||
def forward(self, x: torch.Tensor):
|
||||
return x * torch.exp(1j * self.phi) * torch.sin(self.omega + self.func(x))
|
||||
|
||||
def naive_angle_loss(x: torch.Tensor, target: torch.Tensor, mod=2*torch.pi):
|
||||
return torch.fmod((x - target), mod).square().mean()
|
||||
|
||||
def cosine_loss(x: torch.Tensor, target: torch.Tensor):
|
||||
return (2*(1 - torch.cos(x - target))).mean()
|
||||
|
||||
def angle_mse_loss(x: torch.Tensor, target: torch.Tensor):
|
||||
x = torch.fmod(x, 2*torch.pi)
|
||||
target = torch.fmod(target, 2*torch.pi)
|
||||
|
||||
x_cos = torch.cos(x)
|
||||
x_sin = torch.sin(x)
|
||||
target_cos = torch.cos(target)
|
||||
target_sin = torch.sin(target)
|
||||
|
||||
cos_diff = x_cos - target_cos
|
||||
sin_diff = x_sin - target_sin
|
||||
squared_diff = cos_diff**2 + sin_diff**2
|
||||
return squared_diff.mean()
|
||||
|
||||
class EOActivation(nn.Module):
|
||||
def __init__(self, bias, size=None):
|
||||
def __init__(self, size=None):
|
||||
# 10.1109/SiPhotonics60897.2024.10543376
|
||||
super(EOActivation, self).__init__()
|
||||
if size is None:
|
||||
@@ -569,83 +670,12 @@ class ZReLU(nn.Module):
|
||||
return x * (torch.angle(x) >= 0) * (torch.angle(x) <= torch.pi / 2)
|
||||
else:
|
||||
return torch.relu(x)
|
||||
|
||||
|
||||
class regenerator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*dims,
|
||||
layer_function=ONN,
|
||||
layer_kwargs: dict | None = None,
|
||||
layer_parametrizations: list[dict] = None,
|
||||
activation_function=Pow,
|
||||
dtype=torch.float64,
|
||||
dropout_prob=0.01,
|
||||
scale=False,
|
||||
**kwargs,
|
||||
):
|
||||
super(regenerator, self).__init__()
|
||||
if len(dims) == 0:
|
||||
try:
|
||||
dims = kwargs["dims"]
|
||||
except KeyError:
|
||||
raise ValueError("dims must be provided")
|
||||
self._n_hidden_layers = len(dims) - 2
|
||||
self._layers = nn.Sequential()
|
||||
if layer_kwargs is None:
|
||||
layer_kwargs = {}
|
||||
# self.powers = []
|
||||
|
||||
for i in range(self._n_hidden_layers + 1):
|
||||
if scale:
|
||||
self._layers.append(Scale(dims[i]))
|
||||
self._layers.append(layer_function(dims[i], dims[i + 1], dtype=dtype, **layer_kwargs))
|
||||
if i < self._n_hidden_layers:
|
||||
if dropout_prob is not None:
|
||||
self._layers.append(DropoutComplex(p=dropout_prob))
|
||||
self._layers.append(activation_function(bias=True, size=dims[i + 1]))
|
||||
|
||||
self._layers.append(Scale(dims[-1]))
|
||||
|
||||
# add parametrizations
|
||||
if layer_parametrizations is not None:
|
||||
for layer in self._layers:
|
||||
for layer_parametrization in layer_parametrizations:
|
||||
tensor_name = layer_parametrization.get("tensor_name", None)
|
||||
parametrization = layer_parametrization.get("parametrization", None)
|
||||
param_kwargs = layer_parametrization.get("kwargs", {})
|
||||
if tensor_name is not None and tensor_name in layer._parameters and parametrization is not None:
|
||||
parametrization(layer, tensor_name, **param_kwargs)
|
||||
|
||||
# def __call__(self, input_x, **kwargs):
|
||||
# return self.forward(input_x, **kwargs)
|
||||
|
||||
def forward(self, input_x, trace_powers=False):
|
||||
x = input_x
|
||||
|
||||
if trace_powers:
|
||||
powers = [x.abs().square().sum()]
|
||||
|
||||
# check if tracing
|
||||
if torch.jit.is_tracing():
|
||||
for layer in self._layers:
|
||||
x = layer(x)
|
||||
if trace_powers:
|
||||
powers.append(x.abs().square().sum())
|
||||
else:
|
||||
# with torch.nn.utils.parametrize.cached():
|
||||
for layer in self._layers:
|
||||
x = layer(x)
|
||||
if trace_powers:
|
||||
powers.append(x.abs().square().sum())
|
||||
if trace_powers:
|
||||
return x, powers
|
||||
return x
|
||||
|
||||
|
||||
__all__ = [
|
||||
complex_sse_loss,
|
||||
complex_mse_loss,
|
||||
angle_mse_loss,
|
||||
UnitaryLayer,
|
||||
unitary,
|
||||
energy_conserving,
|
||||
@@ -662,6 +692,7 @@ __all__ = [
|
||||
ZReLU,
|
||||
MZISingle,
|
||||
EOActivation,
|
||||
photodiode,
|
||||
# SaturableAbsorberLambertW,
|
||||
# SaturableAbsorber,
|
||||
# SpreadLayer,
|
||||
|
||||
Reference in New Issue
Block a user