From 81606cccb4abe18f68c6203e8c0c6caff632c66a Mon Sep 17 00:00:00 2001
From: Paolo Davini
Date: Thu, 17 Oct 2024 17:08:24 +0200
Subject: [PATCH] support for dataarray
---
smmregrid/gridtype.py | 3 ++-
smmregrid/regrid.py | 7 ++++++-
tests/basic_test.py | 9 ++++++++-
3 files changed, 16 insertions(+), 3 deletions(-)
diff --git a/smmregrid/gridtype.py b/smmregrid/gridtype.py
index a0f7300..2a18744 100644
--- a/smmregrid/gridtype.py
+++ b/smmregrid/gridtype.py
@@ -45,7 +45,8 @@ def _handle_default_dimensions(self, extra_dims):
update_dims = DEFAULT_DIMS
for dim in extra_dims.keys():
- update_dims[dim] = update_dims[dim] + extra_dims[dim]
+ if extra_dims[dim]:
+ update_dims[dim] = update_dims[dim] + extra_dims[dim]
return update_dims
def __eq__(self, other):
diff --git a/smmregrid/regrid.py b/smmregrid/regrid.py
index f2f37c7..3483f13 100644
--- a/smmregrid/regrid.py
+++ b/smmregrid/regrid.py
@@ -143,7 +143,12 @@ def __init__(self, source_grid=None, target_grid=None, weights=None,
source_grid_array_to_cdo = source_grid
else:
# when feeding from xarray, select the variable and its bounds
- source_grid_array_to_cdo = source_grid_array[[list(gridtype.variables.keys())[0]] + gridtype.bounds]
+ if isinstance(source_grid_array, xarray.Dataset):
+ stored_vars = [list(gridtype.variables.keys())[0]] + gridtype.bounds
+ self.loggy.debug('Storing variables %s', stored_vars)
+ source_grid_array_to_cdo = source_grid_array[stored_vars]
+ else:
+ source_grid_array_to_cdo = source_grid_array
gridtype.weights = cdo_generate_weights(source_grid_array_to_cdo, target_grid, method=method,
vertical_dim=gridtype.vertical_dim,
diff --git a/tests/basic_test.py b/tests/basic_test.py
index 142c8c0..e9f429c 100644
--- a/tests/basic_test.py
+++ b/tests/basic_test.py
@@ -25,7 +25,7 @@ def test_healpix_extra(method):
rfield = interpolator.regrid(xfield)
assert rfield['tas'].shape == (2, 180, 360)
-@pytest.mark.parametrize("method", ['con', 'nn', 'bil'])
+@pytest.mark.parametrize("method", ['con', 'nn', 'bic'])
def test_nan_preserve(method):
"""Test to verify that NaN are preserved"""
xfield = xarray.open_mfdataset(os.path.join(INDIR, 'tas-ecearth.nc'))
@@ -34,3 +34,10 @@ def test_nan_preserve(method):
interpolator = Regridder(weights=wfield, space_dims='pippo', loglevel='debug')
rfield = interpolator.regrid(xfield)
assert numpy.isnan(rfield['tas'][1,:,:]).all().compute()
+
+@pytest.mark.parametrize("method", ['nn'])
+def test_datarray(method):
+ xfield = xarray.open_mfdataset(os.path.join(INDIR, 'tas-ecearth.nc'))
+ interpolator = Regridder(source_grid=xfield['tas'], target_grid=tfile, loglevel='debug', method = method)
+ interp = interpolator.regrid(source_data=xfield)
+ assert interp['tas'].shape == (12, 180, 360)