diff --git a/src/cryo_sbi/inference/models/estimator_models.py b/src/cryo_sbi/inference/models/estimator_models.py index 9c5a1cc..19498af 100644 --- a/src/cryo_sbi/inference/models/estimator_models.py +++ b/src/cryo_sbi/inference/models/estimator_models.py @@ -20,6 +20,7 @@ class Standardize(nn.Module): def __init__(self, mean: float, std: float) -> None: super(Standardize, self).__init__() mean, std = map(torch.as_tensor, (mean, std)) + print(mean, std) self.mean = mean self.std = std self.register_buffer("_mean", mean) @@ -103,7 +104,9 @@ def __init__( ) self.embedding = embedding_net() - self.standardize = Standardize(theta_shift, theta_scale) + theta_shifts = (theta_shift, 0.0, 0.0, 0.0, 0.0, 2.75, 50.0, 0.05) + theta_scales = (theta_scale, 1.0, 1.0, 1.0, 1.0, 2.75, 50.0, 0.05) + self.standardize = Standardize(theta_shifts, theta_scales) def forward(self, theta: torch.Tensor, x: torch.Tensor) -> torch.Tensor: """