diff --git a/tests/test_get_bitinformation.py b/tests/test_get_bitinformation.py index 68b53804..74543403 100644 --- a/tests/test_get_bitinformation.py +++ b/tests/test_get_bitinformation.py @@ -5,6 +5,7 @@ import numpy as np import pytest import xarray as xr +import warnings from numpy.testing import assert_allclose, assert_equal from xarray.core import formatting from xarray.core.dataarray import DataArray @@ -258,3 +259,17 @@ def test_implementations_agree(ds, dim, axis, request): masked_value=None, ) bitinfo_assert_allclose(bi_python, bi_julia, rtol=1e-4) + + +@pytest.mark.parametrize("implementation", ["python", "julia"]) +@pytest.mark.parametrize("dataset_name", ["air_temperature", "eraint_uvz"]) +def test_warn_on_quantized_variables(dataset_name, implementation): + ds_quantized = xr.tutorial.load_dataset(dataset_name) + ds_raw = xr.tutorial.load_dataset(dataset_name, mask_and_scale=False) + + with pytest.warns(UserWarning): + _ = xb.get_bitinformation(ds_quantized, implementation=implementation) + + with warnings.catch_warnings(): + warnings.simplefilter("error") + _ = xb.get_bitinformation(ds_raw, implementation=implementation) \ No newline at end of file diff --git a/xbitinfo/xbitinfo.py b/xbitinfo/xbitinfo.py index d5d6b543..2f125990 100644 --- a/xbitinfo/xbitinfo.py +++ b/xbitinfo/xbitinfo.py @@ -1,6 +1,7 @@ import json import logging import os +import warnings import numpy as np import xarray as xr @@ -233,6 +234,15 @@ def get_bitinformation( # noqa: C901 pbar = tqdm(ds.data_vars) for var in pbar: pbar.set_description(f"Processing var: {var} for dim: {dim}") + + if _quantized_variable_is_scaled(ds, var): + loaded_dtype = ds[var].dtype + quantized_storage_dtype = ds[var].encoding["dtype"] + warnings.warn( + f"Variable {var} is quantized as {quantized_storage_dtype}, but loaded as {loaded_dtype}. Consider reopening using `mask_and_scale=False` to get sensible results", + category=UserWarning + ) + if implementation == "julia": info_per_bit_var = _jl_get_bitinformation(ds, var, axis, dim, kwargs) if info_per_bit_var is None: @@ -260,6 +270,17 @@ def get_bitinformation( # noqa: C901 return info_per_bit +def _quantized_variable_is_scaled(ds: xr.DataArray, var: str) -> bool: + loaded_dtype = ds[var].dtype + quantized_storage_dtype = ds[var].encoding["dtype"] + has_scale_or_offset = any(["add_offset" in ds[var].encoding, "scale_factor" in ds[var].encoding]) + + if has_scale_or_offset and quantized_storage_dtype != loaded_dtype: + return True + else: + return False + + def _jl_get_bitinformation(ds, var, axis, dim, kwargs={}): X = ds[var].values Main.X = X