diff --git a/normalizing_flows/bijections/finite/residual/log_abs_det_estimators.py b/normalizing_flows/bijections/finite/residual/log_abs_det_estimators.py index 5be083e..7c64e87 100644 --- a/normalizing_flows/bijections/finite/residual/log_abs_det_estimators.py +++ b/normalizing_flows/bijections/finite/residual/log_abs_det_estimators.py @@ -21,7 +21,7 @@ def hutchinson_log_abs_det_estimator(g: callable, assert n_iterations >= 2 w = noise # (batch_size, event_size, n_hutchinson_samples) - log_abs_det_jac_f = torch.zeros(size=(batch_size, n_hutchinson_samples)) # this will be averaged at the end + log_abs_det_jac_f = torch.zeros(size=(batch_size,)) g_value = None for k in range(1, n_iterations + 1): # Compute VJP, reshape appropriately for hutchinson averaging @@ -35,9 +35,9 @@ def hutchinson_log_abs_det_estimator(g: callable, ws = list(ws) # (batch_size, event_size, n_hutchinson_samples) w = torch.concatenate([ws[i][..., None] for i in range(n_hutchinson_samples)], dim=2) - log_abs_det_jac_f += (-1) ** (k + 1) / k * torch.sum(w * noise, dim=1) # sum over event dim - assert log_abs_det_jac_f.shape == (batch_size, n_hutchinson_samples) - log_abs_det_jac_f = torch.mean(log_abs_det_jac_f, dim=1) # hutchinson averaging over the many different series + log_abs_det_jac_f += (-1) ** (k + 1) / k * torch.sum(w * noise, dim=1).mean( + dim=1) # sum over event dim, average over hutchinson dim + assert log_abs_det_jac_f.shape == (batch_size,) return g_value, log_abs_det_jac_f