refactor complex loss functions for improved readability; update settings and dataset classes for consistency
This commit is contained in:
@@ -8,10 +8,7 @@ def complex_mse_loss(input, target):
|
||||
Compute the mean squared error between two complex tensors.
|
||||
"""
|
||||
if input.is_complex():
|
||||
return torch.mean(
|
||||
torch.square(input.real - target.real)
|
||||
+ torch.square(input.imag - target.imag)
|
||||
)
|
||||
return torch.mean(torch.square(input.real - target.real) + torch.square(input.imag - target.imag))
|
||||
else:
|
||||
return F.mse_loss(input, target)
|
||||
|
||||
@@ -21,10 +18,7 @@ def complex_sse_loss(input, target):
|
||||
Compute the sum squared error between two complex tensors.
|
||||
"""
|
||||
if input.is_complex():
|
||||
return torch.sum(
|
||||
torch.square(input.real - target.real)
|
||||
+ torch.square(input.imag - target.imag)
|
||||
)
|
||||
return torch.sum(torch.square(input.real - target.real) + torch.square(input.imag - target.imag))
|
||||
else:
|
||||
return torch.sum(torch.square(input - target))
|
||||
|
||||
@@ -41,19 +35,20 @@ class UnitaryLayer(nn.Module):
|
||||
def reset_parameters(self):
|
||||
q, _ = torch.linalg.qr(self.weight)
|
||||
self.weight.data = q
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
return torch.matmul(x, self.weight)
|
||||
|
||||
|
||||
def __repr__(self):
|
||||
return f"UnitaryLayer({self.in_features}, {self.out_features})"
|
||||
|
||||
|
||||
|
||||
class SemiUnitaryLayer(nn.Module):
|
||||
def __init__(self, input_dim, output_dim, dtype=None):
|
||||
super(SemiUnitaryLayer, self).__init__()
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
|
||||
|
||||
# Create a larger square matrix for QR decomposition
|
||||
self.weight = nn.Parameter(torch.randn(max(input_dim, output_dim), max(input_dim, output_dim), dtype=dtype))
|
||||
self.reset_parameters()
|
||||
@@ -62,14 +57,14 @@ class SemiUnitaryLayer(nn.Module):
|
||||
# Ensure the weights are semi-unitary by QR decomposition
|
||||
q, _ = torch.linalg.qr(self.weight)
|
||||
if self.input_dim > self.output_dim:
|
||||
self.weight.data = q[:self.input_dim, :self.output_dim]
|
||||
self.weight.data = q[: self.input_dim, : self.output_dim]
|
||||
else:
|
||||
self.weight.data = q[:self.output_dim, :self.input_dim].t()
|
||||
self.weight.data = q[: self.output_dim, : self.input_dim].t()
|
||||
|
||||
def forward(self, x):
|
||||
out = torch.matmul(x, self.weight)
|
||||
return out
|
||||
|
||||
|
||||
def __repr__(self):
|
||||
return f"SemiUnitaryLayer({self.input_dim}, {self.output_dim})"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user