update dataset configurations, add rotation module, and refine model settings for training, new hyperparameter tuning run for corrected datasets
This commit is contained in:
@@ -319,6 +319,29 @@ class normalize_by_first(nn.Module):
|
||||
|
||||
def forward(self, data):
|
||||
return data / data[:, 0].unsqueeze(1)
|
||||
|
||||
class rotate(nn.Module):
|
||||
def __init__(self):
|
||||
super(rotate, self).__init__()
|
||||
|
||||
def forward(self, data, angle):
|
||||
# data -> (batch, n*2)
|
||||
# angle -> (batch, n)
|
||||
data_ = data
|
||||
if angle.ndim == 1:
|
||||
angle_ = angle.unsqueeze(1)
|
||||
else:
|
||||
angle_ = angle
|
||||
angle_ = angle_.expand(-1, data_.shape[1]//2)
|
||||
c = torch.cos(angle_)
|
||||
s = torch.sin(angle_)
|
||||
rot = torch.stack([torch.stack([c, -s], dim=2),
|
||||
torch.stack([s, c], dim=2)], dim=3)
|
||||
d = torch.bmm(data_.reshape(-1, 1, 2), rot.view(-1, 2, 2).to(dtype=data_.dtype)).reshape(*data.shape)
|
||||
# d = torch.bmm(data.unsqueeze(-1).mT, rot.to(dtype=data.dtype).mT).mT.squeeze(-1)
|
||||
|
||||
return d
|
||||
|
||||
|
||||
class photodiode(nn.Module):
|
||||
def __init__(self, size, bias=True):
|
||||
@@ -487,7 +510,7 @@ class MZISingle(nn.Module):
|
||||
return x * torch.exp(1j * self.phi) * torch.sin(self.omega + self.func(x))
|
||||
|
||||
def naive_angle_loss(x: torch.Tensor, target: torch.Tensor, mod=2*torch.pi):
|
||||
return torch.fmod((x - target), mod).square().mean()
|
||||
return torch.fmod((x.abs().real - target.abs().real), mod).abs().mean()
|
||||
|
||||
def cosine_loss(x: torch.Tensor, target: torch.Tensor):
|
||||
return (2*(1 - torch.cos(x - target))).mean()
|
||||
|
||||
Reference in New Issue
Block a user