diff --git a/src/scvi/external/sysvi/_base_components.py b/src/scvi/external/sysvi/_base_components.py index b556636652..ed4cfe214e 100644 --- a/src/scvi/external/sysvi/_base_components.py +++ b/src/scvi/external/sysvi/_base_components.py @@ -317,8 +317,8 @@ def forward(self, x: torch.Tensor): # Force to be non nan - TODO come up with better way to do so if self.mode == "sample_feature": v = self.encoder(x) - v = (self.activation(v) + self.eps) # Ensure that var is strictly positive + v = self.activation(v) + self.eps # Ensure that var is strictly positive elif self.mode == "feature": v = self.var_param.expand(x.shape[0], -1) # Broadcast to input size - v = (self.activation(v) + self.eps) # Ensure that var is strictly positive + v = self.activation(v) + self.eps # Ensure that var is strictly positive return v diff --git a/src/scvi/external/sysvi/_model.py b/src/scvi/external/sysvi/_model.py index 2c520d48ab..0da03933fe 100644 --- a/src/scvi/external/sysvi/_model.py +++ b/src/scvi/external/sysvi/_model.py @@ -11,8 +11,6 @@ from scvi import REGISTRY_KEYS from scvi.data import AnnDataManager -from scvi.data._constants import _SCVI_UUID_KEY -from scvi.data._utils import _check_if_view from scvi.data.fields import ( LayerField, ObsmField, @@ -137,6 +135,7 @@ def get_latent_representation( return_dist If ``True``, returns the mean and variance of the latent distribution. Otherwise, returns the mean of the latent distribution. + Returns ------- Latent Embedding @@ -192,8 +191,7 @@ def _validate_anndata( # Check that all required fields are present and match the Model's adata assert ( - self.adata.uns["layer_information"]["layer"] - == adata.uns["layer_information"]["layer"] + self.adata.uns["layer_information"]["layer"] == adata.uns["layer_information"]["layer"] ) assert ( self.adata.uns["layer_information"]["var_names"] diff --git a/src/scvi/external/sysvi/_module.py b/src/scvi/external/sysvi/_module.py index 00a82d54ae..bec63c9fb3 100644 --- a/src/scvi/external/sysvi/_module.py +++ b/src/scvi/external/sysvi/_module.py @@ -350,7 +350,8 @@ def loss( # Reconstruction loss reconst_loss_x = torch.nn.GaussianNLLLoss(reduction="none")( - generative_outputs["x_m"], x_true, generative_outputs["x_v"]).sum(dim=1) + generative_outputs["x_m"], x_true, generative_outputs["x_v"] + ).sum(dim=1) reconst_loss = reconst_loss_x diff --git a/tests/external/sysvi/test_model.py b/tests/external/sysvi/test_model.py index 32d57beab2..77fc539bb1 100644 --- a/tests/external/sysvi/test_model.py +++ b/tests/external/sysvi/test_model.py @@ -171,5 +171,3 @@ def test_model(): give_mean=False, ), ) - -