Skip to content

Commit

Permalink
extend var documentation and remove unused "linear" mode
Browse files Browse the repository at this point in the history
  • Loading branch information
Hrovatin committed Sep 14, 2024
1 parent 3c93e7f commit 3e8ffd0
Showing 1 changed file with 7 additions and 24 deletions.
31 changes: 7 additions & 24 deletions src/scvi/external/sysvi/_base_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,10 @@ class EncoderDecoder(Module):
The number of fully-connected hidden layers
n_layers
Number of hidden layers
var_eps
See :class:`~scvi.external.sysvi.VarEncoder`
var_mode
See :class:`~scvi.external.sysvi.VarEncoder`
How to compute variance from model outputs, see :class:`~scvi.external.sysvi.VarEncoder`
'sample_feature' - learn per sample and feature
'feature' - learn per feature, constant across samples
sample
Return samples from predicted distribution
kwargs
Expand All @@ -86,16 +86,13 @@ def __init__(
n_cov: int,
n_hidden: int = 256,
n_layers: int = 3,
var_eps: float = 1e-4,
var_mode: Literal["sample_feature", "feature", "linear"] = "feature",
sample: bool = False,
**kwargs,
):
super().__init__()
self.sample = sample

self.var_eps = var_eps

self.decoder_y = Layers(
n_in=n_input,
n_cov=n_cov,
Expand All @@ -106,7 +103,7 @@ def __init__(
)

self.mean_encoder = Linear(n_hidden, n_output)
self.var_encoder = VarEncoder(n_hidden, n_output, mode=var_mode, eps=var_eps)
self.var_encoder = VarEncoder(n_hidden, n_output, mode=var_mode)

def forward(self, x: torch.Tensor, cov: torch.Tensor | None = None):
y = self.decoder_y(x=x, cov=cov)
Expand Down Expand Up @@ -216,7 +213,7 @@ def set_online_update_hooks(self, hook_first_layer=True):
def _hook_fn_weight(grad):
new_grad = torch.zeros_like(grad)
if self.n_cov > 0:
new_grad[:, -self.n_cov :] = grad[:, -self.n_cov :]
new_grad[:, -self.n_cov:] = grad[:, -self.n_cov:]
return new_grad

def _hook_fn_zero_out(grad):
Expand Down Expand Up @@ -285,42 +282,33 @@ class VarEncoder(Module):
How to compute var
'sample_feature' - learn per sample and feature
'feature' - learn per feature, constant across samples
'linear' - linear with respect to input mean, var = a1 * mean + a0;
not suggested to be used due to bad implementation for positive constraining
eps
"""

def __init__(
self,
n_input: int,
n_output: int,
mode: Literal["sample_feature", "feature", "linear"],
eps: float = 1e-4,
):
super().__init__()

self.eps = eps
self.eps = 1e-4
self.mode = mode
if self.mode == "sample_feature":
self.encoder = Linear(n_input, n_output)
elif self.mode == "feature":
self.var_param = Parameter(torch.zeros(1, n_output))
elif self.mode == "linear":
self.var_param_a1 = Parameter(torch.tensor([1.0]))
self.var_param_a0 = Parameter(torch.tensor([self.eps]))
else:
raise ValueError("Mode not recognised.")
self.activation = torch.exp

def forward(self, x: torch.Tensor, x_m: torch.Tensor):
def forward(self, x: torch.Tensor):
"""Forward pass through model
Parameters
----------
x
Used to encode var if mode is sample_feature; dim = n_samples x n_input
x_m
Used to predict var instead of x if mode is linear; dim = n_samples x 1
Returns
-------
Expand All @@ -337,9 +325,4 @@ def forward(self, x: torch.Tensor, x_m: torch.Tensor):
v = (
torch.nan_to_num(self.activation(v)) + self.eps
) # Ensure that var is strictly positive
elif self.mode == "linear":
v = self.var_param_a1 * x_m.detach().clone() + self.var_param_a0
# TODO come up with a better way to constrain this to positive while having lin relationship
# Could activation be used for log-lin relationship?
v = torch.clamp(torch.nan_to_num(v), min=self.eps)
return v

0 comments on commit 3e8ffd0

Please sign in to comment.