define new activation functions and parametrizations

This commit is contained in:
Joseph Hopfmüller
2024-11-29 15:51:25 +01:00
parent bdf6f5bfb8
commit 487288c923

View File

@@ -1,16 +1,26 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
# from torchlambertw.special import lambertw
def complex_mse_loss(input, target): def complex_mse_loss(input, target, power=False, reduction="mean"):
""" """
Compute the mean squared error between two complex tensors. Compute the mean squared error between two complex tensors.
If power is set to True, the loss is computed as |input|^2 - |target|^2
""" """
if input.is_complex(): reduce = getattr(torch, reduction)
return torch.mean(torch.square(input.real - target.real) + torch.square(input.imag - target.imag))
if power:
input = (input * input.conj()).real.to(dtype=input.dtype.to_real())
target = (target * target.conj()).real.to(dtype=target.dtype.to_real())
if input.is_complex() and target.is_complex():
return reduce(torch.square(input.real - target.real) + torch.square(input.imag - target.imag))
elif input.is_complex() or target.is_complex():
raise ValueError("Input and target must have the same type (real or complex)")
else: else:
return F.mse_loss(input, target) return F.mse_loss(input, target, reduction=reduction)
def complex_sse_loss(input, target): def complex_sse_loss(input, target):
@@ -43,6 +53,174 @@ class UnitaryLayer(nn.Module):
return f"UnitaryLayer({self.in_features}, {self.out_features})" return f"UnitaryLayer({self.in_features}, {self.out_features})"
class _Unitary(nn.Module):
def forward(self, X:torch.Tensor):
if X.ndim < 2:
raise ValueError(
"Only tensors with 2 or more dimensions are supported. "
f"Got a tensor of shape {X.shape}"
)
n, k = X.size(-2), X.size(-1)
transpose = n<k
if transpose:
X = X.transpose(-2, -1)
q, r = torch.linalg.qr(X)
# q: torch.Tensor = q
# r: torch.Tensor = r
d = r.diagonal(dim1=-2, dim2=-1).sgn()
q*=d.unsqueeze(-2)
if transpose:
q = q.transpose(-2, -1)
if n == k:
mask = (torch.linalg.det(q).abs() >= 0).to(q.dtype.to_real())
mask[mask == 0] = -1
mask = mask.unsqueeze(-1)
q[..., 0] *= mask
# X.copy_(q)
return q
def unitary(module: nn.Module, name: str = "weight") -> nn.Module:
weight = getattr(module, name, None)
if not isinstance(weight, torch.Tensor):
raise ValueError(f"Module '{module}' has no parameter or buffer '{name}'")
if weight.ndim < 2:
raise ValueError(f"Expected a matrix or batch of matrices. Got a tensor of {weight.ndim} dimensions.")
if weight.shape[-2] != weight.shape[-1]:
raise ValueError(f"Expected a square matrix or batch of square matrices. Got a tensor of shape {weight.shape}")
unit = _Unitary()
nn.utils.parametrize.register_parametrization(module, name, unit)
return module
class _SpecialUnitary(nn.Module):
def __init__(self):
super().__init__()
def forward(self, X:torch.Tensor):
n, k = X.size(-2), X.size(-1)
if n != k:
raise ValueError(f"Expected a square matrix. Got a tensor of shape {X.shape}")
q, _ = torch.linalg.qr(X)
q = q / torch.linalg.det(q).pow(1/n)
return q
def special_unitary(module: nn.Module, name: str = "weight") -> nn.Module:
weight = getattr(module, name, None)
if not isinstance(weight, torch.Tensor):
raise ValueError(f"Module '{module}' has no parameter or buffer '{name}'")
if weight.ndim < 2:
raise ValueError(f"Expected a matrix or batch of matrices. Got a tensor of {weight.ndim} dimensions.")
if weight.shape[-2] != weight.shape[-1]:
raise ValueError(f"Expected a square matrix or batch of square matrices. Got a tensor of shape {weight.shape}")
unit = _SpecialUnitary()
nn.utils.parametrize.register_parametrization(module, name, unit)
return module
class _Clamp(nn.Module):
def __init__(self, min, max):
super(_Clamp, self).__init__()
self.min = min
self.max = max
def forward(self, x):
if x.is_complex():
# clamp magnitude, ignore phase
return torch.clamp(x.abs(), self.min, self.max) * x / x.abs()
return torch.clamp(x, self.min, self.max)
def clamp(module: nn.Module, name: str = "scale", min=0, max=1) -> nn.Module:
scale = getattr(module, name, None)
if not isinstance(scale, torch.Tensor):
raise ValueError(f"Module '{module}' has no parameter or buffer '{name}'")
cl = _Clamp(min, max)
nn.utils.parametrize.register_parametrization(module, name, cl)
return module
class ONNMiller(nn.Module):
def __init__(self, input_dim, output_dim, dtype=None) -> None:
super(ONNMiller, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.dtype = dtype
self.dim = max(input_dim, output_dim)
# zero pad input to internal size if smaller
if self.input_dim < self.dim:
self.pad = lambda x: F.pad(x, ((self.dim - self.input_dim) // 2, (self.dim - self.input_dim + 1) // 2))
else:
self.pad = lambda x: x
self.pad.__doc__ = f"Zero pad input from {self.input_dim} to {self.dim}"
# crop output to desired size
if self.output_dim < self.dim:
self.crop = lambda x: x[:, (self.dim - self.output_dim) // 2 : (x.shape[1] - (self.dim - self.output_dim + 1) // 2)]
else:
self.crop = lambda x: x
self.crop.__doc__ = f"Crop output from {self.dim} to {self.output_dim}"
self.U = nn.Parameter(torch.randn(self.dim, self.dim, dtype=self.dtype)) # -> parametrization: Unitary
self.S = nn.Parameter(torch.randn(self.dim, dtype=self.dtype)) # -> parametrization: Clamp (magnitude 0..1)
self.V = nn.Parameter(torch.randn(self.dim, self.dim, dtype=self.dtype)) # -> parametrization: Unitary
self.register_buffer("MZI_scale", torch.tensor(2, dtype=self.dtype.to_real()).sqrt())
# V is actually V.H, but
def forward(self, x_in):
x = x_in
x = self.pad(x)
x = x @ self.U
x = x * (self.S.squeeze() / self.MZI_scale)
x = x @ self.V
x = self.crop(x)
return x
class ONN(nn.Module):
def __init__(self, input_dim, output_dim, dtype=None) -> None:
super(ONN, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.dtype = dtype
self.dim = max(input_dim, output_dim)
# zero pad input to internal size if smaller
if self.input_dim < self.dim:
self.pad = lambda x: F.pad(x, ((self.dim - self.input_dim) // 2, (self.dim - self.input_dim + 1) // 2))
self.pad.__doc__ = f"Zero pad input from {self.input_dim} to {self.dim}"
else:
self.pad = lambda x: x
self.pad.__doc__ = f"Input size equals internal size {self.dim}"
# crop output to desired size
if self.output_dim < self.dim:
self.crop = lambda x: x[:, (self.dim - self.output_dim) // 2 : (x.shape[1] - (self.dim - self.output_dim + 1) // 2)]
self.crop.__doc__ = f"Crop output from {self.dim} to {self.output_dim}"
else:
self.crop = lambda x: x
self.crop.__doc__ = f"Output size equals internal size {self.dim}"
self.weight = nn.Parameter(torch.randn(self.dim, self.dim, dtype=self.dtype))
def reset_parameters(self):
q, _ = torch.linalg.qr(self.weight)
self.weight.data = q
# def get_M(self):
# return self.U @ self.sigma @ self.V
def forward(self, x):
return self.crop(self.pad(x) @ self.weight)
class SemiUnitaryLayer(nn.Module): class SemiUnitaryLayer(nn.Module):
def __init__(self, input_dim, output_dim, dtype=None): def __init__(self, input_dim, output_dim, dtype=None):
super(SemiUnitaryLayer, self).__init__() super(SemiUnitaryLayer, self).__init__()
@@ -51,24 +229,84 @@ class SemiUnitaryLayer(nn.Module):
# Create a larger square matrix for QR decomposition # Create a larger square matrix for QR decomposition
self.weight = nn.Parameter(torch.randn(max(input_dim, output_dim), max(input_dim, output_dim), dtype=dtype)) self.weight = nn.Parameter(torch.randn(max(input_dim, output_dim), max(input_dim, output_dim), dtype=dtype))
self.scale = nn.Parameter(torch.tensor(1.0, dtype=dtype.to_real()))
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
# Ensure the weights are semi-unitary by QR decomposition # Ensure the weights are unitary by QR decomposition
q, _ = torch.linalg.qr(self.weight) q, _ = torch.linalg.qr(self.weight)
# A = QR with A being a complex square matrix -> Q is unitary, R is upper triangular
# truncate the matrix to the desired size
if self.input_dim > self.output_dim: if self.input_dim > self.output_dim:
self.weight.data = q[: self.input_dim, : self.output_dim] self.weight.data = q[: self.input_dim, : self.output_dim]
else: else:
self.weight.data = q[: self.output_dim, : self.input_dim].t() self.weight.data = q[: self.output_dim, : self.input_dim].t()
...
def forward(self, x): def forward(self, x):
out = torch.matmul(x, self.weight) with torch.no_grad():
scale = torch.clamp(self.scale, 0.0, 1.0)
out = torch.matmul(x, scale * self.weight)
return out return out
def __repr__(self): def __repr__(self):
return f"SemiUnitaryLayer({self.input_dim}, {self.output_dim})" return f"SemiUnitaryLayer({self.input_dim}, {self.output_dim})"
# class SaturableAbsorberLambertW(nn.Module):
# """
# Implements the activation function for an optical saturable absorber
# base eqn: sigma*tau*I0 = 0.5*(log(Tm/T0))/(1-Tm),
# where: sigma is the absorption cross section
# tau is the radiative lifetime of the absorber material
# T0 is the initial transmittance
# I0 is the input intensity
# Tm is the transmittance of the absorber
# The activation function is defined as:
# Iout = I0 * Tm(I0)
# where Tm(I0) is the transmittance of the absorber as a function of the input intensity I0
# for a unit sigma*tau product, he solution Tm(I0) is given by:
# Tm(I0) = (W(2*exp(2*I0)*I0*T0))/(2*I0),
# where W is the Lambert W function
# if sigma*tau is not 1, I0 has to be scaled by sigma*tau
# (-> x has to be scaled by sqrt(sigma*tau))
# """
# def __init__(self, T0):
# super(SaturableAbsorberLambertW, self).__init__()
# self.register_buffer("T0", torch.tensor(T0))
# def forward(self, x: torch.Tensor):
# xc = x.conj()
# two_x_xc = (2 * x * xc).real
# return (lambertw(2 * torch.exp(two_x_xc) * (x * self.T0 * xc).real) / two_x_xc).to(dtype=x.dtype)
# def backward(self, x):
# xc = x.conj()
# lambert_eval = lambertw(2 * torch.exp(2 * x * xc).real * (x * self.T0 * xc).real)
# return (((xc * (-2 * lambert_eval + 2 * torch.square(x) - 1) + 2 * x * torch.square(xc) + x) * lambert_eval) / (
# 2 * torch.pow(x, 3) * xc * (lambert_eval + 1)
# )).to(dtype=x.dtype)
# class SaturableAbsorber(nn.Module):
# def __init__(self, alpha, I0):
# super(SaturableAbsorber, self).__init__()
# self.register_buffer("alpha", torch.tensor(alpha))
# self.register_buffer("I0", torch.tensor(I0))
# def forward(self, x):
# I = (x*x.conj()).to(dtype=x.dtype.to_real())
# A = self.alpha/(1+I/self.I0)
# class SpreadLayer(nn.Module): # class SpreadLayer(nn.Module):
# def __init__(self, in_features, out_features, dtype=None): # def __init__(self, in_features, out_features, dtype=None):
# super(SpreadLayer, self).__init__() # super(SpreadLayer, self).__init__()
@@ -85,6 +323,19 @@ class SemiUnitaryLayer(nn.Module):
#### as defined by zhang et al #### as defined by zhang et al
class DropoutComplex(nn.Module):
def __init__(self, p=0.5):
super(DropoutComplex, self).__init__()
self.dropout = nn.Dropout(p=p)
def forward(self, x):
if x.is_complex():
mask = self.dropout(torch.ones_like(x.real))
return x * mask
else:
return self.dropout(x)
class Identity(nn.Module): class Identity(nn.Module):
""" """
implements the "activation" function implements the "activation" function
@@ -97,18 +348,76 @@ class Identity(nn.Module):
def forward(self, x): def forward(self, x):
return x return x
class PowRot(nn.Module):
def __init__(self, bias=False):
super(PowRot, self).__init__()
self.scale = nn.Parameter(torch.tensor(1.0))
if bias:
self.bias = nn.Parameter(torch.tensor(0.0))
else:
self.register_buffer("bias", torch.tensor(0.0))
def forward(self, x: torch.Tensor):
if x.is_complex():
return x * torch.exp(-self.scale*1j*x.abs().square()+self.bias.to(dtype=x.dtype))
else:
return x
class Pow(nn.Module):
"""
implements the activation function
M(z) = ||z||^2 + b
"""
def __init__(self, bias=False):
super(Pow, self).__init__()
if bias:
self.bias = nn.Parameter(torch.tensor(0.0))
else:
self.register_buffer("bias", torch.tensor(0.0))
def forward(self, x: torch.Tensor):
return x.abs().square().add(self.bias).to(dtype=x.dtype)
class Mag(nn.Module): class Mag(nn.Module):
""" """
implements the activation function implements the activation function
M(z) = ||z|| M(z) = ||z||+b
""" """
def __init__(self): def __init__(self, bias=False):
super(Mag, self).__init__() super(Mag, self).__init__()
if bias:
self.bias = nn.Parameter(torch.tensor(0.0))
else:
self.register_buffer("bias", torch.tensor(0.0))
def forward(self, x): def forward(self, x: torch.Tensor):
return torch.abs(x).to(dtype=x.dtype) return x.abs().add(self.bias).to(dtype=x.dtype)
class MagScale(nn.Module):
def __init__(self, bias=False):
super(MagScale, self).__init__()
if bias:
self.bias = nn.Parameter(torch.tensor(0.0))
else:
self.register_buffer("bias", torch.tensor(0.0))
def forward(self, x: torch.Tensor):
return x.abs().add(self.bias).to(dtype=x.dtype).sin().mul(x)
class PowScale(nn.Module):
def __init__(self, bias=False):
super(PowScale, self).__init__()
if bias:
self.bias = nn.Parameter(torch.tensor(0.0))
else:
self.register_buffer("bias", torch.tensor(0.0))
def forward(self, x: torch.Tensor):
return x.mul(x.abs().square().add(self.bias).to(dtype=x.dtype).sin())
class ModReLU(nn.Module): class ModReLU(nn.Module):
@@ -118,17 +427,21 @@ class ModReLU(nn.Module):
= ReLU(||z|| + b)*z/||z|| = ReLU(||z|| + b)*z/||z||
""" """
def __init__(self, b=0): def __init__(self, bias=True):
super(ModReLU, self).__init__() super(ModReLU, self).__init__()
self.b = torch.tensor(b) if bias:
self.bias = nn.Parameter(torch.tensor(0.0))
else:
self.register_buffer("bias", torch.tensor(0.0))
def forward(self, x): def forward(self, x):
if x.is_complex(): if x.is_complex():
mod = torch.abs(x.real**2 + x.imag**2) mod = x.abs()
return torch.relu(mod + self.b) * x / mod out = torch.relu(mod + self.bias) * x / mod
return out.to(dtype=x.dtype)
else: else:
return torch.relu(x + self.b) return torch.relu(x + self.bias).to(dtype=x.dtype)
def __repr__(self): def __repr__(self):
return f"ModReLU(b={self.b})" return f"ModReLU(b={self.b})"
@@ -166,3 +479,26 @@ class ZReLU(nn.Module):
return x * (torch.angle(x) >= 0) * (torch.angle(x) <= torch.pi / 2) return x * (torch.angle(x) >= 0) * (torch.angle(x) <= torch.pi / 2)
else: else:
return torch.relu(x) return torch.relu(x)
__all__ = [
complex_sse_loss,
complex_mse_loss,
UnitaryLayer,
unitary,
clamp,
ONN,
ONNMiller,
SemiUnitaryLayer,
DropoutComplex,
Identity,
Pow,
PowRot,
Mag,
ModReLU,
CReLU,
ZReLU,
# SaturableAbsorberLambertW,
# SaturableAbsorber,
# SpreadLayer,
]