move hypertraining class into separate file;
move settings dataclasses into separate file; add SemiUnitaryLayer; clean up model response plotting code; cnt hyperparameter search
This commit is contained in:
@@ -1,116 +1,160 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def complex_mse_loss(input, target):
|
||||
"""
|
||||
Compute the mean squared error between two complex tensors.
|
||||
"""
|
||||
return torch.mean(torch.square(input.real - target.real) + torch.square(input.imag - target.imag))
|
||||
if input.is_complex():
|
||||
return torch.mean(
|
||||
torch.square(input.real - target.real)
|
||||
+ torch.square(input.imag - target.imag)
|
||||
)
|
||||
else:
|
||||
return F.mse_loss(input, target)
|
||||
|
||||
|
||||
def complex_sse_loss(input, target):
|
||||
"""
|
||||
Compute the sum squared error between two complex tensors.
|
||||
"""
|
||||
if input.is_complex():
|
||||
return torch.sum(torch.square(input.real - target.real) + torch.square(input.imag - target.imag))
|
||||
return torch.sum(
|
||||
torch.square(input.real - target.real)
|
||||
+ torch.square(input.imag - target.imag)
|
||||
)
|
||||
else:
|
||||
return torch.sum(torch.square(input - target))
|
||||
|
||||
|
||||
|
||||
|
||||
class UnitaryLayer(nn.Module):
|
||||
def __init__(self, in_features, out_features):
|
||||
super(UnitaryLayer, self).__init__()
|
||||
def __init__(self, in_features, out_features, dtype=None):
|
||||
assert in_features >= out_features
|
||||
super(UnitaryLayer, self).__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.weight = nn.Parameter(torch.randn(in_features, out_features, dtype=torch.cfloat))
|
||||
self.weight = nn.Parameter(torch.randn(in_features, out_features, dtype=dtype))
|
||||
self.reset_parameters()
|
||||
|
||||
|
||||
def reset_parameters(self):
|
||||
q, _ = torch.linalg.qr(self.weight)
|
||||
self.weight.data = q
|
||||
|
||||
def forward(self, x):
|
||||
return torch.matmul(x, self.weight)
|
||||
|
||||
def __repr__(self):
|
||||
return f"UnitaryLayer({self.in_features}, {self.out_features})"
|
||||
|
||||
class SemiUnitaryLayer(nn.Module):
|
||||
def __init__(self, input_dim, output_dim, dtype=None):
|
||||
super(SemiUnitaryLayer, self).__init__()
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
|
||||
# Create a larger square matrix for QR decomposition
|
||||
self.weight = nn.Parameter(torch.randn(max(input_dim, output_dim), max(input_dim, output_dim), dtype=dtype))
|
||||
self.reset_parameters()
|
||||
|
||||
@staticmethod
|
||||
@torch.jit.script
|
||||
def _unitary_forward(x, weight):
|
||||
out = torch.matmul(x, weight)
|
||||
return out
|
||||
def reset_parameters(self):
|
||||
# Ensure the weights are semi-unitary by QR decomposition
|
||||
q, _ = torch.linalg.qr(self.weight)
|
||||
if self.input_dim > self.output_dim:
|
||||
self.weight.data = q[:self.input_dim, :self.output_dim]
|
||||
else:
|
||||
self.weight.data = q[:self.output_dim, :self.input_dim].t()
|
||||
|
||||
def forward(self, x):
|
||||
return self._unitary_forward(x, self.weight)
|
||||
out = torch.matmul(x, self.weight)
|
||||
return out
|
||||
|
||||
def __repr__(self):
|
||||
return f"SemiUnitaryLayer({self.input_dim}, {self.output_dim})"
|
||||
|
||||
|
||||
# class SpreadLayer(nn.Module):
|
||||
# def __init__(self, in_features, out_features, dtype=None):
|
||||
# super(SpreadLayer, self).__init__()
|
||||
# self.in_features = in_features
|
||||
# self.out_features = out_features
|
||||
# self.mat = torch.ones(in_features, out_features, dtype=dtype)*torch.sqrt(torch.tensor(in_features/out_features))
|
||||
|
||||
# def forward(self, x):
|
||||
# # N in_features -> M out_features, Enery is preserved (P = abs(x)^2)
|
||||
# out = torch.matmul(x, self.mat)
|
||||
# return out
|
||||
|
||||
|
||||
#### as defined by zhang et al
|
||||
|
||||
|
||||
class Identity(nn.Module):
|
||||
"""
|
||||
implements the "activation" function
|
||||
implements the "activation" function
|
||||
M(z) = z
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(Identity, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
|
||||
class Mag(nn.Module):
|
||||
"""
|
||||
implements the activation function
|
||||
implements the activation function
|
||||
M(z) = ||z||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(Mag, self).__init__()
|
||||
|
||||
@torch.jit.script
|
||||
def forward(self, x):
|
||||
return torch.abs(x.real**2 + x.imag**2)
|
||||
|
||||
# class Tanh(nn.Module):
|
||||
# """
|
||||
# implements the activation function
|
||||
# M(z) = tanh(z) = sinh(z)/cosh(z) = (exp(z)-exp(-z))/(exp(z)+exp(-z)) = (exp(2*z)-1)/(exp(2*z)+1)
|
||||
# """
|
||||
# def __init__(self):
|
||||
# super(Tanh, self).__init__()
|
||||
return torch.abs(x).to(dtype=x.dtype)
|
||||
|
||||
|
||||
# def forward(self, x):
|
||||
# return torch.tanh(x)
|
||||
|
||||
class ModReLU(nn.Module):
|
||||
"""
|
||||
implements the activation function
|
||||
implements the activation function
|
||||
M(z) = ReLU(||z|| + b)*exp(j*theta_z)
|
||||
= ReLU(||z|| + b)*z/||z||
|
||||
"""
|
||||
|
||||
def __init__(self, b=0):
|
||||
super(ModReLU, self).__init__()
|
||||
self.b = b
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
@staticmethod
|
||||
# @torch.jit.script
|
||||
def _mod_relu(x, b):
|
||||
mod = torch.abs(x.real**2 + x.imag**2)
|
||||
return torch.relu(mod + b) * x / mod
|
||||
self.b = torch.tensor(b)
|
||||
|
||||
def forward(self, x):
|
||||
return self._mod_relu(x, self.b)
|
||||
|
||||
if x.is_complex():
|
||||
mod = torch.abs(x.real**2 + x.imag**2)
|
||||
return torch.relu(mod + self.b) * x / mod
|
||||
|
||||
else:
|
||||
return torch.relu(x + self.b)
|
||||
|
||||
def __repr__(self):
|
||||
return f"ModReLU(b={self.b})"
|
||||
|
||||
|
||||
class CReLU(nn.Module):
|
||||
"""
|
||||
implements the activation function
|
||||
M(z) = ReLU(Re(z)) + j*ReLU(Im(z))
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(CReLU, self).__init__()
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
@torch.jit.script
|
||||
def forward(self, x):
|
||||
return torch.relu(x.real) + 1j*torch.relu(x.imag)
|
||||
|
||||
if x.is_complex():
|
||||
return torch.relu(x.real) + 1j * torch.relu(x.imag)
|
||||
else:
|
||||
return torch.relu(x)
|
||||
|
||||
|
||||
class ZReLU(nn.Module):
|
||||
"""
|
||||
implements the activation function
|
||||
@@ -122,20 +166,8 @@ class ZReLU(nn.Module):
|
||||
def __init__(self):
|
||||
super(ZReLU, self).__init__()
|
||||
|
||||
@torch.jit.script
|
||||
def forward(self, x):
|
||||
return x * (torch.angle(x) >= 0) * (torch.angle(x) <= torch.pi/2)
|
||||
|
||||
# class ComplexFeedForwardNN(nn.Module):
|
||||
# def __init__(self, in_features, hidden_features, out_features):
|
||||
# super(ComplexFeedForwardNN, self).__init__()
|
||||
# self.in_features = in_features
|
||||
# self.hidden_features = hidden_features
|
||||
# self.out_features = out_features
|
||||
# self.fc1 = UnitaryLayer(in_features, hidden_features)
|
||||
# self.fc2 = UnitaryLayer(hidden_features, out_features)
|
||||
|
||||
# def forward(self, x):
|
||||
# x = self.fc1(x)
|
||||
# x = self.fc2(x)
|
||||
# return x
|
||||
if x.is_complex():
|
||||
return x * (torch.angle(x) >= 0) * (torch.angle(x) <= torch.pi / 2)
|
||||
else:
|
||||
return torch.relu(x)
|
||||
|
||||
Reference in New Issue
Block a user