add regenerator class and update dataset configurations for model training

This commit is contained in:
Joseph Hopfmüller
2024-12-05 23:55:03 +01:00
parent 884d9f73c9
commit 0e29b87395
7 changed files with 82705 additions and 353 deletions

View File

@@ -569,6 +569,78 @@ class ZReLU(nn.Module):
return x * (torch.angle(x) >= 0) * (torch.angle(x) <= torch.pi / 2)
else:
return torch.relu(x)
class regenerator(nn.Module):
def __init__(
self,
*dims,
layer_function=ONN,
layer_kwargs: dict | None = None,
layer_parametrizations: list[dict] = None,
activation_function=Pow,
dtype=torch.float64,
dropout_prob=0.01,
scale=False,
**kwargs,
):
super(regenerator, self).__init__()
if len(dims) == 0:
try:
dims = kwargs["dims"]
except KeyError:
raise ValueError("dims must be provided")
self._n_hidden_layers = len(dims) - 2
self._layers = nn.Sequential()
if layer_kwargs is None:
layer_kwargs = {}
# self.powers = []
for i in range(self._n_hidden_layers + 1):
if scale:
self._layers.append(Scale(dims[i]))
self._layers.append(layer_function(dims[i], dims[i + 1], dtype=dtype, **layer_kwargs))
if i < self._n_hidden_layers:
if dropout_prob is not None:
self._layers.append(DropoutComplex(p=dropout_prob))
self._layers.append(activation_function(bias=True, size=dims[i + 1]))
self._layers.append(Scale(dims[-1]))
# add parametrizations
if layer_parametrizations is not None:
for layer in self._layers:
for layer_parametrization in layer_parametrizations:
tensor_name = layer_parametrization.get("tensor_name", None)
parametrization = layer_parametrization.get("parametrization", None)
param_kwargs = layer_parametrization.get("kwargs", {})
if tensor_name is not None and tensor_name in layer._parameters and parametrization is not None:
parametrization(layer, tensor_name, **param_kwargs)
# def __call__(self, input_x, **kwargs):
# return self.forward(input_x, **kwargs)
def forward(self, input_x, trace_powers=False):
x = input_x
if trace_powers:
powers = [x.abs().square().sum()]
# check if tracing
if torch.jit.is_tracing():
for layer in self._layers:
x = layer(x)
if trace_powers:
powers.append(x.abs().square().sum())
else:
# with torch.nn.utils.parametrize.cached():
for layer in self._layers:
x = layer(x)
if trace_powers:
powers.append(x.abs().square().sum())
if trace_powers:
return x, powers
return x
__all__ = [