refactor complex loss functions for improved readability; update settings and dataset classes for consistency

This commit is contained in:
Joseph Hopfmüller
2024-11-24 01:55:32 +01:00
parent 9a16a5637d
commit 7343ccb3a5
4 changed files with 392 additions and 361 deletions

View File

@@ -8,10 +8,7 @@ def complex_mse_loss(input, target):
Compute the mean squared error between two complex tensors.
"""
if input.is_complex():
return torch.mean(
torch.square(input.real - target.real)
+ torch.square(input.imag - target.imag)
)
return torch.mean(torch.square(input.real - target.real) + torch.square(input.imag - target.imag))
else:
return F.mse_loss(input, target)
@@ -21,10 +18,7 @@ def complex_sse_loss(input, target):
Compute the sum squared error between two complex tensors.
"""
if input.is_complex():
return torch.sum(
torch.square(input.real - target.real)
+ torch.square(input.imag - target.imag)
)
return torch.sum(torch.square(input.real - target.real) + torch.square(input.imag - target.imag))
else:
return torch.sum(torch.square(input - target))
@@ -41,19 +35,20 @@ class UnitaryLayer(nn.Module):
def reset_parameters(self):
q, _ = torch.linalg.qr(self.weight)
self.weight.data = q
def forward(self, x):
return torch.matmul(x, self.weight)
def __repr__(self):
return f"UnitaryLayer({self.in_features}, {self.out_features})"
class SemiUnitaryLayer(nn.Module):
def __init__(self, input_dim, output_dim, dtype=None):
super(SemiUnitaryLayer, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
# Create a larger square matrix for QR decomposition
self.weight = nn.Parameter(torch.randn(max(input_dim, output_dim), max(input_dim, output_dim), dtype=dtype))
self.reset_parameters()
@@ -62,14 +57,14 @@ class SemiUnitaryLayer(nn.Module):
# Ensure the weights are semi-unitary by QR decomposition
q, _ = torch.linalg.qr(self.weight)
if self.input_dim > self.output_dim:
self.weight.data = q[:self.input_dim, :self.output_dim]
self.weight.data = q[: self.input_dim, : self.output_dim]
else:
self.weight.data = q[:self.output_dim, :self.input_dim].t()
self.weight.data = q[: self.output_dim, : self.input_dim].t()
def forward(self, x):
out = torch.matmul(x, self.weight)
return out
def __repr__(self):
return f"SemiUnitaryLayer({self.input_dim}, {self.output_dim})"