diff --git a/src/fidder/erase/sparse_local_mean.py b/src/fidder/erase/sparse_local_mean.py index 77a23b8..c1f98e6 100644 --- a/src/fidder/erase/sparse_local_mean.py +++ b/src/fidder/erase/sparse_local_mean.py @@ -110,9 +110,10 @@ def estimate_local_mean_3d( grid = CubicBSplineGrid3d(resolution=resolution) optimiser = torch.optim.Adam(grid.parameters(), lr=0.01) + foreground_sample_idx_rescaled = foreground_sample_idx / volume.shape for i in range(500): # what does the model predict for our observations? - prediction = grid(foreground_sample_idx).squeeze() + prediction = grid(foreground_sample_idx_rescaled).squeeze() # zero gradients and calculate loss between observations and model prediction optimiser.zero_grad()