diff --git a/earthkit/data/readers/netcdf.py b/earthkit/data/readers/netcdf.py index 0074ff27..b3ec3318 100644 --- a/earthkit/data/readers/netcdf.py +++ b/earthkit/data/readers/netcdf.py @@ -336,6 +336,14 @@ def __init__(self, ds, *args, **kwargs): self.ds = ds self._fields = None Index.__init__(self, *args, **kwargs) + # populate with in-built xarray methods: + for method in dir(ds): + if not method.startswith("_") and method not in dir(self): + try: + setattr(self.__class__, method, classmethod(getattr(ds, method))) + except Exception: + # Ignore incompatible methods + pass @property def fields(self): diff --git a/earthkit/data/wrappers/xarray.py b/earthkit/data/wrappers/xarray.py index 4d088ead..107378c4 100644 --- a/earthkit/data/wrappers/xarray.py +++ b/earthkit/data/wrappers/xarray.py @@ -20,6 +20,14 @@ class XArrayDataArrayWrapper(Wrapper): def __init__(self, data): self.data = data + # populate with in-built xarray methods: + for method in dir(data): + if not method.startswith("_") and method not in dir(self): + try: + setattr(self.__class__, method, classmethod(getattr(data, method))) + except Exception: + # Ignore those that are incompatible + pass # def axis(self, axis): # """ @@ -138,10 +146,10 @@ def wrapper(data, *args, **kwargs): if isinstance(data, xr.Dataset): ds = data elif isinstance(data, xr.DataArray): - try: - ds = data.to_dataset() - except ValueError: - return XArrayDataArrayWrapper(data, *args, **kwargs) + # try: + # ds = data.to_dataset() + # except ValueError: + return XArrayDataArrayWrapper(data, *args, **kwargs) if ds is not None: fs = netcdf.XArrayFieldList(ds, **kwargs) diff --git a/tests/translators/test_translators.py b/tests/translators/test_translators.py index f4a92501..81368174 100644 --- a/tests/translators/test_translators.py +++ b/tests/translators/test_translators.py @@ -150,7 +150,7 @@ def test_transform_from_grib_file(): def test_transform_from_xarray_object(): # transform grib-based data object - da = xr.DataArray([]) + da = xr.DataArray([], name="a") ds = xr.Dataset({"a": da}) # da to np.ndarray diff --git a/tests/wrappers/test_xarray.py b/tests/wrappers/test_xarray.py index fc8e7398..db928183 100644 --- a/tests/wrappers/test_xarray.py +++ b/tests/wrappers/test_xarray.py @@ -12,29 +12,42 @@ import logging +import numpy as np +import xarray as xr + from earthkit.data import from_object, wrappers from earthkit.data.wrappers import xarray as xr_wrapper LOG = logging.getLogger(__name__) -def test_dataset_wrapper(): - import xarray as xr +TEST_DA = xr.DataArray( + np.arange(9).reshape(3, 3), name="test", coords={"x": [1, 2, 3], "y": [1, 2, 3]} +) +TEST_DS = TEST_DA.to_dataset() - _wrapper = xr_wrapper.wrapper(xr.Dataset()) + +def test_dataset_wrapper(): + _wrapper = xr_wrapper.wrapper(TEST_DS) assert isinstance(_wrapper, xr_wrapper.XArrayDatasetWrapper) - _wrapper = wrappers.get_wrapper(xr.Dataset()) + _wrapper = wrappers.get_wrapper(TEST_DS) assert isinstance(_wrapper, xr_wrapper.XArrayDatasetWrapper) - _wrapper = from_object(xr.Dataset()) + _wrapper = from_object(TEST_DS) assert isinstance(_wrapper, xr_wrapper.XArrayDatasetWrapper) def test_dataarray_wrapper(): - import xarray as xr - - _wrapper = xr_wrapper.wrapper(xr.DataArray()) + _wrapper = xr_wrapper.wrapper(TEST_DA) assert isinstance(_wrapper, xr_wrapper.XArrayDataArrayWrapper) - _wrapper = wrappers.get_wrapper(xr.DataArray()) + _wrapper = wrappers.get_wrapper(TEST_DA) assert isinstance(_wrapper, xr_wrapper.XArrayDataArrayWrapper) - _wrapper = from_object(xr.DataArray()) + _wrapper = from_object(TEST_DA) assert isinstance(_wrapper, xr_wrapper.XArrayDataArrayWrapper) + + +def test_inbuilt_xarray_methods(): + _wrapper = from_object(TEST_DA) + assert _wrapper.mean().equals(TEST_DA.mean()) + + _wrapper = from_object(TEST_DS) + assert _wrapper.mean().equals(TEST_DS.mean())