update submodule configuration and enhance model settings; add eye diagram functionality
This commit is contained in:
@@ -4,23 +4,36 @@ import torch.nn.functional as F
|
||||
# from torchlambertw.special import lambertw
|
||||
|
||||
|
||||
def complex_mse_loss(input, target, power=False, reduction="mean"):
|
||||
def complex_mse_loss(input, target, power=False, normalize=False, reduction="mean"):
|
||||
"""
|
||||
Compute the mean squared error between two complex tensors.
|
||||
If power is set to True, the loss is computed as |input|^2 - |target|^2
|
||||
"""
|
||||
reduce = getattr(torch, reduction)
|
||||
power_penalty = 0
|
||||
|
||||
if power:
|
||||
input = (input * input.conj()).real.to(dtype=input.dtype.to_real())
|
||||
target = (target * target.conj()).real.to(dtype=target.dtype.to_real())
|
||||
if normalize:
|
||||
power_penalty = ((input.max() - input.min()) - (target.max() - target.min())) ** 2
|
||||
power_penalty += (input.min() - target.min()) ** 2
|
||||
input = input - input.min()
|
||||
input = input / input.max()
|
||||
target = target - target.min()
|
||||
target = target / target.max()
|
||||
else:
|
||||
if normalize:
|
||||
power_penalty = (input.abs().max() - target.abs().max()) ** 2
|
||||
input = input / input.abs().max()
|
||||
target = target / target.abs().max()
|
||||
|
||||
if input.is_complex() and target.is_complex():
|
||||
return reduce(torch.square(input.real - target.real) + torch.square(input.imag - target.imag))
|
||||
return reduce(torch.square(input.real - target.real) + torch.square(input.imag - target.imag)) + power_penalty
|
||||
elif input.is_complex() or target.is_complex():
|
||||
raise ValueError("Input and target must have the same type (real or complex)")
|
||||
else:
|
||||
return F.mse_loss(input, target, reduction=reduction)
|
||||
return F.mse_loss(input, target, reduction=reduction) + power_penalty
|
||||
|
||||
|
||||
def complex_sse_loss(input, target):
|
||||
@@ -53,23 +66,19 @@ class UnitaryLayer(nn.Module):
|
||||
return f"UnitaryLayer({self.in_features}, {self.out_features})"
|
||||
|
||||
|
||||
|
||||
class _Unitary(nn.Module):
|
||||
def forward(self, X:torch.Tensor):
|
||||
def forward(self, X: torch.Tensor):
|
||||
if X.ndim < 2:
|
||||
raise ValueError(
|
||||
"Only tensors with 2 or more dimensions are supported. "
|
||||
f"Got a tensor of shape {X.shape}"
|
||||
)
|
||||
raise ValueError(f"Only tensors with 2 or more dimensions are supported. Got a tensor of shape {X.shape}")
|
||||
n, k = X.size(-2), X.size(-1)
|
||||
transpose = n<k
|
||||
transpose = n < k
|
||||
if transpose:
|
||||
X = X.transpose(-2, -1)
|
||||
q, r = torch.linalg.qr(X)
|
||||
# q: torch.Tensor = q
|
||||
# r: torch.Tensor = r
|
||||
d = r.diagonal(dim1=-2, dim2=-1).sgn()
|
||||
q*=d.unsqueeze(-2)
|
||||
q *= d.unsqueeze(-2)
|
||||
if transpose:
|
||||
q = q.transpose(-2, -1)
|
||||
if n == k:
|
||||
@@ -80,6 +89,7 @@ class _Unitary(nn.Module):
|
||||
# X.copy_(q)
|
||||
return q
|
||||
|
||||
|
||||
def unitary(module: nn.Module, name: str = "weight") -> nn.Module:
|
||||
weight = getattr(module, name, None)
|
||||
if not isinstance(weight, torch.Tensor):
|
||||
@@ -87,27 +97,29 @@ def unitary(module: nn.Module, name: str = "weight") -> nn.Module:
|
||||
|
||||
if weight.ndim < 2:
|
||||
raise ValueError(f"Expected a matrix or batch of matrices. Got a tensor of {weight.ndim} dimensions.")
|
||||
|
||||
|
||||
if weight.shape[-2] != weight.shape[-1]:
|
||||
raise ValueError(f"Expected a square matrix or batch of square matrices. Got a tensor of shape {weight.shape}")
|
||||
|
||||
|
||||
unit = _Unitary()
|
||||
nn.utils.parametrize.register_parametrization(module, name, unit)
|
||||
return module
|
||||
|
||||
|
||||
class _SpecialUnitary(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, X:torch.Tensor):
|
||||
def forward(self, X: torch.Tensor):
|
||||
n, k = X.size(-2), X.size(-1)
|
||||
if n != k:
|
||||
raise ValueError(f"Expected a square matrix. Got a tensor of shape {X.shape}")
|
||||
q, _ = torch.linalg.qr(X)
|
||||
q = q / torch.linalg.det(q).pow(1/n)
|
||||
|
||||
q = q / torch.linalg.det(q).pow(1 / n)
|
||||
|
||||
return q
|
||||
|
||||
|
||||
def special_unitary(module: nn.Module, name: str = "weight") -> nn.Module:
|
||||
weight = getattr(module, name, None)
|
||||
if not isinstance(weight, torch.Tensor):
|
||||
@@ -115,73 +127,61 @@ def special_unitary(module: nn.Module, name: str = "weight") -> nn.Module:
|
||||
|
||||
if weight.ndim < 2:
|
||||
raise ValueError(f"Expected a matrix or batch of matrices. Got a tensor of {weight.ndim} dimensions.")
|
||||
|
||||
|
||||
if weight.shape[-2] != weight.shape[-1]:
|
||||
raise ValueError(f"Expected a square matrix or batch of square matrices. Got a tensor of shape {weight.shape}")
|
||||
|
||||
|
||||
unit = _SpecialUnitary()
|
||||
nn.utils.parametrize.register_parametrization(module, name, unit)
|
||||
return module
|
||||
|
||||
|
||||
class _Clamp(nn.Module):
|
||||
def __init__(self, min, max):
|
||||
super(_Clamp, self).__init__()
|
||||
self.min = min
|
||||
self.max = max
|
||||
|
||||
def forward(self, x):
|
||||
if x.is_complex():
|
||||
# clamp magnitude, ignore phase
|
||||
return torch.clamp(x.abs(), self.min, self.max) * x / x.abs()
|
||||
return torch.clamp(x, self.min, self.max)
|
||||
|
||||
|
||||
|
||||
def clamp(module: nn.Module, name: str = "scale", min=0, max=1) -> nn.Module:
|
||||
scale = getattr(module, name, None)
|
||||
if not isinstance(scale, torch.Tensor):
|
||||
raise ValueError(f"Module '{module}' has no parameter or buffer '{name}'")
|
||||
|
||||
|
||||
cl = _Clamp(min, max)
|
||||
nn.utils.parametrize.register_parametrization(module, name, cl)
|
||||
return module
|
||||
|
||||
|
||||
class ONNMiller(nn.Module):
|
||||
def __init__(self, input_dim, output_dim, dtype=None) -> None:
|
||||
super(ONNMiller, self).__init__()
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
self.dtype = dtype
|
||||
class _EnergyConserving(nn.Module):
|
||||
def __init__(self):
|
||||
super(_EnergyConserving, self).__init__()
|
||||
|
||||
self.dim = max(input_dim, output_dim)
|
||||
def forward(self, X: torch.Tensor):
|
||||
if X.ndim == 2:
|
||||
X = X.unsqueeze(0)
|
||||
spectral_norm = torch.linalg.svdvals(X)[:, 0]
|
||||
return (X / spectral_norm).squeeze()
|
||||
|
||||
# zero pad input to internal size if smaller
|
||||
if self.input_dim < self.dim:
|
||||
self.pad = lambda x: F.pad(x, ((self.dim - self.input_dim) // 2, (self.dim - self.input_dim + 1) // 2))
|
||||
else:
|
||||
self.pad = lambda x: x
|
||||
self.pad.__doc__ = f"Zero pad input from {self.input_dim} to {self.dim}"
|
||||
|
||||
# crop output to desired size
|
||||
if self.output_dim < self.dim:
|
||||
self.crop = lambda x: x[:, (self.dim - self.output_dim) // 2 : (x.shape[1] - (self.dim - self.output_dim + 1) // 2)]
|
||||
else:
|
||||
self.crop = lambda x: x
|
||||
self.crop.__doc__ = f"Crop output from {self.dim} to {self.output_dim}"
|
||||
def energy_conserving(module: nn.Module, name: str = "weight") -> nn.Module:
|
||||
param = getattr(module, name, None)
|
||||
if not isinstance(param, torch.Tensor):
|
||||
raise ValueError(f"Module '{module}' has no parameter or buffer '{name}'")
|
||||
|
||||
self.U = nn.Parameter(torch.randn(self.dim, self.dim, dtype=self.dtype)) # -> parametrization: Unitary
|
||||
self.S = nn.Parameter(torch.randn(self.dim, dtype=self.dtype)) # -> parametrization: Clamp (magnitude 0..1)
|
||||
self.V = nn.Parameter(torch.randn(self.dim, self.dim, dtype=self.dtype)) # -> parametrization: Unitary
|
||||
self.register_buffer("MZI_scale", torch.tensor(2, dtype=self.dtype.to_real()).sqrt())
|
||||
# V is actually V.H, but
|
||||
if not (2 <= param.ndim <= 3):
|
||||
raise ValueError(f"Expected a matrix or batch of matrices. Got a tensor of {param.ndim} dimensions.")
|
||||
|
||||
unit = _EnergyConserving()
|
||||
nn.utils.parametrize.register_parametrization(module, name, unit)
|
||||
return module
|
||||
|
||||
def forward(self, x_in):
|
||||
x = x_in
|
||||
x = self.pad(x)
|
||||
x = x @ self.U
|
||||
x = x * (self.S.squeeze() / self.MZI_scale)
|
||||
x = x @ self.V
|
||||
x = self.crop(x)
|
||||
return x
|
||||
|
||||
class ONN(nn.Module):
|
||||
def __init__(self, input_dim, output_dim, dtype=None) -> None:
|
||||
@@ -202,56 +202,72 @@ class ONN(nn.Module):
|
||||
|
||||
# crop output to desired size
|
||||
if self.output_dim < self.dim:
|
||||
self.crop = lambda x: x[:, (self.dim - self.output_dim) // 2 : (x.shape[1] - (self.dim - self.output_dim + 1) // 2)]
|
||||
self.crop = lambda x: x[
|
||||
:, (self.dim - self.output_dim) // 2 : (x.shape[1] - (self.dim - self.output_dim + 1) // 2)
|
||||
]
|
||||
self.crop.__doc__ = f"Crop output from {self.dim} to {self.output_dim}"
|
||||
else:
|
||||
self.crop = lambda x: x
|
||||
self.crop.__doc__ = f"Output size equals internal size {self.dim}"
|
||||
|
||||
|
||||
self.weight = nn.Parameter(torch.randn(self.dim, self.dim, dtype=self.dtype))
|
||||
# self.scale = nn.Parameter(torch.randn(1, dtype=self.dtype.to_real())+0.5)
|
||||
|
||||
def reset_parameters(self):
|
||||
q, _ = torch.linalg.qr(self.weight)
|
||||
self.weight.data = q
|
||||
|
||||
# def get_M(self):
|
||||
# return self.U @ self.sigma @ self.V
|
||||
# return self.U @ self.sigma @ self.V
|
||||
|
||||
def forward(self, x):
|
||||
return self.crop(self.pad(x) @ self.weight)
|
||||
|
||||
|
||||
class SemiUnitaryLayer(nn.Module):
|
||||
def __init__(self, input_dim, output_dim, dtype=None):
|
||||
super(SemiUnitaryLayer, self).__init__()
|
||||
class ONNRect(nn.Module):
|
||||
def __init__(self, input_dim, output_dim, square=False, dtype=None):
|
||||
super(ONNRect, self).__init__()
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
|
||||
# Create a larger square matrix for QR decomposition
|
||||
self.weight = nn.Parameter(torch.randn(max(input_dim, output_dim), max(input_dim, output_dim), dtype=dtype))
|
||||
self.scale = nn.Parameter(torch.tensor(1.0, dtype=dtype.to_real()))
|
||||
self.reset_parameters()
|
||||
if square:
|
||||
dim = max(input_dim, output_dim)
|
||||
self.weight = nn.Parameter(torch.randn(dim, dim, dtype=dtype))
|
||||
|
||||
# zero pad input to internal size if smaller
|
||||
if self.input_dim < dim:
|
||||
self.pad = lambda x: F.pad(x, ((dim - self.input_dim) // 2, (dim - self.input_dim + 1) // 2))
|
||||
self.pad.__doc__ = f"Zero pad input from {self.input_dim} to {dim}"
|
||||
else:
|
||||
self.pad = lambda x: x
|
||||
self.pad.__doc__ = f"Input size equals internal size {dim}"
|
||||
|
||||
# crop output to desired size
|
||||
if self.output_dim < dim:
|
||||
self.crop = lambda x: x[
|
||||
:, (dim - self.output_dim) // 2 : (x.shape[1] - (dim - self.output_dim + 1) // 2)
|
||||
]
|
||||
self.crop.__doc__ = f"Crop output from {dim} to {self.output_dim}"
|
||||
else:
|
||||
self.crop = lambda x: x
|
||||
self.crop.__doc__ = f"Output size equals internal size {dim}"
|
||||
|
||||
def reset_parameters(self):
|
||||
# Ensure the weights are unitary by QR decomposition
|
||||
q, _ = torch.linalg.qr(self.weight)
|
||||
# A = QR with A being a complex square matrix -> Q is unitary, R is upper triangular
|
||||
|
||||
# truncate the matrix to the desired size
|
||||
if self.input_dim > self.output_dim:
|
||||
self.weight.data = q[: self.input_dim, : self.output_dim]
|
||||
else:
|
||||
self.weight.data = q[: self.output_dim, : self.input_dim].t()
|
||||
...
|
||||
self.weight = nn.Parameter(torch.randn(output_dim, input_dim, dtype=dtype))
|
||||
self.pad = lambda x: x
|
||||
self.pad.__doc__ = "No padding"
|
||||
self.crop = lambda x: x
|
||||
self.crop.__doc__ = "No cropping"
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
with torch.no_grad():
|
||||
scale = torch.clamp(self.scale, 0.0, 1.0)
|
||||
out = torch.matmul(x, scale * self.weight)
|
||||
x = self.pad(x)
|
||||
out = self.crop((self.weight @ x.mT).mT)
|
||||
return out
|
||||
|
||||
def __repr__(self):
|
||||
return f"SemiUnitaryLayer({self.input_dim}, {self.output_dim})"
|
||||
# def __repr__(self):
|
||||
# return f"ONNRect({self.input_dim}, {self.output_dim})"
|
||||
|
||||
|
||||
# class SaturableAbsorberLambertW(nn.Module):
|
||||
@@ -336,6 +352,19 @@ class DropoutComplex(nn.Module):
|
||||
return self.dropout(x)
|
||||
|
||||
|
||||
class Scale(nn.Module):
|
||||
def __init__(self, size):
|
||||
super(Scale, self).__init__()
|
||||
self.size = size
|
||||
self.scale = nn.Parameter(torch.ones(size, dtype=torch.float32))
|
||||
|
||||
def forward(self, x):
|
||||
return x * self.scale
|
||||
|
||||
def __repr__(self):
|
||||
return f"Scale({self.size})"
|
||||
|
||||
|
||||
class Identity(nn.Module):
|
||||
"""
|
||||
implements the "activation" function
|
||||
@@ -348,6 +377,7 @@ class Identity(nn.Module):
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
|
||||
class PowRot(nn.Module):
|
||||
def __init__(self, bias=False):
|
||||
super(PowRot, self).__init__()
|
||||
@@ -359,15 +389,75 @@ class PowRot(nn.Module):
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
if x.is_complex():
|
||||
return x * torch.exp(-self.scale*1j*x.abs().square()+self.bias.to(dtype=x.dtype))
|
||||
return x * torch.exp(-self.scale * 1j * x.abs().square() + self.bias.to(dtype=x.dtype))
|
||||
else:
|
||||
return x
|
||||
return x
|
||||
|
||||
|
||||
class MZISingle(nn.Module):
|
||||
def __init__(self, bias, size, func=None):
|
||||
super(MZISingle, self).__init__()
|
||||
self.omega = nn.Parameter(torch.randn(size))
|
||||
self.phi = nn.Parameter(torch.randn(size))
|
||||
self.func = func or (lambda x: x.abs().square()) # default to |z|^2
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return x * torch.exp(1j * self.phi) * torch.sin(self.omega + self.func(x))
|
||||
|
||||
|
||||
class EOActivation(nn.Module):
|
||||
def __init__(self, bias, size=None):
|
||||
# 10.1109/SiPhotonics60897.2024.10543376
|
||||
super(EOActivation, self).__init__()
|
||||
if size is None:
|
||||
raise ValueError("Size must be specified")
|
||||
self.size = size
|
||||
self.alpha = nn.Parameter(torch.ones(size))
|
||||
self.V_bias = nn.Parameter(torch.ones(size))
|
||||
self.gain = nn.Parameter(torch.ones(size))
|
||||
# if bias:
|
||||
# self.phase_bias = nn.Parameter(torch.zeros(size))
|
||||
# else:
|
||||
# self.register_buffer("phase_bias", torch.zeros(size))
|
||||
self.register_buffer("phase_bias", torch.clamp(torch.ones(size) + torch.randn(size)*0.1, 0, 1)*torch.pi)
|
||||
self.register_buffer("responsivity", torch.ones(size)*0.9)
|
||||
self.register_buffer("V_pi", torch.ones(size)*3)
|
||||
|
||||
self.reset_weights()
|
||||
|
||||
def reset_weights(self):
|
||||
if "alpha" in self._parameters:
|
||||
self.alpha.data = torch.ones(self.size)*0.5
|
||||
if "V_pi" in self._parameters:
|
||||
self.V_pi.data = torch.ones(self.size)*3
|
||||
if "V_bias" in self._parameters:
|
||||
self.V_bias.data = torch.zeros(self.size)
|
||||
if "gain" in self._parameters:
|
||||
self.gain.data = torch.ones(self.size)
|
||||
if "responsivity" in self._parameters:
|
||||
self.responsivity.data = torch.ones(self.size)*0.9
|
||||
if "bias" in self._parameters:
|
||||
self.phase_bias.data = torch.zeros(self.size)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
phi_b = torch.pi * self.V_bias / (self.V_pi + 1e-8)
|
||||
g_phi = torch.pi * (self.alpha * self.gain * self.responsivity) / (self.V_pi + 1e-8)
|
||||
intermediate = g_phi * x.abs().square() + phi_b
|
||||
return (
|
||||
1j
|
||||
* torch.sqrt(1 - self.alpha)
|
||||
* torch.exp(-0.5j * (intermediate + self.phase_bias))
|
||||
* torch.cos(0.5 * intermediate)
|
||||
* x
|
||||
)
|
||||
|
||||
|
||||
class Pow(nn.Module):
|
||||
"""
|
||||
implements the activation function
|
||||
M(z) = ||z||^2 + b
|
||||
"""
|
||||
|
||||
def __init__(self, bias=False):
|
||||
super(Pow, self).__init__()
|
||||
if bias:
|
||||
@@ -375,7 +465,6 @@ class Pow(nn.Module):
|
||||
else:
|
||||
self.register_buffer("bias", torch.tensor(0.0))
|
||||
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return x.abs().square().add(self.bias).to(dtype=x.dtype)
|
||||
|
||||
@@ -395,7 +484,7 @@ class Mag(nn.Module):
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return x.abs().add(self.bias).to(dtype=x.dtype)
|
||||
|
||||
|
||||
|
||||
class MagScale(nn.Module):
|
||||
def __init__(self, bias=False):
|
||||
@@ -404,10 +493,11 @@ class MagScale(nn.Module):
|
||||
self.bias = nn.Parameter(torch.tensor(0.0))
|
||||
else:
|
||||
self.register_buffer("bias", torch.tensor(0.0))
|
||||
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return x.abs().add(self.bias).to(dtype=x.dtype).sin().mul(x)
|
||||
|
||||
|
||||
|
||||
class PowScale(nn.Module):
|
||||
def __init__(self, bias=False):
|
||||
super(PowScale, self).__init__()
|
||||
@@ -415,7 +505,7 @@ class PowScale(nn.Module):
|
||||
self.bias = nn.Parameter(torch.tensor(0.0))
|
||||
else:
|
||||
self.register_buffer("bias", torch.tensor(0.0))
|
||||
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return x.mul(x.abs().square().add(self.bias).to(dtype=x.dtype).sin())
|
||||
|
||||
@@ -486,10 +576,10 @@ __all__ = [
|
||||
complex_mse_loss,
|
||||
UnitaryLayer,
|
||||
unitary,
|
||||
energy_conserving,
|
||||
clamp,
|
||||
ONN,
|
||||
ONNMiller,
|
||||
SemiUnitaryLayer,
|
||||
ONNRect,
|
||||
DropoutComplex,
|
||||
Identity,
|
||||
Pow,
|
||||
@@ -498,7 +588,9 @@ __all__ = [
|
||||
ModReLU,
|
||||
CReLU,
|
||||
ZReLU,
|
||||
MZISingle,
|
||||
EOActivation,
|
||||
# SaturableAbsorberLambertW,
|
||||
# SaturableAbsorber,
|
||||
# SpreadLayer,
|
||||
]
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user