From 9fd7ebc500fb82854097fc9424a3594659773655 Mon Sep 17 00:00:00 2001 From: chris-langfield Date: Wed, 24 Apr 2024 22:25:23 -0400 Subject: [PATCH] make offset default pattern --- src/hexfft/array.py | 2 +- src/hexfft/hexfft.py | 8 ++++---- src/hexfft/utils.py | 24 +++--------------------- tests/test_api.py | 6 +++--- tests/test_hexfft.py | 6 +++--- tests/test_utils.py | 10 +++++----- 6 files changed, 19 insertions(+), 37 deletions(-) diff --git a/src/hexfft/array.py b/src/hexfft/array.py index dea92f6..66f8443 100644 --- a/src/hexfft/array.py +++ b/src/hexfft/array.py @@ -40,7 +40,7 @@ class HexArray(np.ndarray): """ - def __new__(cls, arr, pattern="oblique"): + def __new__(cls, arr, pattern="offset"): # np.ndarray subclass boilerplate obj = np.asarray(arr).view(cls) diff --git a/src/hexfft/hexfft.py b/src/hexfft/hexfft.py index 279544d..e6ee822 100644 --- a/src/hexfft/hexfft.py +++ b/src/hexfft/hexfft.py @@ -280,7 +280,7 @@ def _inverse(self, X): X = np.expand_dims(X, 0) F2 = scipy.fft.ifft(X, axis=2) F1 = F2 * self.phase_shift_conj - x = HexArray(scipy.fft.ifft(F1, axis=1)) + x = HexArray(scipy.fft.ifft(F1, axis=1), "oblique") if squeeze: x = np.squeeze(x) @@ -643,12 +643,12 @@ def rect_fft(x): cdtype = complex_type(dtype) N1, N2 = x.shape - F1 = HexArray(scipy.fft.fft(x, axis=0)) + F1 = HexArray(scipy.fft.fft(x, axis=0), "oblique") exp_factor = np.exp( 1.0j * np.pi * np.array([i * np.arange(N2) for i in range(N1)]) / N1 ).astype(cdtype) F2 = F1 * exp_factor - F = HexArray(np.fft.fft(F2, axis=1)) + F = HexArray(np.fft.fft(F2, axis=1), "oblique") return F @@ -663,6 +663,6 @@ def rect_ifft(X): -1.0j * np.pi * np.array([i * np.arange(N2) for i in range(N1)]) / N1 ).astype(cdtype) F1 = F2 * exp_factor - x = HexArray(scipy.fft.ifft(F1, axis=0)) + x = HexArray(scipy.fft.ifft(F1, axis=0), "oblique") return x diff --git a/src/hexfft/utils.py b/src/hexfft/utils.py index 1b7580b..919e7ea 100644 --- a/src/hexfft/utils.py +++ b/src/hexfft/utils.py @@ -69,7 +69,7 @@ def hex_to_pgram(h): p[pgram_left] = h[support_below] p[pgram_right] = h[support_above] - return HexArray(p, pattern=h.pattern) + return HexArray(p, h.pattern) def pgram_to_hex(p, N, pattern="oblique"): @@ -108,29 +108,11 @@ def pgram_to_hex(p, N, pattern="oblique"): h[support_below] = p[pgram_left] h[support_above] = p[pgram_right] - return HexArray(h, pattern=pattern) - - -def pad(x): - """ - Given an NxN array x, find the enclosing Mersereau - hexagonal region and sampling grid. - """ - assert x.shape[0] == x.shape[1] - - # Create a Mersereau hexagonal region of size N - P = x.shape[0] # i.e. = N - # Parallelogram (square in oblique coordinates) enclosing - M = 2 * (P + 1) - m1, m2 = np.meshgrid(np.arange(M), np.arange(M)) - grid = np.zeros((M, M), x.dtype) - grid[int(P // 2) : P + int(P // 2), int(P // 2) : P + int(P // 2)] = x - - return grid + return HexArray(h, pattern) def nice_test_function(shape, hcrop=True, pattern="oblique"): - h = HexArray(np.zeros(shape), pattern=pattern) + h = HexArray(np.zeros(shape), pattern) N1, N2 = shape n1, n2 = np.meshgrid(np.arange(N1), np.arange(N2), indexing="ij") if hcrop: diff --git a/tests/test_api.py b/tests/test_api.py index 966b519..0f29c40 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -105,13 +105,13 @@ def test_hex_fft_output(pattern): assert X.shape == array.shape assert X.dtype == np.complex128 - X = fft(HexArray(array.astype(np.float32)), "hex") + X = fft(HexArray(array.astype(np.float32), pattern), "hex") assert X.dtype == np.complex64 - X = fft(HexArray(array.astype(np.complex64)), "hex") + X = fft(HexArray(array.astype(np.complex64), pattern), "hex") assert X.dtype == np.complex64 - X = fft(HexArray(array.astype(np.complex128)), "hex") + X = fft(HexArray(array.astype(np.complex128), pattern), "hex") assert X.dtype == np.complex128 diff --git a/tests/test_hexfft.py b/tests/test_hexfft.py index cac0e64..b9aaecd 100644 --- a/tests/test_hexfft.py +++ b/tests/test_hexfft.py @@ -94,9 +94,9 @@ def test_pgram_hexdft(): else: center = (N / 2 - 1, N / 2 - 1) impulse = HexArray( - np.stack([hregion(n1, n2, center, i + 1) for i in range(nstack)]) + np.stack([hregion(n1, n2, center, i + 1) for i in range(nstack)]), "oblique" ) - impulse_single = HexArray(hregion(n1, n2, center, 1)) + impulse_single = HexArray(hregion(n1, n2, center, 1), "oblique") impulse_p = hex_to_pgram(impulse) impulse_single_p = hex_to_pgram(impulse_single) @@ -254,7 +254,7 @@ def test_mersereau_fft(): def test_fftshift(): for size in [8, 16, 32]: x = nice_test_function((size, size)) - h_oblique = HexArray(x) * hsupport(size) + h_oblique = HexArray(x, "oblique") * hsupport(size) h_offset = HexArray(x, "offset") * hsupport(size, "offset") shifted_oblique = fftshift(h_oblique) shifted_offset = fftshift(h_offset) diff --git a/tests/test_utils.py b/tests/test_utils.py index 8ff6e57..e50b475 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -10,7 +10,7 @@ def test_hexarray(): for arr in arrs: # make sure the indices are in oblique coordinates by default # this corresponds to a 3x3 parallopiped - hx = HexArray(arr) + hx = HexArray(arr, "oblique") n1, n2 = hx.indices t1, t2 = np.meshgrid(np.arange(3), np.arange(3)) assert np.all(n1 == t1) @@ -18,7 +18,7 @@ def test_hexarray(): # make sure internal representation in oblique coords is correct # when given an array with offset coordinates - hx = HexArray(arr, pattern="offset") + hx = HexArray(arr, "offset") n1, n2 = hx.indices t1, t2 = np.meshgrid(np.arange(3), np.arange(3)) @@ -30,7 +30,7 @@ def test_hexarray(): # test the pattern arr = np.ones((4, 5)) - hx = HexArray(arr, pattern="offset") + hx = HexArray(arr, "offset") n1, n2 = hx.indices col_indices = np.array([[i, i + 1, i + 1, i + 2] for i in range(5)]) assert np.all(n2 == col_indices) @@ -40,8 +40,8 @@ def test_hex_pgram_conversions(): # all in oblique coordinates nstack = 10 for N in [5, 8, 16, 21]: - pgram = HexArray(np.random.normal(size=(N // 2, 3 * (N // 2)))) - pgrams = HexArray(np.stack([pgram * (i + 1) for i in range(nstack)])) + pgram = HexArray(np.random.normal(size=(N // 2, 3 * (N // 2))), "oblique") + pgrams = HexArray(np.stack([pgram * (i + 1) for i in range(nstack)]), "oblique") hex = pgram_to_hex(pgram, N) hexs = pgram_to_hex(pgrams, N)