Skip to content

Commit

Permalink
Batch VJP computation in hutchinson estimator across all hutchinson s…
Browse files Browse the repository at this point in the history
…amples
  • Loading branch information
davidnabergoj committed Oct 19, 2023
1 parent a88a105 commit 64ddd42
Showing 1 changed file with 7 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,16 @@ def hutchinson_log_abs_det_estimator(g: callable,
g_value = None
for k in range(1, n_iterations + 1):
# Compute VJP, reshape appropriately for hutchinson averaging
tmp = [torch.autograd.functional.vjp(g, x, w[..., i], strict=True) for i in range(n_hutchinson_samples)]
gs, ws = zip(*tmp)
gs_r, ws_r = torch.autograd.functional.vjp(
g,
x[..., None].repeat(1, 1, n_hutchinson_samples).view(batch_size * n_hutchinson_samples, event_size),
w.view(batch_size * n_hutchinson_samples, event_size)
)

if g_value is None:
gs = list(gs)
g_value = gs[0]
g_value = gs_r.view(batch_size, event_size, n_hutchinson_samples)[..., 0]

ws = list(ws)
# (batch_size, event_size, n_hutchinson_samples)
w = torch.concatenate([ws[i][..., None] for i in range(n_hutchinson_samples)], dim=2)
w = ws_r.view(batch_size, event_size, n_hutchinson_samples)
log_abs_det_jac_f += (-1) ** (k + 1) / k * torch.sum(w * noise, dim=1).mean(
dim=1) # sum over event dim, average over hutchinson dim
assert log_abs_det_jac_f.shape == (batch_size,)
Expand Down

0 comments on commit 64ddd42

Please sign in to comment.