diff --git a/src/cryo_sbi/inference/models/estimator_models.py b/src/cryo_sbi/inference/models/estimator_models.py index 19498af..ae922af 100644 --- a/src/cryo_sbi/inference/models/estimator_models.py +++ b/src/cryo_sbi/inference/models/estimator_models.py @@ -104,8 +104,8 @@ def __init__( ) self.embedding = embedding_net() - theta_shifts = (theta_shift, 0.0, 0.0, 0.0, 0.0, 2.75, 50.0, 0.05) - theta_scales = (theta_scale, 1.0, 1.0, 1.0, 1.0, 2.75, 50.0, 0.05) + theta_shifts = (theta_shift, 0.0, 0.0, 0.0, 0.0, 2.75, 50.0, -2.0) + theta_scales = (theta_scale, 1.0, 1.0, 1.0, 1.0, 2.75, 50.0, 1.0) self.standardize = Standardize(theta_shifts, theta_scales) def forward(self, theta: torch.Tensor, x: torch.Tensor) -> torch.Tensor: