From 81606cccb4abe18f68c6203e8c0c6caff632c66a Mon Sep 17 00:00:00 2001 From: Paolo Davini Date: Thu, 17 Oct 2024 17:08:24 +0200 Subject: [PATCH] support for dataarray --- smmregrid/gridtype.py | 3 ++- smmregrid/regrid.py | 7 ++++++- tests/basic_test.py | 9 ++++++++- 3 files changed, 16 insertions(+), 3 deletions(-) diff --git a/smmregrid/gridtype.py b/smmregrid/gridtype.py index a0f7300..2a18744 100644 --- a/smmregrid/gridtype.py +++ b/smmregrid/gridtype.py @@ -45,7 +45,8 @@ def _handle_default_dimensions(self, extra_dims): update_dims = DEFAULT_DIMS for dim in extra_dims.keys(): - update_dims[dim] = update_dims[dim] + extra_dims[dim] + if extra_dims[dim]: + update_dims[dim] = update_dims[dim] + extra_dims[dim] return update_dims def __eq__(self, other): diff --git a/smmregrid/regrid.py b/smmregrid/regrid.py index f2f37c7..3483f13 100644 --- a/smmregrid/regrid.py +++ b/smmregrid/regrid.py @@ -143,7 +143,12 @@ def __init__(self, source_grid=None, target_grid=None, weights=None, source_grid_array_to_cdo = source_grid else: # when feeding from xarray, select the variable and its bounds - source_grid_array_to_cdo = source_grid_array[[list(gridtype.variables.keys())[0]] + gridtype.bounds] + if isinstance(source_grid_array, xarray.Dataset): + stored_vars = [list(gridtype.variables.keys())[0]] + gridtype.bounds + self.loggy.debug('Storing variables %s', stored_vars) + source_grid_array_to_cdo = source_grid_array[stored_vars] + else: + source_grid_array_to_cdo = source_grid_array gridtype.weights = cdo_generate_weights(source_grid_array_to_cdo, target_grid, method=method, vertical_dim=gridtype.vertical_dim, diff --git a/tests/basic_test.py b/tests/basic_test.py index 142c8c0..e9f429c 100644 --- a/tests/basic_test.py +++ b/tests/basic_test.py @@ -25,7 +25,7 @@ def test_healpix_extra(method): rfield = interpolator.regrid(xfield) assert rfield['tas'].shape == (2, 180, 360) -@pytest.mark.parametrize("method", ['con', 'nn', 'bil']) +@pytest.mark.parametrize("method", ['con', 'nn', 'bic']) def test_nan_preserve(method): """Test to verify that NaN are preserved""" xfield = xarray.open_mfdataset(os.path.join(INDIR, 'tas-ecearth.nc')) @@ -34,3 +34,10 @@ def test_nan_preserve(method): interpolator = Regridder(weights=wfield, space_dims='pippo', loglevel='debug') rfield = interpolator.regrid(xfield) assert numpy.isnan(rfield['tas'][1,:,:]).all().compute() + +@pytest.mark.parametrize("method", ['nn']) +def test_datarray(method): + xfield = xarray.open_mfdataset(os.path.join(INDIR, 'tas-ecearth.nc')) + interpolator = Regridder(source_grid=xfield['tas'], target_grid=tfile, loglevel='debug', method = method) + interp = interpolator.regrid(source_data=xfield) + assert interp['tas'].shape == (12, 180, 360)