add regenerator class and update dataset configurations for model training
This commit is contained in:
@@ -569,6 +569,78 @@ class ZReLU(nn.Module):
|
||||
return x * (torch.angle(x) >= 0) * (torch.angle(x) <= torch.pi / 2)
|
||||
else:
|
||||
return torch.relu(x)
|
||||
|
||||
|
||||
class regenerator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*dims,
|
||||
layer_function=ONN,
|
||||
layer_kwargs: dict | None = None,
|
||||
layer_parametrizations: list[dict] = None,
|
||||
activation_function=Pow,
|
||||
dtype=torch.float64,
|
||||
dropout_prob=0.01,
|
||||
scale=False,
|
||||
**kwargs,
|
||||
):
|
||||
super(regenerator, self).__init__()
|
||||
if len(dims) == 0:
|
||||
try:
|
||||
dims = kwargs["dims"]
|
||||
except KeyError:
|
||||
raise ValueError("dims must be provided")
|
||||
self._n_hidden_layers = len(dims) - 2
|
||||
self._layers = nn.Sequential()
|
||||
if layer_kwargs is None:
|
||||
layer_kwargs = {}
|
||||
# self.powers = []
|
||||
|
||||
for i in range(self._n_hidden_layers + 1):
|
||||
if scale:
|
||||
self._layers.append(Scale(dims[i]))
|
||||
self._layers.append(layer_function(dims[i], dims[i + 1], dtype=dtype, **layer_kwargs))
|
||||
if i < self._n_hidden_layers:
|
||||
if dropout_prob is not None:
|
||||
self._layers.append(DropoutComplex(p=dropout_prob))
|
||||
self._layers.append(activation_function(bias=True, size=dims[i + 1]))
|
||||
|
||||
self._layers.append(Scale(dims[-1]))
|
||||
|
||||
# add parametrizations
|
||||
if layer_parametrizations is not None:
|
||||
for layer in self._layers:
|
||||
for layer_parametrization in layer_parametrizations:
|
||||
tensor_name = layer_parametrization.get("tensor_name", None)
|
||||
parametrization = layer_parametrization.get("parametrization", None)
|
||||
param_kwargs = layer_parametrization.get("kwargs", {})
|
||||
if tensor_name is not None and tensor_name in layer._parameters and parametrization is not None:
|
||||
parametrization(layer, tensor_name, **param_kwargs)
|
||||
|
||||
# def __call__(self, input_x, **kwargs):
|
||||
# return self.forward(input_x, **kwargs)
|
||||
|
||||
def forward(self, input_x, trace_powers=False):
|
||||
x = input_x
|
||||
|
||||
if trace_powers:
|
||||
powers = [x.abs().square().sum()]
|
||||
|
||||
# check if tracing
|
||||
if torch.jit.is_tracing():
|
||||
for layer in self._layers:
|
||||
x = layer(x)
|
||||
if trace_powers:
|
||||
powers.append(x.abs().square().sum())
|
||||
else:
|
||||
# with torch.nn.utils.parametrize.cached():
|
||||
for layer in self._layers:
|
||||
x = layer(x)
|
||||
if trace_powers:
|
||||
powers.append(x.abs().square().sum())
|
||||
if trace_powers:
|
||||
return x, powers
|
||||
return x
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
||||
Reference in New Issue
Block a user