Files
optical-regeneration/src/single-core-regen/util/complexNN.py
Joseph Hopfmüller 674033ac2e move hypertraining class into separate file;
move settings dataclasses into separate file;
add SemiUnitaryLayer;
clean up model response plotting code;
cnt hyperparameter search
2024-11-20 22:49:31 +01:00

174 lines
4.6 KiB
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
def complex_mse_loss(input, target):
"""
Compute the mean squared error between two complex tensors.
"""
if input.is_complex():
return torch.mean(
torch.square(input.real - target.real)
+ torch.square(input.imag - target.imag)
)
else:
return F.mse_loss(input, target)
def complex_sse_loss(input, target):
"""
Compute the sum squared error between two complex tensors.
"""
if input.is_complex():
return torch.sum(
torch.square(input.real - target.real)
+ torch.square(input.imag - target.imag)
)
else:
return torch.sum(torch.square(input - target))
class UnitaryLayer(nn.Module):
def __init__(self, in_features, out_features, dtype=None):
assert in_features >= out_features
super(UnitaryLayer, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = nn.Parameter(torch.randn(in_features, out_features, dtype=dtype))
self.reset_parameters()
def reset_parameters(self):
q, _ = torch.linalg.qr(self.weight)
self.weight.data = q
def forward(self, x):
return torch.matmul(x, self.weight)
def __repr__(self):
return f"UnitaryLayer({self.in_features}, {self.out_features})"
class SemiUnitaryLayer(nn.Module):
def __init__(self, input_dim, output_dim, dtype=None):
super(SemiUnitaryLayer, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
# Create a larger square matrix for QR decomposition
self.weight = nn.Parameter(torch.randn(max(input_dim, output_dim), max(input_dim, output_dim), dtype=dtype))
self.reset_parameters()
def reset_parameters(self):
# Ensure the weights are semi-unitary by QR decomposition
q, _ = torch.linalg.qr(self.weight)
if self.input_dim > self.output_dim:
self.weight.data = q[:self.input_dim, :self.output_dim]
else:
self.weight.data = q[:self.output_dim, :self.input_dim].t()
def forward(self, x):
out = torch.matmul(x, self.weight)
return out
def __repr__(self):
return f"SemiUnitaryLayer({self.input_dim}, {self.output_dim})"
# class SpreadLayer(nn.Module):
# def __init__(self, in_features, out_features, dtype=None):
# super(SpreadLayer, self).__init__()
# self.in_features = in_features
# self.out_features = out_features
# self.mat = torch.ones(in_features, out_features, dtype=dtype)*torch.sqrt(torch.tensor(in_features/out_features))
# def forward(self, x):
# # N in_features -> M out_features, Enery is preserved (P = abs(x)^2)
# out = torch.matmul(x, self.mat)
# return out
#### as defined by zhang et al
class Identity(nn.Module):
"""
implements the "activation" function
M(z) = z
"""
def __init__(self):
super(Identity, self).__init__()
def forward(self, x):
return x
class Mag(nn.Module):
"""
implements the activation function
M(z) = ||z||
"""
def __init__(self):
super(Mag, self).__init__()
def forward(self, x):
return torch.abs(x).to(dtype=x.dtype)
class ModReLU(nn.Module):
"""
implements the activation function
M(z) = ReLU(||z|| + b)*exp(j*theta_z)
= ReLU(||z|| + b)*z/||z||
"""
def __init__(self, b=0):
super(ModReLU, self).__init__()
self.b = torch.tensor(b)
def forward(self, x):
if x.is_complex():
mod = torch.abs(x.real**2 + x.imag**2)
return torch.relu(mod + self.b) * x / mod
else:
return torch.relu(x + self.b)
def __repr__(self):
return f"ModReLU(b={self.b})"
class CReLU(nn.Module):
"""
implements the activation function
M(z) = ReLU(Re(z)) + j*ReLU(Im(z))
"""
def __init__(self):
super(CReLU, self).__init__()
def forward(self, x):
if x.is_complex():
return torch.relu(x.real) + 1j * torch.relu(x.imag)
else:
return torch.relu(x)
class ZReLU(nn.Module):
"""
implements the activation function
M(z) = z if 0 <= angle(z) <= pi/2
= 0 otherwise
"""
def __init__(self):
super(ZReLU, self).__init__()
def forward(self, x):
if x.is_complex():
return x * (torch.angle(x) >= 0) * (torch.angle(x) <= torch.pi / 2)
else:
return torch.relu(x)