diff --git a/test/test_model.py b/test/test_model.py index 23a097b0..6325510f 100644 --- a/test/test_model.py +++ b/test/test_model.py @@ -28,6 +28,7 @@ from sklearn.gaussian_process.kernels import DotProduct, WhiteKernel from eddymotion import model +from eddymotion.model.dipy import GaussianProcessModel from eddymotion.data.dmri import DWI from eddymotion.data.splitting import lovo_split from eddymotion.exceptions import ModelNotFittedError @@ -107,12 +108,8 @@ def test_average_model(): assert np.all(tmodel_2000.predict([0, 0, 0]) == 1100) -def test_gp_model(datadir): - dwi = DWI.from_filename(datadir / "dwi.h5") - - kernel = DotProduct() + WhiteKernel() - - gp = model.GaussianProcessModel(dwi=dwi, kernel=kernel) +def test_gp_model(): + gp = GaussianProcessModel(kernel="default") assert isinstance(gp, model.GaussianProcessModel)