Skip to content

Commit

Permalink
add documentation for inference functions
Browse files Browse the repository at this point in the history
  • Loading branch information
Dingel321 committed May 15, 2023
1 parent 1413aa4 commit 4c75625
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 78 deletions.
45 changes: 8 additions & 37 deletions src/cryo_sbi/inference/models/build_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,17 @@
from cryo_sbi.inference.models.embedding_nets import EMBEDDING_NETS


def build_npe_flow_model(config, **embedding_kwargs):
def build_npe_flow_model(config: dict, **embedding_kwargs) -> nn.Module:
"""
Function to build NPE estimator with embedding net
from config_file
Args:
config (dict): config file
embedding_kwargs (dict): kwargs for embedding net
Returns:
estimator (nn.Module): NPE estimator
"""

if config["MODEL"] == "MAF":
Expand Down Expand Up @@ -46,39 +53,3 @@ def build_npe_flow_model(config, **embedding_kwargs):
)

return estimator


def build_nre_classifier_model(config):
"""
Function to build NRE estimator with embedding net
from config_file
"""

if config["MODEL"] == "RESMLP":
model = lampe.nn.ResMLP
elif config["MODEL"] == "MLP":
model = zuko.nn.MLP
else:
raise NotImplementedError(
f"Model : {config['MODEL']} has not been implemented yet!"
)

try:
embedding = partial(EMBEDDING_NETS[config["EMBEDDING"]], config["OUT_DIM"])
except:
raise NotImplementedError(
f"Model : {config['EMBEDDING']} has not been implemented yet! \
The following embeddings are implemented : {[key for key in EMBEDDING_NETS.keys()]}"
)

estimator = estimator_models.NREWithEmbedding(
embedding_net=embedding,
output_embedding_dim=config["OUT_DIM"],
hidden_features=config["HIDDEN_FEATURES"],
activation=partial(nn.LeakyReLU, 0.1),
network=model,
theta_scale=config["THETA_SCALE"],
theta_shift=config["THETA_SHIFT"],
)

return estimator
10 changes: 9 additions & 1 deletion src/cryo_sbi/inference/models/embedding_nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,15 @@


def add_embedding(name):
"""Adds the class to the embedding_nets dict with specific key"""
"""
Add embedding net to EMBEDDING_NETS dict
Args:
name (str): name of embedding net
Returns:
add (function): function to add embedding net to EMBEDDING_NETS dict
"""

def add(class_):
EMBEDDING_NETS[name] = class_
Expand Down
109 changes: 69 additions & 40 deletions src/cryo_sbi/inference/models/estimator_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,37 +9,92 @@


class Standardize(nn.Module):
"""Module to standardize inputs and retransform them to the original space"""
"""
Module to standardize inputs and retransform them to the original space
Args:
mean (torch.Tensor): mean of the data
std (torch.Tensor): standard deviation of the data
Returns:
standardized (torch.Tensor): standardized data
"""

# Code adapted from :https://github.com/mackelab/sbi/blob/main/sbi/utils/sbiutils.py
def __init__(self, mean, std):
def __init__(self, mean: float, std: float) -> None:
super(Standardize, self).__init__()
mean, std = map(torch.as_tensor, (mean, std))
self.mean = mean
self.std = std
self.register_buffer("_mean", mean)
self.register_buffer("_std", std)

def forward(self, tensor):
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
"""
Standardize the input tensor
Args:
tensor (torch.Tensor): input tensor
Returns:
standardized (torch.Tensor): standardized tensor
"""

return (tensor - self._mean) / self._std

def transform(self, tensor):
def transform(self, tensor: torch.Tensor) -> torch.Tensor:
"""
Transform the standardized tensor back to the original space
Args:
tensor (torch.Tensor): input tensor
Returns:
retransformed (torch.Tensor): retransformed tensor
"""

return (tensor * self._std) + self._mean


class NPEWithEmbedding(nn.Module):
"""Neural Posterior Estimation with embedding net
Attributes:
npe (NPE): NPE model
embedding (nn.Module): embedding net
standardize (Standardize): standardization module
"""

def __init__(
self,
embedding_net,
output_embedding_dim,
num_transforms=4,
num_hidden_flow=2,
hidden_flow_dim=128,
flow=zuko.flows.MAF,
theta_shift=0,
theta_scale=1,
**kwargs
):
embedding_net: nn.Module,
output_embedding_dim: int,
num_transforms: int = 4,
num_hidden_flow: int = 2,
hidden_flow_dim: int = 128,
flow: nn.Module = zuko.flows.MAF,
theta_shift: float = 0.0,
theta_scale: float = 1.0,
**kwargs,
) -> None:
"""
Neural Posterior Estimation with embedding net.
Args:
embedding_net (nn.Module): embedding net
output_embedding_dim (int): output embedding dimension
num_transforms (int, optional): number of transforms. Defaults to 4.
num_hidden_flow (int, optional): number of hidden layers in flow. Defaults to 2.
hidden_flow_dim (int, optional): hidden dimension in flow. Defaults to 128.
flow (nn.Module, optional): flow. Defaults to zuko.flows.MAF.
theta_shift (float, optional): Shift of the theta for standardization. Defaults to 0.0.
theta_scale (float, optional): Scale of the theta for standardization. Defaults to 1.0.
kwargs: additional arguments for the flow
Returns:
None
"""

super().__init__()

self.npe = NPE(
Expand All @@ -64,29 +119,3 @@ def sample(self, x: torch.Tensor, shape=(1,)):
samples_standardized = self.flow(x).sample(shape)
return self.standardize.transform(samples_standardized)


class NREWithEmbedding(nn.Module):
def __init__(
self,
embedding_net,
output_embedding_dim,
hidden_features,
activation,
network,
theta_shift=0,
theta_scale=1,
):
super().__init__()

self.nre = NRE(
1,
output_embedding_dim,
hidden_features=hidden_features,
activation=activation,
build=network,
)
self.embedding = embedding_net
self.standardize = Standardize(theta_shift, theta_scale)

def forward(self, theta, x):
return self.nre(self.standardize(theta), self.embedding(x))

0 comments on commit 4c75625

Please sign in to comment.