Skip to content

Commit

Permalink
Rewrite hutchinson log det estimator to save some memory
Browse files Browse the repository at this point in the history
  • Loading branch information
davidnabergoj committed Oct 19, 2023
1 parent 4a042de commit a88a105
Showing 1 changed file with 4 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def hutchinson_log_abs_det_estimator(g: callable,
assert n_iterations >= 2

w = noise # (batch_size, event_size, n_hutchinson_samples)
log_abs_det_jac_f = torch.zeros(size=(batch_size, n_hutchinson_samples)) # this will be averaged at the end
log_abs_det_jac_f = torch.zeros(size=(batch_size,))
g_value = None
for k in range(1, n_iterations + 1):
# Compute VJP, reshape appropriately for hutchinson averaging
Expand All @@ -35,9 +35,9 @@ def hutchinson_log_abs_det_estimator(g: callable,
ws = list(ws)
# (batch_size, event_size, n_hutchinson_samples)
w = torch.concatenate([ws[i][..., None] for i in range(n_hutchinson_samples)], dim=2)
log_abs_det_jac_f += (-1) ** (k + 1) / k * torch.sum(w * noise, dim=1) # sum over event dim
assert log_abs_det_jac_f.shape == (batch_size, n_hutchinson_samples)
log_abs_det_jac_f = torch.mean(log_abs_det_jac_f, dim=1) # hutchinson averaging over the many different series
log_abs_det_jac_f += (-1) ** (k + 1) / k * torch.sum(w * noise, dim=1).mean(
dim=1) # sum over event dim, average over hutchinson dim
assert log_abs_det_jac_f.shape == (batch_size,)
return g_value, log_abs_det_jac_f


Expand Down

0 comments on commit a88a105

Please sign in to comment.