import torch import torch.nn as nn import torch.nn.functional as F # from torchlambertw.special import lambertw def complex_mse_loss(input, target, power=False, normalize=False, reduction="mean"): """ Compute the mean squared error between two complex tensors. If power is set to True, the loss is computed as |input|^2 - |target|^2 """ reduce = getattr(torch, reduction) power_penalty = 0 if power: input = (input * input.conj()).real.to(dtype=input.dtype.to_real()) target = (target * target.conj()).real.to(dtype=target.dtype.to_real()) if normalize: power_penalty = ((input.max() - input.min()) - (target.max() - target.min())) ** 2 power_penalty += (input.min() - target.min()) ** 2 input = input - input.min() input = input / input.max() target = target - target.min() target = target / target.max() else: if normalize: power_penalty = (input.abs().max() - target.abs().max()) ** 2 input = input / input.abs().max() target = target / target.abs().max() if input.is_complex() and target.is_complex(): return reduce(torch.square(input.real - target.real) + torch.square(input.imag - target.imag)) + power_penalty elif input.is_complex() or target.is_complex(): raise ValueError("Input and target must have the same type (real or complex)") else: return F.mse_loss(input, target, reduction=reduction) + power_penalty def complex_sse_loss(input, target): """ Compute the sum squared error between two complex tensors. """ if input.is_complex(): return torch.sum(torch.square(input.real - target.real) + torch.square(input.imag - target.imag)) else: return torch.sum(torch.square(input - target)) class UnitaryLayer(nn.Module): def __init__(self, in_features, out_features, dtype=None): assert in_features >= out_features super(UnitaryLayer, self).__init__() self.in_features = in_features self.out_features = out_features self.weight = nn.Parameter(torch.randn(in_features, out_features, dtype=dtype)) self.reset_parameters() def reset_parameters(self): q, _ = torch.linalg.qr(self.weight) self.weight.data = q def forward(self, x): return torch.matmul(x, self.weight) def __repr__(self): return f"UnitaryLayer({self.in_features}, {self.out_features})" class _Unitary(nn.Module): def forward(self, X: torch.Tensor): if X.ndim < 2: raise ValueError(f"Only tensors with 2 or more dimensions are supported. Got a tensor of shape {X.shape}") n, k = X.size(-2), X.size(-1) transpose = n < k if transpose: X = X.transpose(-2, -1) q, r = torch.linalg.qr(X) # q: torch.Tensor = q # r: torch.Tensor = r d = r.diagonal(dim1=-2, dim2=-1).sgn() q *= d.unsqueeze(-2) if transpose: q = q.transpose(-2, -1) if n == k: mask = (torch.linalg.det(q).abs() >= 0).to(q.dtype.to_real()) mask[mask == 0] = -1 mask = mask.unsqueeze(-1) q[..., 0] *= mask # X.copy_(q) return q def unitary(module: nn.Module, name: str = "weight") -> nn.Module: weight = getattr(module, name, None) if not isinstance(weight, torch.Tensor): raise ValueError(f"Module '{module}' has no parameter or buffer '{name}'") if weight.ndim < 2: raise ValueError(f"Expected a matrix or batch of matrices. Got a tensor of {weight.ndim} dimensions.") if weight.shape[-2] != weight.shape[-1]: raise ValueError(f"Expected a square matrix or batch of square matrices. Got a tensor of shape {weight.shape}") unit = _Unitary() nn.utils.parametrize.register_parametrization(module, name, unit) return module class _SpecialUnitary(nn.Module): def __init__(self): super().__init__() def forward(self, X: torch.Tensor): n, k = X.size(-2), X.size(-1) if n != k: raise ValueError(f"Expected a square matrix. Got a tensor of shape {X.shape}") q, _ = torch.linalg.qr(X) q = q / torch.linalg.det(q).pow(1 / n) return q def special_unitary(module: nn.Module, name: str = "weight") -> nn.Module: weight = getattr(module, name, None) if not isinstance(weight, torch.Tensor): raise ValueError(f"Module '{module}' has no parameter or buffer '{name}'") if weight.ndim < 2: raise ValueError(f"Expected a matrix or batch of matrices. Got a tensor of {weight.ndim} dimensions.") if weight.shape[-2] != weight.shape[-1]: raise ValueError(f"Expected a square matrix or batch of square matrices. Got a tensor of shape {weight.shape}") unit = _SpecialUnitary() nn.utils.parametrize.register_parametrization(module, name, unit) return module class _Clamp(nn.Module): def __init__(self, min, max): super(_Clamp, self).__init__() self.min = min self.max = max def forward(self, x): if x.is_complex(): # clamp magnitude, ignore phase return torch.clamp(x.abs(), self.min, self.max) * x / x.abs() return torch.clamp(x, self.min, self.max) def clamp(module: nn.Module, name: str = "scale", min=0, max=1) -> nn.Module: scale = getattr(module, name, None) if not isinstance(scale, torch.Tensor): raise ValueError(f"Module '{module}' has no parameter or buffer '{name}'") cl = _Clamp(min, max) nn.utils.parametrize.register_parametrization(module, name, cl) return module class _EnergyConserving(nn.Module): def __init__(self): super(_EnergyConserving, self).__init__() def forward(self, X: torch.Tensor): if X.ndim == 2: X = X.unsqueeze(0) spectral_norm = torch.linalg.svdvals(X)[:, 0] return (X / spectral_norm).squeeze() def energy_conserving(module: nn.Module, name: str = "weight") -> nn.Module: param = getattr(module, name, None) if not isinstance(param, torch.Tensor): raise ValueError(f"Module '{module}' has no parameter or buffer '{name}'") if not (2 <= param.ndim <= 3): raise ValueError(f"Expected a matrix or batch of matrices. Got a tensor of {param.ndim} dimensions.") unit = _EnergyConserving() nn.utils.parametrize.register_parametrization(module, name, unit) return module class ONN(nn.Module): def __init__(self, input_dim, output_dim, dtype=None) -> None: super(ONN, self).__init__() self.input_dim = input_dim self.output_dim = output_dim self.dtype = dtype self.dim = max(input_dim, output_dim) # zero pad input to internal size if smaller if self.input_dim < self.dim: self.pad = lambda x: F.pad(x, ((self.dim - self.input_dim) // 2, (self.dim - self.input_dim + 1) // 2)) self.pad.__doc__ = f"Zero pad input from {self.input_dim} to {self.dim}" else: self.pad = lambda x: x self.pad.__doc__ = f"Input size equals internal size {self.dim}" # crop output to desired size if self.output_dim < self.dim: self.crop = lambda x: x[ :, (self.dim - self.output_dim) // 2 : (x.shape[1] - (self.dim - self.output_dim + 1) // 2) ] self.crop.__doc__ = f"Crop output from {self.dim} to {self.output_dim}" else: self.crop = lambda x: x self.crop.__doc__ = f"Output size equals internal size {self.dim}" self.weight = nn.Parameter(torch.randn(self.dim, self.dim, dtype=self.dtype)) # self.scale = nn.Parameter(torch.randn(1, dtype=self.dtype.to_real())+0.5) def reset_parameters(self): q, _ = torch.linalg.qr(self.weight) self.weight.data = q # def get_M(self): # return self.U @ self.sigma @ self.V def forward(self, x): return self.crop(self.pad(x) @ self.weight) class ONNRect(nn.Module): def __init__(self, input_dim, output_dim, square=False, dtype=None): super(ONNRect, self).__init__() self.input_dim = input_dim self.output_dim = output_dim if square: dim = max(input_dim, output_dim) self.weight = nn.Parameter(torch.randn(dim, dim, dtype=dtype)) # zero pad input to internal size if smaller if self.input_dim < dim: self.pad = lambda x: F.pad(x, ((dim - self.input_dim) // 2, (dim - self.input_dim + 1) // 2)) self.pad.__doc__ = f"Zero pad input from {self.input_dim} to {dim}" else: self.pad = lambda x: x self.pad.__doc__ = f"Input size equals internal size {dim}" # crop output to desired size if self.output_dim < dim: self.crop = lambda x: x[ :, (dim - self.output_dim) // 2 : (x.shape[1] - (dim - self.output_dim + 1) // 2) ] self.crop.__doc__ = f"Crop output from {dim} to {self.output_dim}" else: self.crop = lambda x: x self.crop.__doc__ = f"Output size equals internal size {dim}" else: self.weight = nn.Parameter(torch.randn(output_dim, input_dim, dtype=dtype)) self.pad = lambda x: x self.pad.__doc__ = "No padding" self.crop = lambda x: x self.crop.__doc__ = "No cropping" def forward(self, x): x = self.pad(x).to(dtype=self.weight.dtype) out = self.crop((self.weight @ x.mT).mT) return out class polarimeter(nn.Module): def __init__(self): super(polarimeter, self).__init__() # self.input_length = input_length def forward(self, data): # S0 = I # S1 = (2*I_x - I)/I # S2 = (2*I_45 - I)/I # S3 = (2*I_RHC - I)/I # # data: (batch, input_length*2) -> (batch, input_length, 2) data = data.view(data.shape[0], -1, 2) x = data[:, :, 0].mean(dim=1) y = data[:, :, 1].mean(dim=1) # x = x.mean(dim=1) # y = y.mean(dim=1) # angle = torch.atan2(y.abs().square().real, x.abs().square().real) # return torch.stack([angle, angle, angle, angle], dim=1) # horizontal polarisation I_x = x.abs().square() # vertical polarisation I_y = y.abs().square() # 45 degree polarisation I_45 = (x + y).abs().square() # right hand circular polarisation I_RHC = (x + 1j*y).abs().square() # S0 = I_x + I_y # S1 = I_x - I_y # S2 = I_45 - I_m45 # S3 = I_RHC - I_LHC S0 = (I_x + I_y) S1 = ((2*I_x - S0)/S0) S2 = ((2*I_45 - S0)/S0) S3 = ((2*I_RHC - S0)/S0) return torch.stack([S0/S0, S1/S0, S2/S0, S3/S0], dim=1) class normalize_by_first(nn.Module): def __init__(self): super(normalize_by_first, self).__init__() def forward(self, data): return data / data[:, 0].unsqueeze(1) class rotate(nn.Module): def __init__(self): super(rotate, self).__init__() def forward(self, data, angle): # data -> (batch, n*2) # angle -> (batch, n) data_ = data if angle.ndim == 1: angle_ = angle.unsqueeze(1) else: angle_ = angle angle_ = angle_.expand(-1, data_.shape[1]//2) c = torch.cos(angle_) s = torch.sin(angle_) rot = torch.stack([torch.stack([c, -s], dim=2), torch.stack([s, c], dim=2)], dim=3) d = torch.bmm(data_.reshape(-1, 1, 2), rot.view(-1, 2, 2).to(dtype=data_.dtype)).reshape(*data.shape) # d = torch.bmm(data.unsqueeze(-1).mT, rot.to(dtype=data.dtype).mT).mT.squeeze(-1) return d class photodiode(nn.Module): def __init__(self, size, bias=True): super(photodiode, self).__init__() self.input_dim = size self.scale = nn.Parameter(torch.rand(size)) self.pd_bias = nn.Parameter(torch.rand(size)) def forward(self, x): return x.abs().square().to(dtype=x.dtype.to_real()).mul(self.scale).add(self.pd_bias) class input_rotator(nn.Module): def __init__(self, input_dim): super(input_rotator, self).__init__() assert input_dim % 2 == 0, "Input dimension must be even" self.input_dim = input_dim # self.angle = nn.Parameter(torch.randn(1, dtype=self.dtype.to_real())) def forward(self, x, angle=None): # take channels (0,1), (2,3), ... and rotate them by the angle angle = angle or self.angle sine = torch.sin(angle) cosine = torch.cos(angle) rot = torch.tensor([[cosine, -sine], [sine, cosine]], dtype=self.dtype) return torch.matmul(x.view(-1, 2), rot).view(x.shape) # def __repr__(self): # return f"ONNRect({self.input_dim}, {self.output_dim})" # class SaturableAbsorberLambertW(nn.Module): # """ # Implements the activation function for an optical saturable absorber # base eqn: sigma*tau*I0 = 0.5*(log(Tm/T0))/(1-Tm), # where: sigma is the absorption cross section # tau is the radiative lifetime of the absorber material # T0 is the initial transmittance # I0 is the input intensity # Tm is the transmittance of the absorber # The activation function is defined as: # Iout = I0 * Tm(I0) # where Tm(I0) is the transmittance of the absorber as a function of the input intensity I0 # for a unit sigma*tau product, he solution Tm(I0) is given by: # Tm(I0) = (W(2*exp(2*I0)*I0*T0))/(2*I0), # where W is the Lambert W function # if sigma*tau is not 1, I0 has to be scaled by sigma*tau # (-> x has to be scaled by sqrt(sigma*tau)) # """ # def __init__(self, T0): # super(SaturableAbsorberLambertW, self).__init__() # self.register_buffer("T0", torch.tensor(T0)) # def forward(self, x: torch.Tensor): # xc = x.conj() # two_x_xc = (2 * x * xc).real # return (lambertw(2 * torch.exp(two_x_xc) * (x * self.T0 * xc).real) / two_x_xc).to(dtype=x.dtype) # def backward(self, x): # xc = x.conj() # lambert_eval = lambertw(2 * torch.exp(2 * x * xc).real * (x * self.T0 * xc).real) # return (((xc * (-2 * lambert_eval + 2 * torch.square(x) - 1) + 2 * x * torch.square(xc) + x) * lambert_eval) / ( # 2 * torch.pow(x, 3) * xc * (lambert_eval + 1) # )).to(dtype=x.dtype) # class SaturableAbsorber(nn.Module): # def __init__(self, alpha, I0): # super(SaturableAbsorber, self).__init__() # self.register_buffer("alpha", torch.tensor(alpha)) # self.register_buffer("I0", torch.tensor(I0)) # def forward(self, x): # I = (x*x.conj()).to(dtype=x.dtype.to_real()) # A = self.alpha/(1+I/self.I0) # class SpreadLayer(nn.Module): # def __init__(self, in_features, out_features, dtype=None): # super(SpreadLayer, self).__init__() # self.in_features = in_features # self.out_features = out_features # self.mat = torch.ones(in_features, out_features, dtype=dtype)*torch.sqrt(torch.tensor(in_features/out_features)) # def forward(self, x): # # N in_features -> M out_features, Enery is preserved (P = abs(x)^2) # out = torch.matmul(x, self.mat) # return out #### as defined by zhang et alas class DropoutComplex(nn.Module): def __init__(self, p=0.5): super(DropoutComplex, self).__init__() self.dropout = nn.Dropout(p=p) def forward(self, x): if x.is_complex(): mask = self.dropout(torch.ones_like(x.real)) return x * mask else: return self.dropout(x) class Scale(nn.Module): def __init__(self, size): super(Scale, self).__init__() self.size = size self.scale = nn.Parameter(torch.ones(size, dtype=torch.float32)) def forward(self, x): return x * torch.sqrt(self.scale) def __repr__(self): return f"Scale({self.size})" class Identity(nn.Module): """ implements the "activation" function M(z) = z """ def __init__(self, size=None): super(Identity, self).__init__() def forward(self, x): return x class phase_shift(nn.Module): def __init__(self, size): super(phase_shift, self).__init__() self.size = size self.phase = nn.Parameter(torch.rand(size)) def forward(self, x): return x * torch.exp(1j*self.phase) class PowRot(nn.Module): def __init__(self, bias=False): super(PowRot, self).__init__() self.scale = nn.Parameter(torch.tensor(1.0)) if bias: self.bias = nn.Parameter(torch.tensor(0.0)) else: self.register_buffer("bias", torch.tensor(0.0)) def forward(self, x: torch.Tensor): if x.is_complex(): return x * torch.exp(-self.scale * 1j * x.abs().square() + self.bias.to(dtype=x.dtype)) else: return x class MZISingle(nn.Module): def __init__(self, bias, size, func=None): super(MZISingle, self).__init__() self.omega = nn.Parameter(torch.randn(size)) self.phi = nn.Parameter(torch.randn(size)) self.func = func or (lambda x: x.abs().square()) # default to |z|^2 def forward(self, x: torch.Tensor): return x * torch.exp(1j * self.phi) * torch.sin(self.omega + self.func(x)) def naive_angle_loss(x: torch.Tensor, target: torch.Tensor, mod=2*torch.pi): return torch.fmod((x.abs().real - target.abs().real), mod).abs().mean() def cosine_loss(x: torch.Tensor, target: torch.Tensor): return (2*(1 - torch.cos(x - target))).mean() def angle_mse_loss(x: torch.Tensor, target: torch.Tensor): x = torch.fmod(x, 2*torch.pi) target = torch.fmod(target, 2*torch.pi) x_cos = torch.cos(x) x_sin = torch.sin(x) target_cos = torch.cos(target) target_sin = torch.sin(target) cos_diff = x_cos - target_cos sin_diff = x_sin - target_sin squared_diff = cos_diff**2 + sin_diff**2 return squared_diff.mean() class EOActivation(nn.Module): def __init__(self, size=None): # 10.1109/JSTQE.2019.2930455 super(EOActivation, self).__init__() if size is None: raise ValueError("Size must be specified") self.size = size self.alpha = nn.Parameter(torch.rand(size)) self.gain = nn.Parameter(torch.rand(size)) self.V_bias = nn.Parameter(torch.rand(size)) # self.register_buffer("gain", torch.ones(size)) # self.register_buffer("responsivity", torch.ones(size)) # self.register_buffer("V_pi", torch.ones(size)) self.reset_weights() def reset_weights(self): if "alpha" in self._parameters: self.alpha.data = torch.rand(self.size) # if "V_pi" in self._parameters: # self.V_pi.data = torch.rand(self.size)*3 if "V_bias" in self._parameters: self.V_bias.data = torch.randn(self.size) if "gain" in self._parameters: self.gain.data = torch.rand(self.size) # if "responsivity" in self._parameters: # self.responsivity.data = torch.ones(self.size)*0.9 # if "bias" in self._parameters: # self.phase_bias.data = torch.zeros(self.size) def forward(self, x: torch.Tensor): phi_b = torch.pi * self.V_bias# / (self.V_pi) g_phi = torch.pi * (self.alpha * self.gain)# * self.responsivity)# / (self.V_pi) intermediate = g_phi * x.abs().square() + phi_b return ( 1j * torch.sqrt(1 - self.alpha) * torch.exp(-0.5j * intermediate) * torch.cos(0.5 * intermediate) * x ) class Pow(nn.Module): """ implements the activation function M(z) = ||z||^2 + b """ def __init__(self, bias=False): super(Pow, self).__init__() if bias: self.bias = nn.Parameter(torch.tensor(0.0)) else: self.register_buffer("bias", torch.tensor(0.0)) def forward(self, x: torch.Tensor): return x.abs().square().add(self.bias).to(dtype=x.dtype) class Mag(nn.Module): """ implements the activation function M(z) = ||z||+b """ def __init__(self, bias=False): super(Mag, self).__init__() if bias: self.bias = nn.Parameter(torch.tensor(0.0)) else: self.register_buffer("bias", torch.tensor(0.0)) def forward(self, x: torch.Tensor): return x.abs().add(self.bias).to(dtype=x.dtype) class MagScale(nn.Module): def __init__(self, bias=False): super(MagScale, self).__init__() if bias: self.bias = nn.Parameter(torch.tensor(0.0)) else: self.register_buffer("bias", torch.tensor(0.0)) def forward(self, x: torch.Tensor): return x.abs().add(self.bias).to(dtype=x.dtype).sin().mul(x) class PowScale(nn.Module): def __init__(self, bias=False): super(PowScale, self).__init__() if bias: self.bias = nn.Parameter(torch.tensor(0.0)) else: self.register_buffer("bias", torch.tensor(0.0)) def forward(self, x: torch.Tensor): return x.mul(x.abs().square().add(self.bias).to(dtype=x.dtype).sin()) class ModReLU(nn.Module): """ implements the activation function M(z) = ReLU(||z|| + b)*exp(j*theta_z) = ReLU(||z|| + b)*z/||z|| """ def __init__(self, bias=True): super(ModReLU, self).__init__() if bias: self.bias = nn.Parameter(torch.tensor(0.0)) else: self.register_buffer("bias", torch.tensor(0.0)) def forward(self, x): if x.is_complex(): mod = x.abs() out = torch.relu(mod + self.bias) * x / mod return out.to(dtype=x.dtype) else: return torch.relu(x + self.bias).to(dtype=x.dtype) def __repr__(self): return f"ModReLU(b={self.b})" class CReLU(nn.Module): """ implements the activation function M(z) = ReLU(Re(z)) + j*ReLU(Im(z)) """ def __init__(self): super(CReLU, self).__init__() def forward(self, x): if x.is_complex(): return torch.relu(x.real) + 1j * torch.relu(x.imag) else: return torch.relu(x) class ZReLU(nn.Module): """ implements the activation function M(z) = z if 0 <= angle(z) <= pi/2 = 0 otherwise """ def __init__(self): super(ZReLU, self).__init__() def forward(self, x): if x.is_complex(): return x * (torch.angle(x) >= 0) * (torch.angle(x) <= torch.pi / 2) else: return torch.relu(x) __all__ = [ complex_sse_loss, complex_mse_loss, angle_mse_loss, UnitaryLayer, unitary, energy_conserving, clamp, ONN, ONNRect, DropoutComplex, Identity, Pow, PowRot, Mag, ModReLU, CReLU, ZReLU, MZISingle, EOActivation, photodiode, phase_shift, # SaturableAbsorberLambertW, # SaturableAbsorber, # SpreadLayer, ]