add training script for polarization estimation, refactor model definitions, randomised polarisation support in data_loader

This commit is contained in:
Joseph Hopfmüller
2024-12-11 09:48:38 +01:00
parent 0e29b87395
commit 39ae13d0af
8 changed files with 1899 additions and 259 deletions

View File

@@ -260,12 +260,94 @@ class ONNRect(nn.Module):
self.crop = lambda x: x
self.crop.__doc__ = "No cropping"
def forward(self, x):
x = self.pad(x)
x = self.pad(x).to(dtype=self.weight.dtype)
out = self.crop((self.weight @ x.mT).mT)
return out
class polarimeter(nn.Module):
def __init__(self):
super(polarimeter, self).__init__()
# self.input_length = input_length
def forward(self, data):
# S0 = I
# S1 = (2*I_x - I)/I
# S2 = (2*I_45 - I)/I
# S3 = (2*I_RHC - I)/I
# # data: (batch, input_length*2) -> (batch, input_length, 2)
data = data.view(data.shape[0], -1, 2)
x = data[:, :, 0].mean(dim=1)
y = data[:, :, 1].mean(dim=1)
# x = x.mean(dim=1)
# y = y.mean(dim=1)
# angle = torch.atan2(y.abs().square().real, x.abs().square().real)
# return torch.stack([angle, angle, angle, angle], dim=1)
# horizontal polarisation
I_x = x.abs().square()
# vertical polarisation
I_y = y.abs().square()
# 45 degree polarisation
I_45 = (x + y).abs().square()
# right hand circular polarisation
I_RHC = (x + 1j*y).abs().square()
# S0 = I_x + I_y
# S1 = I_x - I_y
# S2 = I_45 - I_m45
# S3 = I_RHC - I_LHC
S0 = (I_x + I_y)
S1 = ((2*I_x - S0)/S0)
S2 = ((2*I_45 - S0)/S0)
S3 = ((2*I_RHC - S0)/S0)
return torch.stack([S0/S0, S1/S0, S2/S0, S3/S0], dim=1)
class normalize_by_first(nn.Module):
def __init__(self):
super(normalize_by_first, self).__init__()
def forward(self, data):
return data / data[:, 0].unsqueeze(1)
class photodiode(nn.Module):
def __init__(self, size, bias=True):
super(photodiode, self).__init__()
self.input_dim = size
self.scale = nn.Parameter(torch.rand(size))
self.pd_bias = nn.Parameter(torch.rand(size))
def forward(self, x):
return x.abs().square().to(dtype=x.dtype.to_real()).mul(self.scale).add(self.pd_bias)
class input_rotator(nn.Module):
def __init__(self, input_dim):
super(input_rotator, self).__init__()
assert input_dim % 2 == 0, "Input dimension must be even"
self.input_dim = input_dim
# self.angle = nn.Parameter(torch.randn(1, dtype=self.dtype.to_real()))
def forward(self, x, angle=None):
# take channels (0,1), (2,3), ... and rotate them by the angle
angle = angle or self.angle
sine = torch.sin(angle)
cosine = torch.cos(angle)
rot = torch.tensor([[cosine, -sine], [sine, cosine]], dtype=self.dtype)
return torch.matmul(x.view(-1, 2), rot).view(x.shape)
# def __repr__(self):
# return f"ONNRect({self.input_dim}, {self.output_dim})"
@@ -371,7 +453,7 @@ class Identity(nn.Module):
M(z) = z
"""
def __init__(self):
def __init__(self, size=None):
super(Identity, self).__init__()
def forward(self, x):
@@ -404,9 +486,28 @@ class MZISingle(nn.Module):
def forward(self, x: torch.Tensor):
return x * torch.exp(1j * self.phi) * torch.sin(self.omega + self.func(x))
def naive_angle_loss(x: torch.Tensor, target: torch.Tensor, mod=2*torch.pi):
return torch.fmod((x - target), mod).square().mean()
def cosine_loss(x: torch.Tensor, target: torch.Tensor):
return (2*(1 - torch.cos(x - target))).mean()
def angle_mse_loss(x: torch.Tensor, target: torch.Tensor):
x = torch.fmod(x, 2*torch.pi)
target = torch.fmod(target, 2*torch.pi)
x_cos = torch.cos(x)
x_sin = torch.sin(x)
target_cos = torch.cos(target)
target_sin = torch.sin(target)
cos_diff = x_cos - target_cos
sin_diff = x_sin - target_sin
squared_diff = cos_diff**2 + sin_diff**2
return squared_diff.mean()
class EOActivation(nn.Module):
def __init__(self, bias, size=None):
def __init__(self, size=None):
# 10.1109/SiPhotonics60897.2024.10543376
super(EOActivation, self).__init__()
if size is None:
@@ -569,83 +670,12 @@ class ZReLU(nn.Module):
return x * (torch.angle(x) >= 0) * (torch.angle(x) <= torch.pi / 2)
else:
return torch.relu(x)
class regenerator(nn.Module):
def __init__(
self,
*dims,
layer_function=ONN,
layer_kwargs: dict | None = None,
layer_parametrizations: list[dict] = None,
activation_function=Pow,
dtype=torch.float64,
dropout_prob=0.01,
scale=False,
**kwargs,
):
super(regenerator, self).__init__()
if len(dims) == 0:
try:
dims = kwargs["dims"]
except KeyError:
raise ValueError("dims must be provided")
self._n_hidden_layers = len(dims) - 2
self._layers = nn.Sequential()
if layer_kwargs is None:
layer_kwargs = {}
# self.powers = []
for i in range(self._n_hidden_layers + 1):
if scale:
self._layers.append(Scale(dims[i]))
self._layers.append(layer_function(dims[i], dims[i + 1], dtype=dtype, **layer_kwargs))
if i < self._n_hidden_layers:
if dropout_prob is not None:
self._layers.append(DropoutComplex(p=dropout_prob))
self._layers.append(activation_function(bias=True, size=dims[i + 1]))
self._layers.append(Scale(dims[-1]))
# add parametrizations
if layer_parametrizations is not None:
for layer in self._layers:
for layer_parametrization in layer_parametrizations:
tensor_name = layer_parametrization.get("tensor_name", None)
parametrization = layer_parametrization.get("parametrization", None)
param_kwargs = layer_parametrization.get("kwargs", {})
if tensor_name is not None and tensor_name in layer._parameters and parametrization is not None:
parametrization(layer, tensor_name, **param_kwargs)
# def __call__(self, input_x, **kwargs):
# return self.forward(input_x, **kwargs)
def forward(self, input_x, trace_powers=False):
x = input_x
if trace_powers:
powers = [x.abs().square().sum()]
# check if tracing
if torch.jit.is_tracing():
for layer in self._layers:
x = layer(x)
if trace_powers:
powers.append(x.abs().square().sum())
else:
# with torch.nn.utils.parametrize.cached():
for layer in self._layers:
x = layer(x)
if trace_powers:
powers.append(x.abs().square().sum())
if trace_powers:
return x, powers
return x
__all__ = [
complex_sse_loss,
complex_mse_loss,
angle_mse_loss,
UnitaryLayer,
unitary,
energy_conserving,
@@ -662,6 +692,7 @@ __all__ = [
ZReLU,
MZISingle,
EOActivation,
photodiode,
# SaturableAbsorberLambertW,
# SaturableAbsorber,
# SpreadLayer,