Skip to content

Commit

Permalink
improved bug in parameter normalization
Browse files Browse the repository at this point in the history
  • Loading branch information
Dingel321 committed Jun 13, 2024
1 parent 3290324 commit 9946b0d
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/cryo_sbi/inference/models/estimator_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ def __init__(
)

self.embedding = embedding_net()
theta_shifts = (theta_shift, 0.0, 0.0, 0.0, 0.0, 2.75, 50.0, 0.05)
theta_scales = (theta_scale, 1.0, 1.0, 1.0, 1.0, 2.75, 50.0, 0.05)
theta_shifts = (theta_shift, 0.0, 0.0, 0.0, 0.0, 2.75, 50.0, -2.0)
theta_scales = (theta_scale, 1.0, 1.0, 1.0, 1.0, 2.75, 50.0, 1.0)
self.standardize = Standardize(theta_shifts, theta_scales)

def forward(self, theta: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
Expand Down

0 comments on commit 9946b0d

Please sign in to comment.