Skip to content

Commit

Permalink
SH rotation matrix (#28)
Browse files Browse the repository at this point in the history
* adds rotation matrices

* some bugfixes, tests pass now at rtol=1e-3

---------

Co-authored-by: Benjamin Stahl <[email protected]>
  • Loading branch information
BenjSta and Benjamin Stahl authored Jan 9, 2024
1 parent 3ff17c1 commit c573f30
Show file tree
Hide file tree
Showing 2 changed files with 238 additions and 0 deletions.
200 changes: 200 additions & 0 deletions spaudiopy/sph.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,206 @@ def inverse_sht(F_nm, azi, colat, sh_type, N_sph=None, Y_nm=None):
return np.matmul(Y_nm, F_nm[:(N_sph + 1) ** 2, :])


def rotation_matrix(N_sph, yaw, pitch, roll, sh_type='real',
return_as_blocks=True):
"""Computes a Wigner-D matrix for the rotation of spherical harmonics.
Parameters
----------
N_sph : int
Maximum SH order.
yaw: float
Rotation around Z axis.
pitch: float
Rotation around Y axis.
roll: float
Rotation around X axis.
sh_type : 'complex' or 'real' spherical harmonics. Currently only 'real' is
supported.
return_as_blocks: whether to return a list of blocks (one for each order)
or the full block diagonal matrix
Returns
-------
Either a list r_blocks of numpy arrays [r(n) for n in range(N_sph)], where
the shape of r is (..., 2*n-1, 2*n-1) or a block diagonal matrix R with
shape (..., (N_sph+1)**2, (N_sph+1)**2)
Implemented according to: Ivanic, Joseph, and Klaus Ruedenberg. "Rotation
matrices for real spherical harmonics. Direct determination by recursion."
The Journal of Physical Chemistry 100.15 (1996): 6342-6347.
Ported from https://git.iem.at/audioplugins/IEMPluginSuite
"""

if sh_type == 'complex':
raise NotImplementedError(
'Currently only real valued SHs can be rotated')
elif sh_type == 'real':
pass
else:
raise ValueError('Unknown SH type')



rot_mat_cartesian = utils.rotation_euler(yaw, pitch, roll)

r_blocks = [np.array([[1]])]

# change order to y, z, x
r1 = rot_mat_cartesian[..., [1, 2, 0], :]
r1 = r1[..., :, [1, 2, 0]]
r_blocks.append(r1)

# auxiliary functions
def _rot_p_func(i, l, a, b, r1, rlm1):
ri1 = r1[..., i + 1, 2]
rim1 = r1[..., i + 1, 0]
ri0 = r1[..., i + 1, 1]

if b == -l:
return (ri1 * rlm1[..., a + l - 1, 0] +
rim1 * rlm1[..., a + l - 1, 2 * l - 2])
elif b == l:
return (ri1 * rlm1[..., a + l - 1, 2 * l - 2] -
rim1 * rlm1[..., a + l - 1, 0])
else:
return ri0 * rlm1[..., a + l - 1, b + l - 1]


def _rot_u_func(l, m, n, r1, rlm1):
return _rot_p_func(0, l, m, n, r1, rlm1)


def _rot_v_func(l, m, n, r1, rlm1):
if m == 0:
p0 = _rot_p_func(1, l, 1, n, r1, rlm1)
p1 = _rot_p_func(-1, l, -1, n, r1, rlm1)
return p0 + p1

elif m > 0:
p0 = _rot_p_func(1, l, m - 1, n, r1, rlm1)
if m == 1:
return p0 * np.sqrt(2)
else:
return p0 - _rot_p_func(-1, l, 1 - m, n, r1, rlm1)
else:
p1 = _rot_p_func(-1, l, -m - 1, n, r1, rlm1)
if m == -1:
return p1 * np.sqrt(2)
else:
return p1 + _rot_p_func(1, l, m + 1, n, r1, rlm1)


def _rot_w_func(l, m, n, r1, rlm1):
if m > 0:
p0 = _rot_p_func(1, l, m + 1, n, r1, rlm1)
p1 = _rot_p_func(-1, l, -m - 1, n, r1, rlm1)
return p0 + p1
elif m < 0:
p0 = _rot_p_func(1, l, m - 1, n, r1, rlm1)
p1 = _rot_p_func(-1, l, 1 - m, n, r1, rlm1)
return p0 - p1
return 0

rlm1 = r1
for l in range(2, N_sph + 1):
rl = np.zeros((2 * l + 1, 2 * l + 1))
for m in range(-l, l + 1):
for n in range(-l, l + 1):
d = int(m == 0)
if abs(n) == l:
denom = (2 * l) * (2 * l - 1)
else:
denom = l * l - n * n

u = np.sqrt((l * l - m * m) / denom)
v = (
np.sqrt((1.0 + d) * (l + abs(m) - 1.0)
* (l + abs(m)) / denom)
* (1.0 - 2.0 * d)
* 0.5
)
w = (
np.sqrt((l - abs(m) - 1.0) * (l - abs(m)) / denom)
* (1.0 - d)
* (-0.5)
)

if u != 0:
u *= _rot_u_func(l, m, n, r1, rlm1)
if v != 0:
v *= _rot_v_func(l, m, n, r1, rlm1)
if w != 0:
w *= _rot_w_func(l, m, n, r1, rlm1)

rl[..., m + l, n + l] = u + v + w

r_blocks.append(rl)
rlm1 = rl

if return_as_blocks:
return r_blocks
else:
# compose a block-diagonal matrix
R = np.zeros(2*[(N_sph+1)**2])
index = 0
for r_block in r_blocks:
R[..., index:index + r_block.shape[-1],
index:index + r_block.shape[-1]] = r_block
index += r_block.shape[-1]
return R


def rotate_sh(F_nm, yaw, pitch, roll, sh_type='real'):
"""Rotate spherical harmonics coefficients.
Parameters
----------
F_nm : (..., (N_sph+1)**2) numpy.ndarray
Spherical harmonics coefficients
yaw: float
Rotation around Z axis.
pitch: float
Rotation around Y axis.
roll: float
Rotation around X axis.
sh_type : 'complex' or 'real' spherical harmonics. Currently only 'real' is
supported.
Returns
-------
F_nm_rot : (..., (N_sph+1)**2) numpy.ndarray
Rotated spherical harmonics coefficients.
"""
if sh_type == 'complex':
raise NotImplementedError(
'Currently only real valued SHs can be rotated')
elif sh_type == 'real':
pass
else:
raise ValueError('Unknown SH type')

N_sph = np.sqrt(F_nm.shape[-1]) - 1
assert N_sph == np.floor(N_sph), 'Invalid number of coefficients'
N_sph = int(N_sph)

r_blocks = rotation_matrix(N_sph, yaw, pitch, roll, sh_type)

index = 0

F_nm_rot = np.zeros_like(F_nm)

for r_block in r_blocks:
F_nm_rot[..., index:index + r_block.shape[-1]] = (r_block @
F_nm[..., index:index + r_block.shape[-1]][..., None]).squeeze(-1)
index += r_block.shape[-1]

return F_nm_rot


def check_cond_sht(N_sph, azi, colat, sh_type, lim=None):
"""Check if condition number for a least-squares SHT(N_sph) is high."""
if lim is None:
Expand Down
38 changes: 38 additions & 0 deletions tests/test_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,41 @@ def test_calculate_grid_weights(test_n_sph):

# Perfect Reconstruction
assert_allclose(q_weights_t, q_weights)




@pytest.mark.parametrize('test_n_sph', N_SPHS)
def test_rotation(test_n_sph):
# use integer angles here (these definitely will not lead to special
# cases in which a rotation matches by chance)
test_yaw, test_pitch, test_roll = (1, 3, 5), (1, 3, 5), (1, 3, 5)
for yaw in test_yaw:
for pitch in test_pitch:
for roll in test_roll:
tgrid = spa.grids.load_t_design(degree=2*test_n_sph)
tx = tgrid[:, 0]
ty = tgrid[:, 1]
tz = tgrid[:, 2]
tazi, tcolat, _ = spa.utils.cart2sph(tx, ty, tz)

R = spa.utils.rotation_euler(yaw, pitch, roll)
tgrid_rot = (R @ tgrid.T).T
tx_rot = tgrid_rot[:, 0]
ty_rot = tgrid_rot[:, 1]
tz_rot = tgrid_rot[:, 2]
tazi_rot, tcolat_rot, _ = spa.utils.cart2sph(
tx_rot, ty_rot, tz_rot)

shmat = spa.sph.sh_matrix(test_n_sph, tazi, tcolat, 'real')
shmat_ref = spa.sph.sh_matrix(test_n_sph, tazi_rot, tcolat_rot)

R = spa.sph.rotation_matrix(test_n_sph, yaw, pitch, roll,
sh_type='real',
return_as_blocks=False)

shmat_rot = (R @ shmat[..., None]).squeeze(-1)
assert_allclose(shmat_ref, shmat_rot, rtol=1e-3)

shmat_rot = spa.sph.rotate_sh(shmat, yaw, pitch, roll)
assert_allclose(shmat_ref, shmat_rot, rtol=1e-3)

0 comments on commit c573f30

Please sign in to comment.