wip
This commit is contained in:
@@ -441,8 +441,7 @@ class input_rotator(nn.Module):
|
||||
# return out
|
||||
|
||||
|
||||
#### as defined by zhang et al
|
||||
|
||||
#### as defined by zhang et alas
|
||||
|
||||
class DropoutComplex(nn.Module):
|
||||
def __init__(self, p=0.5):
|
||||
@@ -464,7 +463,7 @@ class Scale(nn.Module):
|
||||
self.scale = nn.Parameter(torch.ones(size, dtype=torch.float32))
|
||||
|
||||
def forward(self, x):
|
||||
return x * self.scale
|
||||
return x * torch.sqrt(self.scale)
|
||||
|
||||
def __repr__(self):
|
||||
return f"Scale({self.size})"
|
||||
@@ -546,35 +545,31 @@ class EOActivation(nn.Module):
|
||||
raise ValueError("Size must be specified")
|
||||
self.size = size
|
||||
self.alpha = nn.Parameter(torch.rand(size))
|
||||
self.V_bias = nn.Parameter(torch.rand(size))
|
||||
self.gain = nn.Parameter(torch.rand(size))
|
||||
# if bias:
|
||||
# self.phase_bias = nn.Parameter(torch.zeros(size))
|
||||
# else:
|
||||
# self.register_buffer("phase_bias", torch.zeros(size))
|
||||
# self.register_buffer("phase_bias", torch.clamp(torch.ones(size) + torch.randn(size)*0.1, 0, 1)*torch.pi)
|
||||
self.register_buffer("responsivity", torch.ones(size)*0.9)
|
||||
self.register_buffer("V_pi", torch.ones(size)*3)
|
||||
self.V_bias = nn.Parameter(torch.rand(size))
|
||||
# self.register_buffer("gain", torch.ones(size))
|
||||
# self.register_buffer("responsivity", torch.ones(size))
|
||||
# self.register_buffer("V_pi", torch.ones(size))
|
||||
|
||||
self.reset_weights()
|
||||
|
||||
def reset_weights(self):
|
||||
if "alpha" in self._parameters:
|
||||
self.alpha.data = torch.rand(self.size)
|
||||
if "V_pi" in self._parameters:
|
||||
self.V_pi.data = torch.rand(self.size)*3
|
||||
# if "V_pi" in self._parameters:
|
||||
# self.V_pi.data = torch.rand(self.size)*3
|
||||
if "V_bias" in self._parameters:
|
||||
self.V_bias.data = torch.randn(self.size)
|
||||
if "gain" in self._parameters:
|
||||
self.gain.data = torch.rand(self.size)
|
||||
if "responsivity" in self._parameters:
|
||||
self.responsivity.data = torch.ones(self.size)*0.9
|
||||
# if "responsivity" in self._parameters:
|
||||
# self.responsivity.data = torch.ones(self.size)*0.9
|
||||
# if "bias" in self._parameters:
|
||||
# self.phase_bias.data = torch.zeros(self.size)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
phi_b = torch.pi * self.V_bias / (self.V_pi + 1e-8)
|
||||
g_phi = torch.pi * (self.alpha * self.gain * self.responsivity) / (self.V_pi + 1e-8)
|
||||
phi_b = torch.pi * self.V_bias# / (self.V_pi)
|
||||
g_phi = torch.pi * (self.alpha * self.gain)# * self.responsivity)# / (self.V_pi)
|
||||
intermediate = g_phi * x.abs().square() + phi_b
|
||||
return (
|
||||
1j
|
||||
|
||||
Reference in New Issue
Block a user