From 1d22912e9ddeb7cc7e508d18086c597540b76b7e Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Thu, 29 Aug 2024 17:38:32 +0200 Subject: [PATCH] fix: revise tests, all green locally --- test/test_dipy.py | 2 +- test/test_dmri_utils.py | 639 ---------------------------------------- test/test_model.py | 16 +- 3 files changed, 9 insertions(+), 648 deletions(-) delete mode 100644 test/test_dmri_utils.py diff --git a/test/test_dipy.py b/test/test_dipy.py index cd62415a..54c91006 100644 --- a/test/test_dipy.py +++ b/test/test_dipy.py @@ -27,7 +27,7 @@ from dipy.core.gradients import gradient_table from dipy.io import read_bvals_bvecs -from eddymotion.model.dipy import ( +from eddymotion.model._dipy import ( PairwiseOrientationKernel, compute_exponential_covariance, compute_pairwise_angles, diff --git a/test/test_dmri_utils.py b/test/test_dmri_utils.py deleted file mode 100644 index f5738c22..00000000 --- a/test/test_dmri_utils.py +++ /dev/null @@ -1,639 +0,0 @@ -# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- -# vi: set ft=python sts=4 ts=4 sw=4 et: -# -# Copyright 2024 The NiPreps Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# We support and encourage derived works from this project, please read -# about our expectations at -# -# https://www.nipreps.org/community/licensing/ -# - -import numpy as np -import pytest - -from eddymotion.model.dmri_utils import ( - find_shelling_scheme, -) - - -@pytest.mark.parametrize( - ("bvals", "exp_scheme", "exp_bval_groups", "exp_bval_estimated"), - [ - ( - np.asarray( - [ - 5, - 300, - 300, - 300, - 300, - 300, - 305, - 1005, - 995, - 1000, - 1000, - 1005, - 1000, - 1000, - 1005, - 995, - 1000, - 1005, - 5, - 995, - 1000, - 1000, - 995, - 1005, - 995, - 1000, - 995, - 995, - 2005, - 2000, - 2005, - 2005, - 1995, - 2000, - 2005, - 2000, - 1995, - 2005, - 5, - 1995, - 2005, - 1995, - 1995, - 2005, - 2005, - 1995, - 2000, - 2000, - 2000, - 1995, - 2000, - 2000, - 2005, - 2005, - 1995, - 2005, - 2005, - 1990, - 1995, - 1995, - 1995, - 2005, - 2000, - 1990, - 2010, - 5, - ] - ), - "multi-shell", - [ - np.asarray([5, 5, 5, 5]), - np.asarray([300, 300, 300, 300, 300, 305]), - np.asarray( - [ - 1005, - 995, - 1000, - 1000, - 1005, - 1000, - 1000, - 1005, - 995, - 1000, - 1005, - 995, - 1000, - 1000, - 995, - 1005, - 995, - 1000, - 995, - 995, - ] - ), - np.asarray( - [ - 2005, - 2000, - 2005, - 2005, - 1995, - 2000, - 2005, - 2000, - 1995, - 2005, - 1995, - 2005, - 1995, - 1995, - 2005, - 2005, - 1995, - 2000, - 2000, - 2000, - 1995, - 2000, - 2000, - 2005, - 2005, - 1995, - 2005, - 2005, - 1990, - 1995, - 1995, - 1995, - 2005, - 2000, - 1990, - 2010, - ] - ), - ], - [5, 300, 1000, 2000], - ), - ], -) -def test_find_shelling_scheme_array(bvals, exp_scheme, exp_bval_groups, exp_bval_estimated): - obt_scheme, obt_bval_groups, obt_bval_estimated = find_shelling_scheme(bvals) - assert obt_scheme == exp_scheme - assert all( - np.allclose(obt_arr, exp_arr) - for obt_arr, exp_arr in zip(obt_bval_groups, exp_bval_groups, strict=True) - ) - assert np.allclose(obt_bval_estimated, exp_bval_estimated) - - -@pytest.mark.parametrize( - ("dwi_btable", "exp_scheme", "exp_bval_groups", "exp_bval_estimated"), - [ - ( - "ds000114_singleshell", - "single-shell", - [ - np.asarray([0, 0, 0, 0, 0, 0, 0]), - np.asarray( - [ - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - ] - ), - ], - [0.0, 1000.0], - ), - ( - "hcph_multishell", - "multi-shell", - [ - np.asarray([0, 0, 0, 0, 0, 0]), - np.asarray([700, 700, 700, 700, 700, 700, 700, 700, 700, 700, 700, 700]), - np.asarray( - [ - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - ] - ), - np.asarray( - [ - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - ] - ), - np.asarray( - [ - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - ] - ), - ], - [0.0, 700.0, 1000.0, 2000.0, 3000.0], - ), - ( - "ds004737_dsi", - "DSI", - [ - np.asarray([5, 5, 5, 5, 5, 5, 5, 5, 5]), - np.asarray([995, 995, 800, 800, 995, 995, 795, 995]), - np.asarray([1195, 1195, 1195, 1195, 1000, 1195, 1195, 1000]), - np.asarray([1595, 1595, 1595, 1600.0]), - np.asarray( - [ - 1800, - 1795, - 1795, - 1790, - 1995, - 1800, - 1795, - 1990, - 1990, - 1795, - 1990, - 1795, - 1795, - 1995, - ] - ), - np.asarray([2190, 2195, 2190, 2195, 2000, 2000, 2000, 2195, 2195, 2190]), - np.asarray([2590, 2595, 2600, 2395, 2595, 2600, 2395]), - np.array([2795, 2790, 2795, 2795, 2790, 2795, 2795, 2790, 2795]), - np.array([3590, 3395, 3595, 3595, 3395, 3395, 3400]), - np.array([3790, 3790]), - np.array([4195, 4195]), - np.array([4390, 4395, 4390]), - np.array( - [ - 4790, - 4990, - 4990, - 5000, - 5000, - 4990, - 4795, - 4985, - 5000, - 4795, - 5000, - 4990, - 4990, - 4790, - 5000, - 4990, - 4795, - 4795, - 4990, - 5000, - 4990, - ] - ), - ], - [ - 5.0, - 995.0, - 1195.0, - 1595.0, - 1797.5, - 2190.0, - 2595.0, - 2795.0, - 3400.0, - 3790.0, - 4195.0, - 4390.0, - 4990.0, - ], - ), - ], -) -def test_find_shelling_scheme_files( - dwi_btable, exp_scheme, exp_bval_groups, exp_bval_estimated, repodata -): - bvals = np.loadtxt(repodata / f"{dwi_btable}.bval") - - obt_scheme, obt_bval_groups, obt_bval_estimated = find_shelling_scheme(bvals) - assert obt_scheme == exp_scheme - assert all( - np.allclose(obt_arr, exp_arr) - for obt_arr, exp_arr in zip(obt_bval_groups, exp_bval_groups, strict=True) - ) - assert np.allclose(obt_bval_estimated, exp_bval_estimated) diff --git a/test/test_model.py b/test/test_model.py index ba60ebfe..2173aa3d 100644 --- a/test/test_model.py +++ b/test/test_model.py @@ -31,8 +31,8 @@ from eddymotion.data.dmri import DWI from eddymotion.data.splitting import lovo_split from eddymotion.exceptions import ModelNotFittedError -from eddymotion.model.base import DEFAULT_MAX_S0, DEFAULT_MIN_S0 -from eddymotion.model.dipy import GaussianProcessModel +from eddymotion.model._dipy import GaussianProcessModel +from eddymotion.model.dmri import DEFAULT_MAX_S0, DEFAULT_MIN_S0 def test_trivial_model(): @@ -40,9 +40,9 @@ def test_trivial_model(): rng = np.random.default_rng(1234) - # Should not allow initialization without a B0 - with pytest.raises(ValueError): - model.TrivialB0Model(gtab=np.eye(4)) + # Should not allow initialization without an oracle + with pytest.raises(TypeError): + model.TrivialModel() _S0 = rng.normal(size=(2, 2, 2)) @@ -52,7 +52,7 @@ def test_trivial_model(): a_max=DEFAULT_MAX_S0, ) - tmodel = model.TrivialB0Model(gtab=np.eye(4), S0=_clipped_S0) + tmodel = model.TrivialModel(predicted=_clipped_S0) data = None assert tmodel.fit(data) is None @@ -111,7 +111,7 @@ def test_average_model(): def test_gp_model(): gp = GaussianProcessModel("test") - assert isinstance(gp, model.dipy.GaussianProcessModel) + assert isinstance(gp, model._dipy.GaussianProcessModel) X, y = make_regression(n_samples=100, n_features=3, noise=0, random_state=0) @@ -150,7 +150,7 @@ def test_two_initialisations(datadir): # Initialisation via ModelFactory model2 = model.ModelFactory.init( gtab=data_train[1], - model="avg", + model="avgdwi", S0=dmri_dataset.bzero, th_low=100, th_high=1000,