training loop speedup

This commit is contained in:
Joseph Hopfmüller
2024-11-20 11:29:18 +01:00
parent 1622c38582
commit cdca5de473
11 changed files with 1026 additions and 151 deletions

View File

@@ -0,0 +1,141 @@
import torch
import torch.nn as nn
def complex_mse_loss(input, target):
"""
Compute the mean squared error between two complex tensors.
"""
return torch.mean(torch.square(input.real - target.real) + torch.square(input.imag - target.imag))
def complex_sse_loss(input, target):
"""
Compute the sum squared error between two complex tensors.
"""
if input.is_complex():
return torch.sum(torch.square(input.real - target.real) + torch.square(input.imag - target.imag))
else:
return torch.sum(torch.square(input - target))
class UnitaryLayer(nn.Module):
def __init__(self, in_features, out_features):
super(UnitaryLayer, self).__init__()
assert in_features >= out_features
self.in_features = in_features
self.out_features = out_features
self.weight = nn.Parameter(torch.randn(in_features, out_features, dtype=torch.cfloat))
self.reset_parameters()
def reset_parameters(self):
q, _ = torch.linalg.qr(self.weight)
self.weight.data = q
@staticmethod
@torch.jit.script
def _unitary_forward(x, weight):
out = torch.matmul(x, weight)
return out
def forward(self, x):
return self._unitary_forward(x, self.weight)
#### as defined by zhang et al
class Identity(nn.Module):
"""
implements the "activation" function
M(z) = z
"""
def __init__(self):
super(Identity, self).__init__()
def forward(self, x):
return x
class Mag(nn.Module):
"""
implements the activation function
M(z) = ||z||
"""
def __init__(self):
super(Mag, self).__init__()
@torch.jit.script
def forward(self, x):
return torch.abs(x.real**2 + x.imag**2)
# class Tanh(nn.Module):
# """
# implements the activation function
# M(z) = tanh(z) = sinh(z)/cosh(z) = (exp(z)-exp(-z))/(exp(z)+exp(-z)) = (exp(2*z)-1)/(exp(2*z)+1)
# """
# def __init__(self):
# super(Tanh, self).__init__()
# def forward(self, x):
# return torch.tanh(x)
class ModReLU(nn.Module):
"""
implements the activation function
M(z) = ReLU(||z|| + b)*exp(j*theta_z)
= ReLU(||z|| + b)*z/||z||
"""
def __init__(self, b=0):
super(ModReLU, self).__init__()
self.b = b
self.relu = nn.ReLU()
@staticmethod
# @torch.jit.script
def _mod_relu(x, b):
mod = torch.abs(x.real**2 + x.imag**2)
return torch.relu(mod + b) * x / mod
def forward(self, x):
return self._mod_relu(x, self.b)
class CReLU(nn.Module):
"""
implements the activation function
M(z) = ReLU(Re(z)) + j*ReLU(Im(z))
"""
def __init__(self):
super(CReLU, self).__init__()
self.relu = nn.ReLU()
@torch.jit.script
def forward(self, x):
return torch.relu(x.real) + 1j*torch.relu(x.imag)
class ZReLU(nn.Module):
"""
implements the activation function
M(z) = z if 0 <= angle(z) <= pi/2
= 0 otherwise
"""
def __init__(self):
super(ZReLU, self).__init__()
@torch.jit.script
def forward(self, x):
return x * (torch.angle(x) >= 0) * (torch.angle(x) <= torch.pi/2)
# class ComplexFeedForwardNN(nn.Module):
# def __init__(self, in_features, hidden_features, out_features):
# super(ComplexFeedForwardNN, self).__init__()
# self.in_features = in_features
# self.hidden_features = hidden_features
# self.out_features = out_features
# self.fc1 = UnitaryLayer(in_features, hidden_features)
# self.fc2 = UnitaryLayer(hidden_features, out_features)
# def forward(self, x):
# x = self.fc1(x)
# x = self.fc2(x)
# return x