Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add warning for quantized variables #286

Merged
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ CHANGELOG
X.X.X (unreleased)
------------------

* Add warning for quantized variables (:pr:`286`, :issue:`202`) `Joel Jaeschke`_.
* Update BitInformation.jl version to v0.6.3 (:pr:`292`) `Hauke Schulz`_
* Improve test/docs environment separation (:pr:`275`, :issue:`267`) `Aryan Bakliwal`_.
* Set default masked value to None for integers (:pr:`289`) `Hauke Schulz`_.
Expand Down
15 changes: 15 additions & 0 deletions tests/test_get_bitinformation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Tests for `xbitinfo` package."""

import os
import warnings

import numpy as np
import pytest
Expand Down Expand Up @@ -267,3 +268,17 @@ def test_implementations_agree(ds, dim, axis, request):
masked_value=None,
)
bitinfo_assert_allclose(bi_python, bi_julia, rtol=1e-4)


@pytest.mark.parametrize("implementation", ["python", "julia"])
@pytest.mark.parametrize("dataset_name", ["air_temperature", "eraint_uvz"])
def test_warn_on_quantized_variables(dataset_name, implementation):
ds_quantized = xr.tutorial.load_dataset(dataset_name)
ds_raw = xr.tutorial.load_dataset(dataset_name, mask_and_scale=False)

with pytest.warns(UserWarning):
_ = xb.get_bitinformation(ds_quantized, implementation=implementation)

with warnings.catch_warnings():
warnings.simplefilter("error")
_ = xb.get_bitinformation(ds_raw, implementation=implementation)
30 changes: 30 additions & 0 deletions xbitinfo/xbitinfo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import logging
import os
import warnings

import numpy as np
import xarray as xr
Expand Down Expand Up @@ -233,6 +234,15 @@ def get_bitinformation( # noqa: C901
pbar = tqdm(ds.data_vars)
for var in pbar:
pbar.set_description(f"Processing var: {var} for dim: {dim}")

if _quantized_variable_is_scaled(ds, var):
loaded_dtype = ds[var].dtype
quantized_storage_dtype = ds[var].encoding["dtype"]
warnings.warn(
f"Variable {var} is quantized as {quantized_storage_dtype}, but loaded as {loaded_dtype}. Consider reopening using `mask_and_scale=False` to get sensible results",
category=UserWarning,
)

if implementation == "julia":
info_per_bit_var = _jl_get_bitinformation(ds, var, axis, dim, kwargs)
if info_per_bit_var is None:
Expand Down Expand Up @@ -260,6 +270,26 @@ def get_bitinformation( # noqa: C901
return info_per_bit


def _quantized_variable_is_scaled(ds: xr.DataArray, var: str) -> bool:
has_scale_or_offset = any(
["add_offset" in ds[var].encoding, "scale_factor" in ds[var].encoding]
)

if not has_scale_or_offset:
return False

loaded_dtype = ds[var].dtype
storage_dtype = ds[var].encoding.get("dtype", None)
assert (
storage_dtype is not None
), f"Variable {var} is likely quantized, but does not have a storage dtype"

if loaded_dtype == storage_dtype:
return False

return True


def _jl_get_bitinformation(ds, var, axis, dim, kwargs={}):
X = ds[var].values
Main.X = X
Expand Down