Skip to content

Commit

Permalink
Fixed indexing warning of torch meshgrid.
Browse files Browse the repository at this point in the history
Fixes #406
  • Loading branch information
TomTomRixRix committed Dec 18, 2024
1 parent 6949c2c commit 8101182
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def reconstruction_algorithm(self, time_series_sensor_data, detection_geometry:
yy, zz, nn, mm = torch.meshgrid(torch.arange(ydim, device=torch_device),
torch.arange(zdim, device=torch_device),
torch.arange(n_sensor_elements, device=torch_device),
torch.arange(n_sensor_elements, device=torch_device))
torch.arange(n_sensor_elements, device=torch_device), indexing='ij')
M = values[x, yy, zz, nn] * values[x, yy, zz, mm]
M = torch.sign(M) * torch.sqrt(torch.abs(M))
# only take upper triangle without diagonal and sum up along n and m axis (last two)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ def compute_delay_and_sum_values(time_series_sensor_data: Tensor, sensor_positio
z = zdim_start + torch.arange(zdim, device=torch_device, dtype=torch.float32)
j = torch.arange(n_sensor_elements, device=torch_device, dtype=torch.float32)

xx, yy, zz, jj = torch.meshgrid(x, y, z, j)
xx, yy, zz, jj = torch.meshgrid(x, y, z, j, indexing='ij')
jj = jj.long()

delays = torch.sqrt((yy * spacing_in_mm - sensor_positions[:, 2][jj]) ** 2 +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def reconstruction_algorithm(self, time_series_sensor_data, detection_geometry:
yy, zz, nn, mm = torch.meshgrid(torch.arange(ydim, device=torch_device),
torch.arange(zdim, device=torch_device),
torch.arange(n_sensor_elements, device=torch_device),
torch.arange(n_sensor_elements, device=torch_device))
torch.arange(n_sensor_elements, device=torch_device), indexing='ij')
M = values[x, yy, zz, nn] * values[x, yy, zz, mm]
M = torch.sign(M) * torch.sqrt(torch.abs(M))
# only take upper triangle without diagonal and sum up along n and m axis (last two)
Expand Down

0 comments on commit 8101182

Please sign in to comment.