Skip to content

Commit

Permalink
support for dataarray
Browse files Browse the repository at this point in the history
  • Loading branch information
oloapinivad committed Oct 17, 2024
1 parent d9c1239 commit 81606cc
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 3 deletions.
3 changes: 2 additions & 1 deletion smmregrid/gridtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ def _handle_default_dimensions(self, extra_dims):

update_dims = DEFAULT_DIMS
for dim in extra_dims.keys():
update_dims[dim] = update_dims[dim] + extra_dims[dim]
if extra_dims[dim]:
update_dims[dim] = update_dims[dim] + extra_dims[dim]
return update_dims

def __eq__(self, other):
Expand Down
7 changes: 6 additions & 1 deletion smmregrid/regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,12 @@ def __init__(self, source_grid=None, target_grid=None, weights=None,
source_grid_array_to_cdo = source_grid
else:
# when feeding from xarray, select the variable and its bounds
source_grid_array_to_cdo = source_grid_array[[list(gridtype.variables.keys())[0]] + gridtype.bounds]
if isinstance(source_grid_array, xarray.Dataset):
stored_vars = [list(gridtype.variables.keys())[0]] + gridtype.bounds
self.loggy.debug('Storing variables %s', stored_vars)
source_grid_array_to_cdo = source_grid_array[stored_vars]
else:
source_grid_array_to_cdo = source_grid_array

gridtype.weights = cdo_generate_weights(source_grid_array_to_cdo, target_grid, method=method,
vertical_dim=gridtype.vertical_dim,
Expand Down
9 changes: 8 additions & 1 deletion tests/basic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_healpix_extra(method):
rfield = interpolator.regrid(xfield)
assert rfield['tas'].shape == (2, 180, 360)

@pytest.mark.parametrize("method", ['con', 'nn', 'bil'])
@pytest.mark.parametrize("method", ['con', 'nn', 'bic'])
def test_nan_preserve(method):
"""Test to verify that NaN are preserved"""
xfield = xarray.open_mfdataset(os.path.join(INDIR, 'tas-ecearth.nc'))
Expand All @@ -34,3 +34,10 @@ def test_nan_preserve(method):
interpolator = Regridder(weights=wfield, space_dims='pippo', loglevel='debug')
rfield = interpolator.regrid(xfield)
assert numpy.isnan(rfield['tas'][1,:,:]).all().compute()

@pytest.mark.parametrize("method", ['nn'])
def test_datarray(method):
xfield = xarray.open_mfdataset(os.path.join(INDIR, 'tas-ecearth.nc'))
interpolator = Regridder(source_grid=xfield['tas'], target_grid=tfile, loglevel='debug', method = method)
interp = interpolator.regrid(source_data=xfield)
assert interp['tas'].shape == (12, 180, 360)

0 comments on commit 81606cc

Please sign in to comment.