import torch import torch.nn as nn import torch.nn.functional as F def complex_mse_loss(input, target): """ Compute the mean squared error between two complex tensors. """ if input.is_complex(): return torch.mean( torch.square(input.real - target.real) + torch.square(input.imag - target.imag) ) else: return F.mse_loss(input, target) def complex_sse_loss(input, target): """ Compute the sum squared error between two complex tensors. """ if input.is_complex(): return torch.sum( torch.square(input.real - target.real) + torch.square(input.imag - target.imag) ) else: return torch.sum(torch.square(input - target)) class UnitaryLayer(nn.Module): def __init__(self, in_features, out_features, dtype=None): assert in_features >= out_features super(UnitaryLayer, self).__init__() self.in_features = in_features self.out_features = out_features self.weight = nn.Parameter(torch.randn(in_features, out_features, dtype=dtype)) self.reset_parameters() def reset_parameters(self): q, _ = torch.linalg.qr(self.weight) self.weight.data = q def forward(self, x): return torch.matmul(x, self.weight) def __repr__(self): return f"UnitaryLayer({self.in_features}, {self.out_features})" class SemiUnitaryLayer(nn.Module): def __init__(self, input_dim, output_dim, dtype=None): super(SemiUnitaryLayer, self).__init__() self.input_dim = input_dim self.output_dim = output_dim # Create a larger square matrix for QR decomposition self.weight = nn.Parameter(torch.randn(max(input_dim, output_dim), max(input_dim, output_dim), dtype=dtype)) self.reset_parameters() def reset_parameters(self): # Ensure the weights are semi-unitary by QR decomposition q, _ = torch.linalg.qr(self.weight) if self.input_dim > self.output_dim: self.weight.data = q[:self.input_dim, :self.output_dim] else: self.weight.data = q[:self.output_dim, :self.input_dim].t() def forward(self, x): out = torch.matmul(x, self.weight) return out def __repr__(self): return f"SemiUnitaryLayer({self.input_dim}, {self.output_dim})" # class SpreadLayer(nn.Module): # def __init__(self, in_features, out_features, dtype=None): # super(SpreadLayer, self).__init__() # self.in_features = in_features # self.out_features = out_features # self.mat = torch.ones(in_features, out_features, dtype=dtype)*torch.sqrt(torch.tensor(in_features/out_features)) # def forward(self, x): # # N in_features -> M out_features, Enery is preserved (P = abs(x)^2) # out = torch.matmul(x, self.mat) # return out #### as defined by zhang et al class Identity(nn.Module): """ implements the "activation" function M(z) = z """ def __init__(self): super(Identity, self).__init__() def forward(self, x): return x class Mag(nn.Module): """ implements the activation function M(z) = ||z|| """ def __init__(self): super(Mag, self).__init__() def forward(self, x): return torch.abs(x).to(dtype=x.dtype) class ModReLU(nn.Module): """ implements the activation function M(z) = ReLU(||z|| + b)*exp(j*theta_z) = ReLU(||z|| + b)*z/||z|| """ def __init__(self, b=0): super(ModReLU, self).__init__() self.b = torch.tensor(b) def forward(self, x): if x.is_complex(): mod = torch.abs(x.real**2 + x.imag**2) return torch.relu(mod + self.b) * x / mod else: return torch.relu(x + self.b) def __repr__(self): return f"ModReLU(b={self.b})" class CReLU(nn.Module): """ implements the activation function M(z) = ReLU(Re(z)) + j*ReLU(Im(z)) """ def __init__(self): super(CReLU, self).__init__() def forward(self, x): if x.is_complex(): return torch.relu(x.real) + 1j * torch.relu(x.imag) else: return torch.relu(x) class ZReLU(nn.Module): """ implements the activation function M(z) = z if 0 <= angle(z) <= pi/2 = 0 otherwise """ def __init__(self): super(ZReLU, self).__init__() def forward(self, x): if x.is_complex(): return x * (torch.angle(x) >= 0) * (torch.angle(x) <= torch.pi / 2) else: return torch.relu(x)