From 9946b0d459336aa1e8e36f98412c5ef72b2c1a7e Mon Sep 17 00:00:00 2001 From: Dingel321 Date: Thu, 13 Jun 2024 11:33:10 -0400 Subject: [PATCH] improved bug in parameter normalization --- src/cryo_sbi/inference/models/estimator_models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/cryo_sbi/inference/models/estimator_models.py b/src/cryo_sbi/inference/models/estimator_models.py index 19498af..ae922af 100644 --- a/src/cryo_sbi/inference/models/estimator_models.py +++ b/src/cryo_sbi/inference/models/estimator_models.py @@ -104,8 +104,8 @@ def __init__( ) self.embedding = embedding_net() - theta_shifts = (theta_shift, 0.0, 0.0, 0.0, 0.0, 2.75, 50.0, 0.05) - theta_scales = (theta_scale, 1.0, 1.0, 1.0, 1.0, 2.75, 50.0, 0.05) + theta_shifts = (theta_shift, 0.0, 0.0, 0.0, 0.0, 2.75, 50.0, -2.0) + theta_scales = (theta_scale, 1.0, 1.0, 1.0, 1.0, 2.75, 50.0, 1.0) self.standardize = Standardize(theta_shifts, theta_scales) def forward(self, theta: torch.Tensor, x: torch.Tensor) -> torch.Tensor: