727 lines
23 KiB
Python
727 lines
23 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
# from torchlambertw.special import lambertw
|
|
|
|
|
|
def complex_mse_loss(input, target, power=False, normalize=False, reduction="mean"):
|
|
"""
|
|
Compute the mean squared error between two complex tensors.
|
|
If power is set to True, the loss is computed as |input|^2 - |target|^2
|
|
"""
|
|
reduce = getattr(torch, reduction)
|
|
power_penalty = 0
|
|
|
|
if power:
|
|
input = (input * input.conj()).real.to(dtype=input.dtype.to_real())
|
|
target = (target * target.conj()).real.to(dtype=target.dtype.to_real())
|
|
if normalize:
|
|
power_penalty = ((input.max() - input.min()) - (target.max() - target.min())) ** 2
|
|
power_penalty += (input.min() - target.min()) ** 2
|
|
input = input - input.min()
|
|
input = input / input.max()
|
|
target = target - target.min()
|
|
target = target / target.max()
|
|
else:
|
|
if normalize:
|
|
power_penalty = (input.abs().max() - target.abs().max()) ** 2
|
|
input = input / input.abs().max()
|
|
target = target / target.abs().max()
|
|
|
|
if input.is_complex() and target.is_complex():
|
|
return reduce(torch.square(input.real - target.real) + torch.square(input.imag - target.imag)) + power_penalty
|
|
elif input.is_complex() or target.is_complex():
|
|
raise ValueError("Input and target must have the same type (real or complex)")
|
|
else:
|
|
return F.mse_loss(input, target, reduction=reduction) + power_penalty
|
|
|
|
|
|
def complex_sse_loss(input, target):
|
|
"""
|
|
Compute the sum squared error between two complex tensors.
|
|
"""
|
|
if input.is_complex():
|
|
return torch.sum(torch.square(input.real - target.real) + torch.square(input.imag - target.imag))
|
|
else:
|
|
return torch.sum(torch.square(input - target))
|
|
|
|
|
|
class UnitaryLayer(nn.Module):
|
|
def __init__(self, in_features, out_features, dtype=None):
|
|
assert in_features >= out_features
|
|
super(UnitaryLayer, self).__init__()
|
|
self.in_features = in_features
|
|
self.out_features = out_features
|
|
self.weight = nn.Parameter(torch.randn(in_features, out_features, dtype=dtype))
|
|
self.reset_parameters()
|
|
|
|
def reset_parameters(self):
|
|
q, _ = torch.linalg.qr(self.weight)
|
|
self.weight.data = q
|
|
|
|
def forward(self, x):
|
|
return torch.matmul(x, self.weight)
|
|
|
|
def __repr__(self):
|
|
return f"UnitaryLayer({self.in_features}, {self.out_features})"
|
|
|
|
|
|
class _Unitary(nn.Module):
|
|
def forward(self, X: torch.Tensor):
|
|
if X.ndim < 2:
|
|
raise ValueError(f"Only tensors with 2 or more dimensions are supported. Got a tensor of shape {X.shape}")
|
|
n, k = X.size(-2), X.size(-1)
|
|
transpose = n < k
|
|
if transpose:
|
|
X = X.transpose(-2, -1)
|
|
q, r = torch.linalg.qr(X)
|
|
# q: torch.Tensor = q
|
|
# r: torch.Tensor = r
|
|
d = r.diagonal(dim1=-2, dim2=-1).sgn()
|
|
q *= d.unsqueeze(-2)
|
|
if transpose:
|
|
q = q.transpose(-2, -1)
|
|
if n == k:
|
|
mask = (torch.linalg.det(q).abs() >= 0).to(q.dtype.to_real())
|
|
mask[mask == 0] = -1
|
|
mask = mask.unsqueeze(-1)
|
|
q[..., 0] *= mask
|
|
# X.copy_(q)
|
|
return q
|
|
|
|
|
|
def unitary(module: nn.Module, name: str = "weight") -> nn.Module:
|
|
weight = getattr(module, name, None)
|
|
if not isinstance(weight, torch.Tensor):
|
|
raise ValueError(f"Module '{module}' has no parameter or buffer '{name}'")
|
|
|
|
if weight.ndim < 2:
|
|
raise ValueError(f"Expected a matrix or batch of matrices. Got a tensor of {weight.ndim} dimensions.")
|
|
|
|
if weight.shape[-2] != weight.shape[-1]:
|
|
raise ValueError(f"Expected a square matrix or batch of square matrices. Got a tensor of shape {weight.shape}")
|
|
|
|
unit = _Unitary()
|
|
nn.utils.parametrize.register_parametrization(module, name, unit)
|
|
return module
|
|
|
|
|
|
class _SpecialUnitary(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, X: torch.Tensor):
|
|
n, k = X.size(-2), X.size(-1)
|
|
if n != k:
|
|
raise ValueError(f"Expected a square matrix. Got a tensor of shape {X.shape}")
|
|
q, _ = torch.linalg.qr(X)
|
|
q = q / torch.linalg.det(q).pow(1 / n)
|
|
|
|
return q
|
|
|
|
|
|
def special_unitary(module: nn.Module, name: str = "weight") -> nn.Module:
|
|
weight = getattr(module, name, None)
|
|
if not isinstance(weight, torch.Tensor):
|
|
raise ValueError(f"Module '{module}' has no parameter or buffer '{name}'")
|
|
|
|
if weight.ndim < 2:
|
|
raise ValueError(f"Expected a matrix or batch of matrices. Got a tensor of {weight.ndim} dimensions.")
|
|
|
|
if weight.shape[-2] != weight.shape[-1]:
|
|
raise ValueError(f"Expected a square matrix or batch of square matrices. Got a tensor of shape {weight.shape}")
|
|
|
|
unit = _SpecialUnitary()
|
|
nn.utils.parametrize.register_parametrization(module, name, unit)
|
|
return module
|
|
|
|
|
|
class _Clamp(nn.Module):
|
|
def __init__(self, min, max):
|
|
super(_Clamp, self).__init__()
|
|
self.min = min
|
|
self.max = max
|
|
|
|
def forward(self, x):
|
|
if x.is_complex():
|
|
# clamp magnitude, ignore phase
|
|
return torch.clamp(x.abs(), self.min, self.max) * x / x.abs()
|
|
return torch.clamp(x, self.min, self.max)
|
|
|
|
|
|
def clamp(module: nn.Module, name: str = "scale", min=0, max=1) -> nn.Module:
|
|
scale = getattr(module, name, None)
|
|
if not isinstance(scale, torch.Tensor):
|
|
raise ValueError(f"Module '{module}' has no parameter or buffer '{name}'")
|
|
|
|
cl = _Clamp(min, max)
|
|
nn.utils.parametrize.register_parametrization(module, name, cl)
|
|
return module
|
|
|
|
|
|
class _EnergyConserving(nn.Module):
|
|
def __init__(self):
|
|
super(_EnergyConserving, self).__init__()
|
|
|
|
def forward(self, X: torch.Tensor):
|
|
if X.ndim == 2:
|
|
X = X.unsqueeze(0)
|
|
spectral_norm = torch.linalg.svdvals(X)[:, 0]
|
|
return (X / spectral_norm).squeeze()
|
|
|
|
|
|
def energy_conserving(module: nn.Module, name: str = "weight") -> nn.Module:
|
|
param = getattr(module, name, None)
|
|
if not isinstance(param, torch.Tensor):
|
|
raise ValueError(f"Module '{module}' has no parameter or buffer '{name}'")
|
|
|
|
if not (2 <= param.ndim <= 3):
|
|
raise ValueError(f"Expected a matrix or batch of matrices. Got a tensor of {param.ndim} dimensions.")
|
|
|
|
unit = _EnergyConserving()
|
|
nn.utils.parametrize.register_parametrization(module, name, unit)
|
|
return module
|
|
|
|
|
|
class ONN(nn.Module):
|
|
def __init__(self, input_dim, output_dim, dtype=None) -> None:
|
|
super(ONN, self).__init__()
|
|
self.input_dim = input_dim
|
|
self.output_dim = output_dim
|
|
self.dtype = dtype
|
|
|
|
self.dim = max(input_dim, output_dim)
|
|
|
|
# zero pad input to internal size if smaller
|
|
if self.input_dim < self.dim:
|
|
self.pad = lambda x: F.pad(x, ((self.dim - self.input_dim) // 2, (self.dim - self.input_dim + 1) // 2))
|
|
self.pad.__doc__ = f"Zero pad input from {self.input_dim} to {self.dim}"
|
|
else:
|
|
self.pad = lambda x: x
|
|
self.pad.__doc__ = f"Input size equals internal size {self.dim}"
|
|
|
|
# crop output to desired size
|
|
if self.output_dim < self.dim:
|
|
self.crop = lambda x: x[
|
|
:, (self.dim - self.output_dim) // 2 : (x.shape[1] - (self.dim - self.output_dim + 1) // 2)
|
|
]
|
|
self.crop.__doc__ = f"Crop output from {self.dim} to {self.output_dim}"
|
|
else:
|
|
self.crop = lambda x: x
|
|
self.crop.__doc__ = f"Output size equals internal size {self.dim}"
|
|
|
|
self.weight = nn.Parameter(torch.randn(self.dim, self.dim, dtype=self.dtype))
|
|
# self.scale = nn.Parameter(torch.randn(1, dtype=self.dtype.to_real())+0.5)
|
|
|
|
def reset_parameters(self):
|
|
q, _ = torch.linalg.qr(self.weight)
|
|
self.weight.data = q
|
|
|
|
# def get_M(self):
|
|
# return self.U @ self.sigma @ self.V
|
|
|
|
def forward(self, x):
|
|
return self.crop(self.pad(x) @ self.weight)
|
|
|
|
|
|
class ONNRect(nn.Module):
|
|
def __init__(self, input_dim, output_dim, square=False, dtype=None):
|
|
super(ONNRect, self).__init__()
|
|
self.input_dim = input_dim
|
|
self.output_dim = output_dim
|
|
|
|
if square:
|
|
dim = max(input_dim, output_dim)
|
|
self.weight = nn.Parameter(torch.randn(dim, dim, dtype=dtype))
|
|
|
|
# zero pad input to internal size if smaller
|
|
if self.input_dim < dim:
|
|
self.pad = lambda x: F.pad(x, ((dim - self.input_dim) // 2, (dim - self.input_dim + 1) // 2))
|
|
self.pad.__doc__ = f"Zero pad input from {self.input_dim} to {dim}"
|
|
else:
|
|
self.pad = lambda x: x
|
|
self.pad.__doc__ = f"Input size equals internal size {dim}"
|
|
|
|
# crop output to desired size
|
|
if self.output_dim < dim:
|
|
self.crop = lambda x: x[
|
|
:, (dim - self.output_dim) // 2 : (x.shape[1] - (dim - self.output_dim + 1) // 2)
|
|
]
|
|
self.crop.__doc__ = f"Crop output from {dim} to {self.output_dim}"
|
|
else:
|
|
self.crop = lambda x: x
|
|
self.crop.__doc__ = f"Output size equals internal size {dim}"
|
|
|
|
|
|
else:
|
|
self.weight = nn.Parameter(torch.randn(output_dim, input_dim, dtype=dtype))
|
|
self.pad = lambda x: x
|
|
self.pad.__doc__ = "No padding"
|
|
self.crop = lambda x: x
|
|
self.crop.__doc__ = "No cropping"
|
|
|
|
def forward(self, x):
|
|
x = self.pad(x).to(dtype=self.weight.dtype)
|
|
out = self.crop((self.weight @ x.mT).mT)
|
|
return out
|
|
|
|
class polarimeter(nn.Module):
|
|
def __init__(self):
|
|
super(polarimeter, self).__init__()
|
|
# self.input_length = input_length
|
|
|
|
def forward(self, data):
|
|
# S0 = I
|
|
# S1 = (2*I_x - I)/I
|
|
# S2 = (2*I_45 - I)/I
|
|
# S3 = (2*I_RHC - I)/I
|
|
|
|
# # data: (batch, input_length*2) -> (batch, input_length, 2)
|
|
data = data.view(data.shape[0], -1, 2)
|
|
x = data[:, :, 0].mean(dim=1)
|
|
y = data[:, :, 1].mean(dim=1)
|
|
|
|
# x = x.mean(dim=1)
|
|
# y = y.mean(dim=1)
|
|
|
|
# angle = torch.atan2(y.abs().square().real, x.abs().square().real)
|
|
|
|
# return torch.stack([angle, angle, angle, angle], dim=1)
|
|
|
|
# horizontal polarisation
|
|
I_x = x.abs().square()
|
|
|
|
# vertical polarisation
|
|
I_y = y.abs().square()
|
|
|
|
# 45 degree polarisation
|
|
I_45 = (x + y).abs().square()
|
|
|
|
|
|
# right hand circular polarisation
|
|
I_RHC = (x + 1j*y).abs().square()
|
|
|
|
# S0 = I_x + I_y
|
|
# S1 = I_x - I_y
|
|
# S2 = I_45 - I_m45
|
|
# S3 = I_RHC - I_LHC
|
|
|
|
S0 = (I_x + I_y)
|
|
S1 = ((2*I_x - S0)/S0)
|
|
S2 = ((2*I_45 - S0)/S0)
|
|
S3 = ((2*I_RHC - S0)/S0)
|
|
|
|
return torch.stack([S0/S0, S1/S0, S2/S0, S3/S0], dim=1)
|
|
|
|
class normalize_by_first(nn.Module):
|
|
def __init__(self):
|
|
super(normalize_by_first, self).__init__()
|
|
|
|
def forward(self, data):
|
|
return data / data[:, 0].unsqueeze(1)
|
|
|
|
class rotate(nn.Module):
|
|
def __init__(self):
|
|
super(rotate, self).__init__()
|
|
|
|
def forward(self, data, angle):
|
|
# data -> (batch, n*2)
|
|
# angle -> (batch, n)
|
|
data_ = data
|
|
if angle.ndim == 1:
|
|
angle_ = angle.unsqueeze(1)
|
|
else:
|
|
angle_ = angle
|
|
angle_ = angle_.expand(-1, data_.shape[1]//2)
|
|
c = torch.cos(angle_)
|
|
s = torch.sin(angle_)
|
|
rot = torch.stack([torch.stack([c, -s], dim=2),
|
|
torch.stack([s, c], dim=2)], dim=3)
|
|
d = torch.bmm(data_.reshape(-1, 1, 2), rot.view(-1, 2, 2).to(dtype=data_.dtype)).reshape(*data.shape)
|
|
# d = torch.bmm(data.unsqueeze(-1).mT, rot.to(dtype=data.dtype).mT).mT.squeeze(-1)
|
|
|
|
return d
|
|
|
|
|
|
class photodiode(nn.Module):
|
|
def __init__(self, size, bias=True):
|
|
super(photodiode, self).__init__()
|
|
self.input_dim = size
|
|
self.scale = nn.Parameter(torch.rand(size))
|
|
self.pd_bias = nn.Parameter(torch.rand(size))
|
|
|
|
def forward(self, x):
|
|
return x.abs().square().to(dtype=x.dtype.to_real()).mul(self.scale).add(self.pd_bias)
|
|
|
|
|
|
class input_rotator(nn.Module):
|
|
def __init__(self, input_dim):
|
|
super(input_rotator, self).__init__()
|
|
assert input_dim % 2 == 0, "Input dimension must be even"
|
|
self.input_dim = input_dim
|
|
# self.angle = nn.Parameter(torch.randn(1, dtype=self.dtype.to_real()))
|
|
|
|
def forward(self, x, angle=None):
|
|
# take channels (0,1), (2,3), ... and rotate them by the angle
|
|
angle = angle or self.angle
|
|
sine = torch.sin(angle)
|
|
cosine = torch.cos(angle)
|
|
rot = torch.tensor([[cosine, -sine], [sine, cosine]], dtype=self.dtype)
|
|
return torch.matmul(x.view(-1, 2), rot).view(x.shape)
|
|
|
|
|
|
|
|
# def __repr__(self):
|
|
# return f"ONNRect({self.input_dim}, {self.output_dim})"
|
|
|
|
|
|
# class SaturableAbsorberLambertW(nn.Module):
|
|
# """
|
|
# Implements the activation function for an optical saturable absorber
|
|
|
|
# base eqn: sigma*tau*I0 = 0.5*(log(Tm/T0))/(1-Tm),
|
|
# where: sigma is the absorption cross section
|
|
# tau is the radiative lifetime of the absorber material
|
|
# T0 is the initial transmittance
|
|
# I0 is the input intensity
|
|
# Tm is the transmittance of the absorber
|
|
|
|
# The activation function is defined as:
|
|
# Iout = I0 * Tm(I0)
|
|
|
|
# where Tm(I0) is the transmittance of the absorber as a function of the input intensity I0
|
|
|
|
# for a unit sigma*tau product, he solution Tm(I0) is given by:
|
|
# Tm(I0) = (W(2*exp(2*I0)*I0*T0))/(2*I0),
|
|
# where W is the Lambert W function
|
|
|
|
# if sigma*tau is not 1, I0 has to be scaled by sigma*tau
|
|
# (-> x has to be scaled by sqrt(sigma*tau))
|
|
|
|
# """
|
|
|
|
# def __init__(self, T0):
|
|
# super(SaturableAbsorberLambertW, self).__init__()
|
|
# self.register_buffer("T0", torch.tensor(T0))
|
|
|
|
# def forward(self, x: torch.Tensor):
|
|
# xc = x.conj()
|
|
# two_x_xc = (2 * x * xc).real
|
|
# return (lambertw(2 * torch.exp(two_x_xc) * (x * self.T0 * xc).real) / two_x_xc).to(dtype=x.dtype)
|
|
|
|
# def backward(self, x):
|
|
# xc = x.conj()
|
|
# lambert_eval = lambertw(2 * torch.exp(2 * x * xc).real * (x * self.T0 * xc).real)
|
|
# return (((xc * (-2 * lambert_eval + 2 * torch.square(x) - 1) + 2 * x * torch.square(xc) + x) * lambert_eval) / (
|
|
# 2 * torch.pow(x, 3) * xc * (lambert_eval + 1)
|
|
# )).to(dtype=x.dtype)
|
|
|
|
|
|
# class SaturableAbsorber(nn.Module):
|
|
# def __init__(self, alpha, I0):
|
|
# super(SaturableAbsorber, self).__init__()
|
|
# self.register_buffer("alpha", torch.tensor(alpha))
|
|
# self.register_buffer("I0", torch.tensor(I0))
|
|
|
|
# def forward(self, x):
|
|
# I = (x*x.conj()).to(dtype=x.dtype.to_real())
|
|
# A = self.alpha/(1+I/self.I0)
|
|
|
|
|
|
# class SpreadLayer(nn.Module):
|
|
# def __init__(self, in_features, out_features, dtype=None):
|
|
# super(SpreadLayer, self).__init__()
|
|
# self.in_features = in_features
|
|
# self.out_features = out_features
|
|
# self.mat = torch.ones(in_features, out_features, dtype=dtype)*torch.sqrt(torch.tensor(in_features/out_features))
|
|
|
|
# def forward(self, x):
|
|
# # N in_features -> M out_features, Enery is preserved (P = abs(x)^2)
|
|
# out = torch.matmul(x, self.mat)
|
|
# return out
|
|
|
|
|
|
#### as defined by zhang et alas
|
|
|
|
class DropoutComplex(nn.Module):
|
|
def __init__(self, p=0.5):
|
|
super(DropoutComplex, self).__init__()
|
|
self.dropout = nn.Dropout(p=p)
|
|
|
|
def forward(self, x):
|
|
if x.is_complex():
|
|
mask = self.dropout(torch.ones_like(x.real))
|
|
return x * mask
|
|
else:
|
|
return self.dropout(x)
|
|
|
|
|
|
class Scale(nn.Module):
|
|
def __init__(self, size):
|
|
super(Scale, self).__init__()
|
|
self.size = size
|
|
self.scale = nn.Parameter(torch.ones(size, dtype=torch.float32))
|
|
|
|
def forward(self, x):
|
|
return x * torch.sqrt(self.scale)
|
|
|
|
def __repr__(self):
|
|
return f"Scale({self.size})"
|
|
|
|
|
|
class Identity(nn.Module):
|
|
"""
|
|
implements the "activation" function
|
|
M(z) = z
|
|
"""
|
|
|
|
def __init__(self, size=None):
|
|
super(Identity, self).__init__()
|
|
|
|
def forward(self, x):
|
|
return x
|
|
|
|
class phase_shift(nn.Module):
|
|
def __init__(self, size):
|
|
super(phase_shift, self).__init__()
|
|
self.size = size
|
|
self.phase = nn.Parameter(torch.rand(size))
|
|
|
|
def forward(self, x):
|
|
return x * torch.exp(1j*self.phase)
|
|
|
|
|
|
class PowRot(nn.Module):
|
|
def __init__(self, bias=False):
|
|
super(PowRot, self).__init__()
|
|
self.scale = nn.Parameter(torch.tensor(1.0))
|
|
if bias:
|
|
self.bias = nn.Parameter(torch.tensor(0.0))
|
|
else:
|
|
self.register_buffer("bias", torch.tensor(0.0))
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
if x.is_complex():
|
|
return x * torch.exp(-self.scale * 1j * x.abs().square() + self.bias.to(dtype=x.dtype))
|
|
else:
|
|
return x
|
|
|
|
|
|
class MZISingle(nn.Module):
|
|
def __init__(self, bias, size, func=None):
|
|
super(MZISingle, self).__init__()
|
|
self.omega = nn.Parameter(torch.randn(size))
|
|
self.phi = nn.Parameter(torch.randn(size))
|
|
self.func = func or (lambda x: x.abs().square()) # default to |z|^2
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
return x * torch.exp(1j * self.phi) * torch.sin(self.omega + self.func(x))
|
|
|
|
def naive_angle_loss(x: torch.Tensor, target: torch.Tensor, mod=2*torch.pi):
|
|
return torch.fmod((x.abs().real - target.abs().real), mod).abs().mean()
|
|
|
|
def cosine_loss(x: torch.Tensor, target: torch.Tensor):
|
|
return (2*(1 - torch.cos(x - target))).mean()
|
|
|
|
def angle_mse_loss(x: torch.Tensor, target: torch.Tensor):
|
|
x = torch.fmod(x, 2*torch.pi)
|
|
target = torch.fmod(target, 2*torch.pi)
|
|
|
|
x_cos = torch.cos(x)
|
|
x_sin = torch.sin(x)
|
|
target_cos = torch.cos(target)
|
|
target_sin = torch.sin(target)
|
|
|
|
cos_diff = x_cos - target_cos
|
|
sin_diff = x_sin - target_sin
|
|
squared_diff = cos_diff**2 + sin_diff**2
|
|
return squared_diff.mean()
|
|
|
|
class EOActivation(nn.Module):
|
|
def __init__(self, size=None):
|
|
# 10.1109/JSTQE.2019.2930455
|
|
super(EOActivation, self).__init__()
|
|
if size is None:
|
|
raise ValueError("Size must be specified")
|
|
self.size = size
|
|
self.alpha = nn.Parameter(torch.rand(size))
|
|
self.gain = nn.Parameter(torch.rand(size))
|
|
self.V_bias = nn.Parameter(torch.rand(size))
|
|
# self.register_buffer("gain", torch.ones(size))
|
|
# self.register_buffer("responsivity", torch.ones(size))
|
|
# self.register_buffer("V_pi", torch.ones(size))
|
|
|
|
self.reset_weights()
|
|
|
|
def reset_weights(self):
|
|
if "alpha" in self._parameters:
|
|
self.alpha.data = torch.rand(self.size)
|
|
# if "V_pi" in self._parameters:
|
|
# self.V_pi.data = torch.rand(self.size)*3
|
|
if "V_bias" in self._parameters:
|
|
self.V_bias.data = torch.randn(self.size)
|
|
if "gain" in self._parameters:
|
|
self.gain.data = torch.rand(self.size)
|
|
# if "responsivity" in self._parameters:
|
|
# self.responsivity.data = torch.ones(self.size)*0.9
|
|
# if "bias" in self._parameters:
|
|
# self.phase_bias.data = torch.zeros(self.size)
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
phi_b = torch.pi * self.V_bias# / (self.V_pi)
|
|
g_phi = torch.pi * (self.alpha * self.gain)# * self.responsivity)# / (self.V_pi)
|
|
intermediate = g_phi * x.abs().square() + phi_b
|
|
return (
|
|
1j
|
|
* torch.sqrt(1 - self.alpha)
|
|
* torch.exp(-0.5j * intermediate)
|
|
* torch.cos(0.5 * intermediate)
|
|
* x
|
|
)
|
|
|
|
class Pow(nn.Module):
|
|
"""
|
|
implements the activation function
|
|
M(z) = ||z||^2 + b
|
|
"""
|
|
|
|
def __init__(self, bias=False):
|
|
super(Pow, self).__init__()
|
|
if bias:
|
|
self.bias = nn.Parameter(torch.tensor(0.0))
|
|
else:
|
|
self.register_buffer("bias", torch.tensor(0.0))
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
return x.abs().square().add(self.bias).to(dtype=x.dtype)
|
|
|
|
|
|
class Mag(nn.Module):
|
|
"""
|
|
implements the activation function
|
|
M(z) = ||z||+b
|
|
"""
|
|
|
|
def __init__(self, bias=False):
|
|
super(Mag, self).__init__()
|
|
if bias:
|
|
self.bias = nn.Parameter(torch.tensor(0.0))
|
|
else:
|
|
self.register_buffer("bias", torch.tensor(0.0))
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
return x.abs().add(self.bias).to(dtype=x.dtype)
|
|
|
|
|
|
class MagScale(nn.Module):
|
|
def __init__(self, bias=False):
|
|
super(MagScale, self).__init__()
|
|
if bias:
|
|
self.bias = nn.Parameter(torch.tensor(0.0))
|
|
else:
|
|
self.register_buffer("bias", torch.tensor(0.0))
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
return x.abs().add(self.bias).to(dtype=x.dtype).sin().mul(x)
|
|
|
|
|
|
class PowScale(nn.Module):
|
|
def __init__(self, bias=False):
|
|
super(PowScale, self).__init__()
|
|
if bias:
|
|
self.bias = nn.Parameter(torch.tensor(0.0))
|
|
else:
|
|
self.register_buffer("bias", torch.tensor(0.0))
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
return x.mul(x.abs().square().add(self.bias).to(dtype=x.dtype).sin())
|
|
|
|
|
|
class ModReLU(nn.Module):
|
|
"""
|
|
implements the activation function
|
|
M(z) = ReLU(||z|| + b)*exp(j*theta_z)
|
|
= ReLU(||z|| + b)*z/||z||
|
|
"""
|
|
|
|
def __init__(self, bias=True):
|
|
super(ModReLU, self).__init__()
|
|
if bias:
|
|
self.bias = nn.Parameter(torch.tensor(0.0))
|
|
else:
|
|
self.register_buffer("bias", torch.tensor(0.0))
|
|
|
|
def forward(self, x):
|
|
if x.is_complex():
|
|
mod = x.abs()
|
|
out = torch.relu(mod + self.bias) * x / mod
|
|
return out.to(dtype=x.dtype)
|
|
|
|
else:
|
|
return torch.relu(x + self.bias).to(dtype=x.dtype)
|
|
|
|
def __repr__(self):
|
|
return f"ModReLU(b={self.b})"
|
|
|
|
|
|
class CReLU(nn.Module):
|
|
"""
|
|
implements the activation function
|
|
M(z) = ReLU(Re(z)) + j*ReLU(Im(z))
|
|
"""
|
|
|
|
def __init__(self):
|
|
super(CReLU, self).__init__()
|
|
|
|
def forward(self, x):
|
|
if x.is_complex():
|
|
return torch.relu(x.real) + 1j * torch.relu(x.imag)
|
|
else:
|
|
return torch.relu(x)
|
|
|
|
|
|
class ZReLU(nn.Module):
|
|
"""
|
|
implements the activation function
|
|
|
|
M(z) = z if 0 <= angle(z) <= pi/2
|
|
= 0 otherwise
|
|
"""
|
|
|
|
def __init__(self):
|
|
super(ZReLU, self).__init__()
|
|
|
|
def forward(self, x):
|
|
if x.is_complex():
|
|
return x * (torch.angle(x) >= 0) * (torch.angle(x) <= torch.pi / 2)
|
|
else:
|
|
return torch.relu(x)
|
|
|
|
|
|
__all__ = [
|
|
complex_sse_loss,
|
|
complex_mse_loss,
|
|
angle_mse_loss,
|
|
UnitaryLayer,
|
|
unitary,
|
|
energy_conserving,
|
|
clamp,
|
|
ONN,
|
|
ONNRect,
|
|
DropoutComplex,
|
|
Identity,
|
|
Pow,
|
|
PowRot,
|
|
Mag,
|
|
ModReLU,
|
|
CReLU,
|
|
ZReLU,
|
|
MZISingle,
|
|
EOActivation,
|
|
photodiode,
|
|
phase_shift,
|
|
# SaturableAbsorberLambertW,
|
|
# SaturableAbsorber,
|
|
# SpreadLayer,
|
|
]
|