diff --git a/src/cryo_challenge/_preprocessing/align_utils.py b/src/cryo_challenge/_preprocessing/align_utils.py index cbbb6a1..d2a7784 100644 --- a/src/cryo_challenge/_preprocessing/align_utils.py +++ b/src/cryo_challenge/_preprocessing/align_utils.py @@ -134,21 +134,22 @@ def align_submission( -------- volumes (torch.Tensor): aligned submission volumes """ - obj_vol = volumes[0].numpy().astype(np.float32) - - obj_vol = Volume(obj_vol / obj_vol.sum()) - ref_vol = Volume(ref_volume / ref_volume.sum()) - - _, R_est = align_BO( - ref_vol, - obj_vol, - loss_type=params["BOT_loss"], - downsampled_size=params["BOT_box_size"], - max_iters=params["BOT_iter"], - refine=params["BOT_refine"], - ) - R_est = Rotation(R_est.astype(np.float32)) - - volumes = torch.from_numpy(Volume(volumes.numpy()).rotate(R_est)._data) + for i in range(len(volumes)): + obj_vol = volumes[i].numpy().astype(np.float32).copy() + + obj_vol = Volume(obj_vol / obj_vol.sum()) + ref_vol = Volume(ref_volume.copy() / ref_volume.sum()) + + _, R_est = align_BO( + ref_vol, + obj_vol, + loss_type=params["BOT_loss"], + downsampled_size=params["BOT_box_size"], + max_iters=params["BOT_iter"], + refine=params["BOT_refine"], + ) + R_est = Rotation(R_est.astype(np.float32)) + + volumes[i] = torch.from_numpy(Volume(volumes[i].numpy()).rotate(R_est)._data) return volumes