diff --git a/src/wf_psf/psf_models/psf_model_semiparametric.py b/src/wf_psf/psf_models/psf_model_semiparametric.py index 44a2bdba..846d332d 100644 --- a/src/wf_psf/psf_models/psf_model_semiparametric.py +++ b/src/wf_psf/psf_models/psf_model_semiparametric.py @@ -12,7 +12,7 @@ from tensorflow.python.keras.engine import data_adapter from wf_psf.psf_models import psf_models as psfm from wf_psf.psf_models import tf_layers as tfl -from wf_psf.utils.utils import tf_decompose_obscured_opd_basis +from wf_psf.utils.utils import decompose_tf_obscured_opd_basis from wf_psf.psf_models.tf_layers import ( TFBatchPolychromaticPSF, TFBatchMonochromaticPSF, @@ -358,7 +358,7 @@ def project_DD_features(self, tf_zernike_cube=None): np.transpose( np.array( [ - tf_decompose_obscured_opd_basis( + decompose_tf_obscured_opd_basis( tf_opd=inter_res_v2[j, :, :], tf_obscurations=self.obscurations, tf_zk_basis=tf_zernike_cube, diff --git a/src/wf_psf/tests/test_utils/utils_test.py b/src/wf_psf/tests/test_utils/utils_test.py index 155fae75..0e3907c1 100644 --- a/src/wf_psf/tests/test_utils/utils_test.py +++ b/src/wf_psf/tests/test_utils/utils_test.py @@ -10,12 +10,15 @@ import pytest import tensorflow as tf import numpy as np -from wf_psf.utils.utils import zernike_generator +from wf_psf.utils.utils import ( + zernike_generator, + compute_unobscured_zernike_projection, + decompose_tf_obscured_opd_basis +) from wf_psf.sims.psf_simulator import PSFSimulator def test_unobscured_zernike_projection(): - from wf_psf.utils.utils import unobscured_zernike_projection n_zernikes = 20 wfe_dim = 256 @@ -41,14 +44,14 @@ def test_unobscured_zernike_projection(): ) # Compute normalisation factor - norm_factor = unobscured_zernike_projection( + norm_factor = compute_unobscured_zernike_projection( tf_zernike_cube[0, :, :], tf_zernike_cube[0, :, :] ) # Compute projections for each zernike estimated_zk_array = np.array( [ - unobscured_zernike_projection( + compute_unobscured_zernike_projection( tf_unobscured_opd, tf_zernike_cube[j, :, :], norm_factor=norm_factor ) for j in range(n_zernikes) @@ -61,8 +64,7 @@ def test_unobscured_zernike_projection(): def test_tf_decompose_obscured_opd_basis(): - from wf_psf.utils.utils import tf_decompose_obscured_opd_basis - + n_zernikes = 20 wfe_dim = 256 tol = 1e-5 @@ -95,7 +97,7 @@ def test_tf_decompose_obscured_opd_basis(): ) # Compute zernike array from OPD - obsc_coeffs = tf_decompose_obscured_opd_basis( + obsc_coeffs = decompose_tf_obscured_opd_basis( tf_opd=tf_obscured_opd, tf_obscurations=tf_obscurations, tf_zk_basis=tf_zernike_cube,