This commit is contained in:
Joseph Hopfmüller
2025-01-27 21:05:49 +01:00
parent f38d0ca3bb
commit 249fe1e940
19 changed files with 2266 additions and 880 deletions

View File

@@ -441,8 +441,7 @@ class input_rotator(nn.Module):
# return out
#### as defined by zhang et al
#### as defined by zhang et alas
class DropoutComplex(nn.Module):
def __init__(self, p=0.5):
@@ -464,7 +463,7 @@ class Scale(nn.Module):
self.scale = nn.Parameter(torch.ones(size, dtype=torch.float32))
def forward(self, x):
return x * self.scale
return x * torch.sqrt(self.scale)
def __repr__(self):
return f"Scale({self.size})"
@@ -546,35 +545,31 @@ class EOActivation(nn.Module):
raise ValueError("Size must be specified")
self.size = size
self.alpha = nn.Parameter(torch.rand(size))
self.V_bias = nn.Parameter(torch.rand(size))
self.gain = nn.Parameter(torch.rand(size))
# if bias:
# self.phase_bias = nn.Parameter(torch.zeros(size))
# else:
# self.register_buffer("phase_bias", torch.zeros(size))
# self.register_buffer("phase_bias", torch.clamp(torch.ones(size) + torch.randn(size)*0.1, 0, 1)*torch.pi)
self.register_buffer("responsivity", torch.ones(size)*0.9)
self.register_buffer("V_pi", torch.ones(size)*3)
self.V_bias = nn.Parameter(torch.rand(size))
# self.register_buffer("gain", torch.ones(size))
# self.register_buffer("responsivity", torch.ones(size))
# self.register_buffer("V_pi", torch.ones(size))
self.reset_weights()
def reset_weights(self):
if "alpha" in self._parameters:
self.alpha.data = torch.rand(self.size)
if "V_pi" in self._parameters:
self.V_pi.data = torch.rand(self.size)*3
# if "V_pi" in self._parameters:
# self.V_pi.data = torch.rand(self.size)*3
if "V_bias" in self._parameters:
self.V_bias.data = torch.randn(self.size)
if "gain" in self._parameters:
self.gain.data = torch.rand(self.size)
if "responsivity" in self._parameters:
self.responsivity.data = torch.ones(self.size)*0.9
# if "responsivity" in self._parameters:
# self.responsivity.data = torch.ones(self.size)*0.9
# if "bias" in self._parameters:
# self.phase_bias.data = torch.zeros(self.size)
def forward(self, x: torch.Tensor):
phi_b = torch.pi * self.V_bias / (self.V_pi + 1e-8)
g_phi = torch.pi * (self.alpha * self.gain * self.responsivity) / (self.V_pi + 1e-8)
phi_b = torch.pi * self.V_bias# / (self.V_pi)
g_phi = torch.pi * (self.alpha * self.gain)# * self.responsivity)# / (self.V_pi)
intermediate = g_phi * x.abs().square() + phi_b
return (
1j