Skip to content

Commit

Permalink
binning noise level same way santiago does it
Browse files Browse the repository at this point in the history
  • Loading branch information
Benjamin Aron committed Mar 22, 2024
1 parent b1262b7 commit c1fc690
Showing 1 changed file with 17 additions and 18 deletions.
35 changes: 17 additions & 18 deletions lib/smi.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,12 +520,11 @@ def fit2D4D_LLS_RealSphHarm_wSorting_norm_var(self, dwi, mask, rank1sl=True):
if np.sum(l_max >= ll) > 1:
id_useful_shells = id_shells[l_max>=ll]
id_current_band = (l_all == ll)
for kk in range(n_voxels):
slm_voxel_raw = s_lm_clusters_all[id_current_band, kk, :][:, id_useful_shells]
slm_voxel_dn = self.low_rank_denoising(slm_voxel_raw, 1)
slm_dn[id_current_band, kk, :][:, id_useful_shells] = slm_voxel_dn

sl_dn[int(ll/2),kk,id_useful_shells] = np.sqrt(np.sum(slm_voxel_dn**2, axis=0))
slm_voxels_raw = s_lm_clusters_all[id_current_band, :, :][:, :, id_useful_shells]
slm_voxels_dn = self.low_rank_denoising(slm_voxels_raw.transpose(1,0,2), 1).transpose(1,0,2)
slm_dn[id_current_band, :, :][:, :, id_useful_shells]
sl_dn[int(ll/2),:,id_useful_shells] = np.sqrt(np.sum(slm_voxels_dn**2, axis=0)).T

s_lm_clusters_all = slm_dn.copy()
s_l_clusters_all = sl_dn.copy()

Expand Down Expand Up @@ -568,9 +567,11 @@ def fit2D4D_LLS_RealSphHarm_wSorting_norm_var(self, dwi, mask, rank1sl=True):

def low_rank_denoising(self, X, p):
u,s,v = np.linalg.svd(X, full_matrices=False)
s_dn = np.zeros_like(s)
s_dn[:p] = s[:p]
return u @ np.diag(s_dn) @ v
s_dn = np.zeros((s.shape[0], s.shape[1], s.shape[1]))
diag_inds = np.diag_indices(X.shape[2])
s_dn[:,diag_inds[0][:p],diag_inds[1][:p]] = s[:,:p]
u_s = np.einsum('ijk,ikl->ijl', u, s_dn)
return np.einsum('ijk,ikl->ijl', u_s, v)


def group_dwi_in_shells_b_beta_te(self):
Expand Down Expand Up @@ -916,8 +917,6 @@ def standard_model_mlfit_rot_invs(self, rot_invs, sigma_norm_limits):
np.divide(
sigma_normalized, s0_lowest_te, out=sigma_normalized, where=s0_lowest_te != 0
)
#rot_invs_normalized = (rot_invs_normalized / s0_lowest_te).T
#sigma_normalized = (sigma_normalized / s0_lowest_te).T

shells = self.table_4d[0,:]
beta = self.table_4d[1,:]
Expand Down Expand Up @@ -950,14 +949,14 @@ def standard_model_mlfit_rot_invs(self, rot_invs, sigma_norm_limits):
)
sigma_noise_norm_levels_ids = np.digitize(
sigma_normalized, sigma_noise_norm_levels_edges
) - 1
)

sigma_noise_norm_levels_ids[sigma_normalized < sigma_noise_norm_levels_edges[0]] = 0
sigma_noise_norm_levels_ids[sigma_normalized > sigma_noise_norm_levels_edges[-1]] = self.n_levels - 1
sigma_noise_norm_levels_ids[sigma_normalized > sigma_noise_norm_levels_edges[-1]] = self.n_levels + 1
sigma_noise_norm_levels_mean = 1/2 * (
sigma_noise_norm_levels_edges[1:] + sigma_noise_norm_levels_edges[:-1]
)

degree_kernel = 3
x_fit_norm = self.compute_extended_moments(
rot_invs_normalized[:, keep_rot_invs_kernel], degree_kernel
Expand Down Expand Up @@ -1009,17 +1008,17 @@ def standard_model_mlfit_rot_invs(self, rot_invs, sigma_norm_limits):

rotinvs_train_norm = rotinvs_train / rotinvs_train[:,[0]]

for i in range(self.n_levels):
for i in range(1, len(sigma_noise_norm_levels_edges)):
flag_current_noise_level = sigma_noise_norm_levels_ids == i

if not np.any(flag_current_noise_level):
continue

sigma_rotinvs_training = sigma_noise_norm_levels_mean[i] / sigma_ndirs_factor
sigma_rotinvs_training = sigma_noise_norm_levels_mean[i-1] / sigma_ndirs_factor
meas_rotinvs_train = (rotinvs_train_norm +
sigma_rotinvs_training * np.random.standard_normal((rotinvs_train_norm.shape))
)

x_train = self.compute_extended_moments(
meas_rotinvs_train[:, keep_rot_invs_kernel], degree=degree_kernel)

Expand Down

0 comments on commit c1fc690

Please sign in to comment.