Skip to content

Commit

Permalink
Cleanup sh rotation matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
chris-hld committed Jan 9, 2024
1 parent c573f30 commit 06e8905
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 56 deletions.
43 changes: 15 additions & 28 deletions spaudiopy/sph.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,8 @@ def inverse_sht(F_nm, azi, colat, sh_type, N_sph=None, Y_nm=None):
return np.matmul(Y_nm, F_nm[:(N_sph + 1) ** 2, :])


def rotation_matrix(N_sph, yaw, pitch, roll, sh_type='real',
return_as_blocks=True):
def sh_rotation_matrix(N_sph, yaw, pitch, roll, sh_type='real',
return_as_blocks=False):
"""Computes a Wigner-D matrix for the rotation of spherical harmonics.
Parameters
Expand All @@ -235,15 +235,16 @@ def rotation_matrix(N_sph, yaw, pitch, roll, sh_type='real',
Rotation around X axis.
sh_type : 'complex' or 'real' spherical harmonics. Currently only 'real' is
supported.
return_as_blocks: whether to return a list of blocks (one for each order)
or the full block diagonal matrix
return_as_blocks: return full block diagonal matrix, or a list of blocks
Returns
-------
Either a list r_blocks of numpy arrays [r(n) for n in range(N_sph)], where
the shape of r is (..., 2*n-1, 2*n-1) or a block diagonal matrix R with
shape (..., (N_sph+1)**2, (N_sph+1)**2)
A block diagonal matrix R with shape (..., (N_sph+1)**2, (N_sph+1)**2), or
a list r_blocks of numpy arrays [r(n) for n in range(N_sph)], where the
shape of r is (..., 2*n-1, 2*n-1)
References
----------
Implemented according to: Ivanic, Joseph, and Klaus Ruedenberg. "Rotation
matrices for real spherical harmonics. Direct determination by recursion."
The Journal of Physical Chemistry 100.15 (1996): 6342-6347.
Expand All @@ -258,9 +259,7 @@ def rotation_matrix(N_sph, yaw, pitch, roll, sh_type='real',
elif sh_type == 'real':
pass
else:
raise ValueError('Unknown SH type')


raise ValueError('Unknown SH type')

rot_mat_cartesian = utils.rotation_euler(yaw, pitch, roll)

Expand All @@ -286,11 +285,9 @@ def _rot_p_func(i, l, a, b, r1, rlm1):
else:
return ri0 * rlm1[..., a + l - 1, b + l - 1]


def _rot_u_func(l, m, n, r1, rlm1):
return _rot_p_func(0, l, m, n, r1, rlm1)


def _rot_v_func(l, m, n, r1, rlm1):
if m == 0:
p0 = _rot_p_func(1, l, 1, n, r1, rlm1)
Expand All @@ -310,7 +307,6 @@ def _rot_v_func(l, m, n, r1, rlm1):
else:
return p1 + _rot_p_func(1, l, m + 1, n, r1, rlm1)


def _rot_w_func(l, m, n, r1, rlm1):
if m > 0:
p0 = _rot_p_func(1, l, m + 1, n, r1, rlm1)
Expand Down Expand Up @@ -361,11 +357,11 @@ def _rot_w_func(l, m, n, r1, rlm1):
if return_as_blocks:
return r_blocks
else:
# compose a block-diagonal matrix
# compose a block-diagonal matrix
R = np.zeros(2*[(N_sph+1)**2])
index = 0
for r_block in r_blocks:
R[..., index:index + r_block.shape[-1],
R[..., index:index + r_block.shape[-1],
index:index + r_block.shape[-1]] = r_block
index += r_block.shape[-1]
return R
Expand Down Expand Up @@ -400,23 +396,14 @@ def rotate_sh(F_nm, yaw, pitch, roll, sh_type='real'):
pass
else:
raise ValueError('Unknown SH type')

N_sph = np.sqrt(F_nm.shape[-1]) - 1
assert N_sph == np.floor(N_sph), 'Invalid number of coefficients'
N_sph = int(N_sph)

r_blocks = rotation_matrix(N_sph, yaw, pitch, roll, sh_type)

index = 0

F_nm_rot = np.zeros_like(F_nm)

for r_block in r_blocks:
F_nm_rot[..., index:index + r_block.shape[-1]] = (r_block @
F_nm[..., index:index + r_block.shape[-1]][..., None]).squeeze(-1)
index += r_block.shape[-1]
R = sh_rotation_matrix(N_sph, yaw, pitch, roll, sh_type)

return F_nm_rot
return F_nm @ R.T


def check_cond_sht(N_sph, azi, colat, sh_type, lim=None):
Expand Down
46 changes: 18 additions & 28 deletions tests/test_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,39 +49,29 @@ def test_calculate_grid_weights(test_n_sph):
assert_allclose(q_weights_t, q_weights)




@pytest.mark.parametrize('test_n_sph', N_SPHS)
def test_rotation(test_n_sph):
# use integer angles here (these definitely will not lead to special
# cases in which a rotation matches by chance)
test_yaw, test_pitch, test_roll = (1, 3, 5), (1, 3, 5), (1, 3, 5)
test_yaw, test_pitch, test_roll = (0, np.pi/2, 5), (0, 3, 5), (0, -3, 5)

tgrid = spa.grids.load_t_design(degree=2*test_n_sph)
tazi, tzen, _ = spa.utils.cart2sph(*tgrid.T)

for yaw in test_yaw:
for pitch in test_pitch:
for roll in test_roll:
tgrid = spa.grids.load_t_design(degree=2*test_n_sph)
tx = tgrid[:, 0]
ty = tgrid[:, 1]
tz = tgrid[:, 2]
tazi, tcolat, _ = spa.utils.cart2sph(tx, ty, tz)

print(yaw, pitch, roll)
R = spa.utils.rotation_euler(yaw, pitch, roll)
tgrid_rot = (R @ tgrid.T).T
tx_rot = tgrid_rot[:, 0]
ty_rot = tgrid_rot[:, 1]
tz_rot = tgrid_rot[:, 2]
tazi_rot, tcolat_rot, _ = spa.utils.cart2sph(
tx_rot, ty_rot, tz_rot)

shmat = spa.sph.sh_matrix(test_n_sph, tazi, tcolat, 'real')
shmat_ref = spa.sph.sh_matrix(test_n_sph, tazi_rot, tcolat_rot)

R = spa.sph.rotation_matrix(test_n_sph, yaw, pitch, roll,
sh_type='real',
return_as_blocks=False)

shmat_rot = (R @ shmat[..., None]).squeeze(-1)
assert_allclose(shmat_ref, shmat_rot, rtol=1e-3)
tgrid_rot = (R @ tgrid.T).T
tazi_rot, tzen_rot, _ = spa.utils.cart2sph(*tgrid_rot.T)

shmat = spa.sph.sh_matrix(test_n_sph, tazi, tzen, 'real')
shmat_ref = spa.sph.sh_matrix(test_n_sph, tazi_rot, tzen_rot)

shmat_rot = spa.sph.rotate_sh(shmat, yaw, pitch, roll)
R = spa.sph.sh_rotation_matrix(test_n_sph, yaw, pitch, roll,
sh_type='real')

shmat_rot = (R @ shmat.T).T
assert_allclose(shmat_ref, shmat_rot, rtol=1e-3)

shmat_rotate_sh = spa.sph.rotate_sh(shmat, yaw, pitch, roll)
assert_allclose(shmat_ref, shmat_rotate_sh, rtol=1e-3)

0 comments on commit 06e8905

Please sign in to comment.