From 64ddd4215af80d9404c19b96a939cf4613bcb2cd Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Thu, 19 Oct 2023 19:35:57 +0200 Subject: [PATCH] Batch VJP computation in hutchinson estimator across all hutchinson samples --- .../finite/residual/log_abs_det_estimators.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/normalizing_flows/bijections/finite/residual/log_abs_det_estimators.py b/normalizing_flows/bijections/finite/residual/log_abs_det_estimators.py index 7c64e87..f62f8c5 100644 --- a/normalizing_flows/bijections/finite/residual/log_abs_det_estimators.py +++ b/normalizing_flows/bijections/finite/residual/log_abs_det_estimators.py @@ -25,16 +25,16 @@ def hutchinson_log_abs_det_estimator(g: callable, g_value = None for k in range(1, n_iterations + 1): # Compute VJP, reshape appropriately for hutchinson averaging - tmp = [torch.autograd.functional.vjp(g, x, w[..., i], strict=True) for i in range(n_hutchinson_samples)] - gs, ws = zip(*tmp) + gs_r, ws_r = torch.autograd.functional.vjp( + g, + x[..., None].repeat(1, 1, n_hutchinson_samples).view(batch_size * n_hutchinson_samples, event_size), + w.view(batch_size * n_hutchinson_samples, event_size) + ) if g_value is None: - gs = list(gs) - g_value = gs[0] + g_value = gs_r.view(batch_size, event_size, n_hutchinson_samples)[..., 0] - ws = list(ws) - # (batch_size, event_size, n_hutchinson_samples) - w = torch.concatenate([ws[i][..., None] for i in range(n_hutchinson_samples)], dim=2) + w = ws_r.view(batch_size, event_size, n_hutchinson_samples) log_abs_det_jac_f += (-1) ** (k + 1) / k * torch.sum(w * noise, dim=1).mean( dim=1) # sum over event dim, average over hutchinson dim assert log_abs_det_jac_f.shape == (batch_size,)