Skip to content

Commit

Permalink
adjusted inconsistencies
Browse files Browse the repository at this point in the history
  • Loading branch information
Dingel321 committed Aug 29, 2024
1 parent 29e2792 commit 7ffcba3
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 12 deletions.
11 changes: 2 additions & 9 deletions src/cryo_sbi/inference/models/build_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def build_npe_flow_model(config: dict, **embedding_kwargs) -> nn.Module:
elif config["MODEL"] == "NSF":
model = zuko.flows.NSF
elif config["MODEL"] == "SOSPF":
model = zuko.flows.SOSPF
model = partial(zuko.flows.SOSPF, polynomials=16, degree=5)
else:
raise NotImplementedError(
f"Model : {config['MODEL']} has not been implemented yet!"
Expand All @@ -39,12 +39,6 @@ def build_npe_flow_model(config: dict, **embedding_kwargs) -> nn.Module:
f"Model : {config['EMBEDDING']} has not been implemented yet! \
The following embeddings are implemented : {[key for key in EMBEDDING_NETS.keys()]}"
)

if "BINS" in config:
bins = config["BINS"]
print(f"Using {bins} bins for NPE")
else:
bins = 8

estimator = estimator_models.NPEWithEmbedding(
embedding_net=embedding,
Expand All @@ -55,8 +49,7 @@ def build_npe_flow_model(config: dict, **embedding_kwargs) -> nn.Module:
flow=model,
theta_shift=config["THETA_SHIFT"],
theta_scale=config["THETA_SCALE"],
bins=bins,
**{"activation": partial(nn.LeakyReLU, 0.1)},
**{"activation": nn.GELU},
)

return estimator
Expand Down
4 changes: 1 addition & 3 deletions src/cryo_sbi/inference/models/estimator_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ def __init__(
flow: nn.Module = zuko.flows.MAF,
theta_shift: float = 0.0,
theta_scale: float = 1.0,
bins: int = 8,
**kwargs,
) -> None:
"""
Expand Down Expand Up @@ -99,8 +98,7 @@ def __init__(
output_embedding_dim,
transforms=num_transforms,
build=flow,
hidden_features=[*[hidden_flow_dim] * num_hidden_flow, 128, 64],
bins=bins,
hidden_features=[*[hidden_flow_dim] * num_hidden_flow],
**kwargs,
)
self.type = "NPE"
Expand Down

0 comments on commit 7ffcba3

Please sign in to comment.