From 4c756252102fb00a84d555fe1025faff6b859dc9 Mon Sep 17 00:00:00 2001 From: Dingel321 Date: Mon, 15 May 2023 16:35:35 +0200 Subject: [PATCH] add documentation for inference functions --- src/cryo_sbi/inference/models/build_models.py | 45 ++------ .../inference/models/embedding_nets.py | 10 +- .../inference/models/estimator_models.py | 109 +++++++++++------- 3 files changed, 86 insertions(+), 78 deletions(-) diff --git a/src/cryo_sbi/inference/models/build_models.py b/src/cryo_sbi/inference/models/build_models.py index 6f5283d..939cca5 100644 --- a/src/cryo_sbi/inference/models/build_models.py +++ b/src/cryo_sbi/inference/models/build_models.py @@ -6,10 +6,17 @@ from cryo_sbi.inference.models.embedding_nets import EMBEDDING_NETS -def build_npe_flow_model(config, **embedding_kwargs): +def build_npe_flow_model(config: dict, **embedding_kwargs) -> nn.Module: """ Function to build NPE estimator with embedding net from config_file + + Args: + config (dict): config file + embedding_kwargs (dict): kwargs for embedding net + + Returns: + estimator (nn.Module): NPE estimator """ if config["MODEL"] == "MAF": @@ -46,39 +53,3 @@ def build_npe_flow_model(config, **embedding_kwargs): ) return estimator - - -def build_nre_classifier_model(config): - """ - Function to build NRE estimator with embedding net - from config_file - """ - - if config["MODEL"] == "RESMLP": - model = lampe.nn.ResMLP - elif config["MODEL"] == "MLP": - model = zuko.nn.MLP - else: - raise NotImplementedError( - f"Model : {config['MODEL']} has not been implemented yet!" - ) - - try: - embedding = partial(EMBEDDING_NETS[config["EMBEDDING"]], config["OUT_DIM"]) - except: - raise NotImplementedError( - f"Model : {config['EMBEDDING']} has not been implemented yet! \ -The following embeddings are implemented : {[key for key in EMBEDDING_NETS.keys()]}" - ) - - estimator = estimator_models.NREWithEmbedding( - embedding_net=embedding, - output_embedding_dim=config["OUT_DIM"], - hidden_features=config["HIDDEN_FEATURES"], - activation=partial(nn.LeakyReLU, 0.1), - network=model, - theta_scale=config["THETA_SCALE"], - theta_shift=config["THETA_SHIFT"], - ) - - return estimator diff --git a/src/cryo_sbi/inference/models/embedding_nets.py b/src/cryo_sbi/inference/models/embedding_nets.py index 65ccabd..826e7e9 100644 --- a/src/cryo_sbi/inference/models/embedding_nets.py +++ b/src/cryo_sbi/inference/models/embedding_nets.py @@ -9,7 +9,15 @@ def add_embedding(name): - """Adds the class to the embedding_nets dict with specific key""" + """ + Add embedding net to EMBEDDING_NETS dict + + Args: + name (str): name of embedding net + + Returns: + add (function): function to add embedding net to EMBEDDING_NETS dict + """ def add(class_): EMBEDDING_NETS[name] = class_ diff --git a/src/cryo_sbi/inference/models/estimator_models.py b/src/cryo_sbi/inference/models/estimator_models.py index cf6d2f2..95f394f 100644 --- a/src/cryo_sbi/inference/models/estimator_models.py +++ b/src/cryo_sbi/inference/models/estimator_models.py @@ -9,10 +9,19 @@ class Standardize(nn.Module): - """Module to standardize inputs and retransform them to the original space""" + """ + Module to standardize inputs and retransform them to the original space + + Args: + mean (torch.Tensor): mean of the data + std (torch.Tensor): standard deviation of the data + + Returns: + standardized (torch.Tensor): standardized data + """ # Code adapted from :https://github.com/mackelab/sbi/blob/main/sbi/utils/sbiutils.py - def __init__(self, mean, std): + def __init__(self, mean: float, std: float) -> None: super(Standardize, self).__init__() mean, std = map(torch.as_tensor, (mean, std)) self.mean = mean @@ -20,26 +29,72 @@ def __init__(self, mean, std): self.register_buffer("_mean", mean) self.register_buffer("_std", std) - def forward(self, tensor): + def forward(self, tensor: torch.Tensor) -> torch.Tensor: + """ + Standardize the input tensor + + Args: + tensor (torch.Tensor): input tensor + + Returns: + standardized (torch.Tensor): standardized tensor + """ + return (tensor - self._mean) / self._std - def transform(self, tensor): + def transform(self, tensor: torch.Tensor) -> torch.Tensor: + """ + Transform the standardized tensor back to the original space + + Args: + tensor (torch.Tensor): input tensor + + Returns: + retransformed (torch.Tensor): retransformed tensor + """ + return (tensor * self._std) + self._mean class NPEWithEmbedding(nn.Module): + """Neural Posterior Estimation with embedding net + + Attributes: + npe (NPE): NPE model + embedding (nn.Module): embedding net + standardize (Standardize): standardization module + """ + def __init__( self, - embedding_net, - output_embedding_dim, - num_transforms=4, - num_hidden_flow=2, - hidden_flow_dim=128, - flow=zuko.flows.MAF, - theta_shift=0, - theta_scale=1, - **kwargs - ): + embedding_net: nn.Module, + output_embedding_dim: int, + num_transforms: int = 4, + num_hidden_flow: int = 2, + hidden_flow_dim: int = 128, + flow: nn.Module = zuko.flows.MAF, + theta_shift: float = 0.0, + theta_scale: float = 1.0, + **kwargs, + ) -> None: + """ + Neural Posterior Estimation with embedding net. + + Args: + embedding_net (nn.Module): embedding net + output_embedding_dim (int): output embedding dimension + num_transforms (int, optional): number of transforms. Defaults to 4. + num_hidden_flow (int, optional): number of hidden layers in flow. Defaults to 2. + hidden_flow_dim (int, optional): hidden dimension in flow. Defaults to 128. + flow (nn.Module, optional): flow. Defaults to zuko.flows.MAF. + theta_shift (float, optional): Shift of the theta for standardization. Defaults to 0.0. + theta_scale (float, optional): Scale of the theta for standardization. Defaults to 1.0. + kwargs: additional arguments for the flow + + Returns: + None + """ + super().__init__() self.npe = NPE( @@ -64,29 +119,3 @@ def sample(self, x: torch.Tensor, shape=(1,)): samples_standardized = self.flow(x).sample(shape) return self.standardize.transform(samples_standardized) - -class NREWithEmbedding(nn.Module): - def __init__( - self, - embedding_net, - output_embedding_dim, - hidden_features, - activation, - network, - theta_shift=0, - theta_scale=1, - ): - super().__init__() - - self.nre = NRE( - 1, - output_embedding_dim, - hidden_features=hidden_features, - activation=activation, - build=network, - ) - self.embedding = embedding_net - self.standardize = Standardize(theta_shift, theta_scale) - - def forward(self, theta, x): - return self.nre(self.standardize(theta), self.embedding(x))