Skip to content

Commit

Permalink
Merge pull request #92 from flatironinstitute/centering_every_volume
Browse files Browse the repository at this point in the history
Centering every volume
  • Loading branch information
DSilva27 authored Aug 21, 2024
2 parents efb77bc + 94b07dd commit 21ca809
Showing 1 changed file with 17 additions and 16 deletions.
33 changes: 17 additions & 16 deletions src/cryo_challenge/_preprocessing/align_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,21 +134,22 @@ def align_submission(
--------
volumes (torch.Tensor): aligned submission volumes
"""
obj_vol = volumes[0].numpy().astype(np.float32)

obj_vol = Volume(obj_vol / obj_vol.sum())
ref_vol = Volume(ref_volume / ref_volume.sum())

_, R_est = align_BO(
ref_vol,
obj_vol,
loss_type=params["BOT_loss"],
downsampled_size=params["BOT_box_size"],
max_iters=params["BOT_iter"],
refine=params["BOT_refine"],
)
R_est = Rotation(R_est.astype(np.float32))

volumes = torch.from_numpy(Volume(volumes.numpy()).rotate(R_est)._data)
for i in range(len(volumes)):
obj_vol = volumes[i].numpy().astype(np.float32).copy()

obj_vol = Volume(obj_vol / obj_vol.sum())
ref_vol = Volume(ref_volume.copy() / ref_volume.sum())

_, R_est = align_BO(
ref_vol,
obj_vol,
loss_type=params["BOT_loss"],
downsampled_size=params["BOT_box_size"],
max_iters=params["BOT_iter"],
refine=params["BOT_refine"],
)
R_est = Rotation(R_est.astype(np.float32))

volumes[i] = torch.from_numpy(Volume(volumes[i].numpy()).rotate(R_est)._data)

return volumes

0 comments on commit 21ca809

Please sign in to comment.