Skip to content

Commit

Permalink
Add warning for quantized variables. This lets a user know that the v…
Browse files Browse the repository at this point in the history
…ariable in question is stored as a quantized integer, but loaded as a float, which will cause the calculation of the bitinformation to yield bogus results depending on the data.
  • Loading branch information
JoelJaeschke committed Jul 30, 2024
1 parent 59c5537 commit 59f2089
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 0 deletions.
15 changes: 15 additions & 0 deletions tests/test_get_bitinformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
import pytest
import xarray as xr
import warnings
from numpy.testing import assert_allclose, assert_equal
from xarray.core import formatting
from xarray.core.dataarray import DataArray
Expand Down Expand Up @@ -258,3 +259,17 @@ def test_implementations_agree(ds, dim, axis, request):
masked_value=None,
)
bitinfo_assert_allclose(bi_python, bi_julia, rtol=1e-4)


@pytest.mark.parametrize("implementation", ["python", "julia"])
@pytest.mark.parametrize("dataset_name", ["air_temperature", "eraint_uvz"])
def test_warn_on_quantized_variables(dataset_name, implementation):
ds_quantized = xr.tutorial.load_dataset(dataset_name)
ds_raw = xr.tutorial.load_dataset(dataset_name, mask_and_scale=False)

with pytest.warns(UserWarning):
_ = xb.get_bitinformation(ds_quantized, implementation=implementation)

with warnings.catch_warnings():
warnings.simplefilter("error")
_ = xb.get_bitinformation(ds_raw, implementation=implementation)
21 changes: 21 additions & 0 deletions xbitinfo/xbitinfo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import logging
import os
import warnings

import numpy as np
import xarray as xr
Expand Down Expand Up @@ -233,6 +234,15 @@ def get_bitinformation( # noqa: C901
pbar = tqdm(ds.data_vars)
for var in pbar:
pbar.set_description(f"Processing var: {var} for dim: {dim}")

if _quantized_variable_is_scaled(ds, var):
loaded_dtype = ds[var].dtype
quantized_storage_dtype = ds[var].encoding["dtype"]
warnings.warn(
f"Variable {var} is quantized as {quantized_storage_dtype}, but loaded as {loaded_dtype}. Consider reopening using `mask_and_scale=False` to get sensible results",
category=UserWarning
)

if implementation == "julia":
info_per_bit_var = _jl_get_bitinformation(ds, var, axis, dim, kwargs)
if info_per_bit_var is None:
Expand Down Expand Up @@ -260,6 +270,17 @@ def get_bitinformation( # noqa: C901
return info_per_bit


def _quantized_variable_is_scaled(ds: xr.DataArray, var: str) -> bool:
loaded_dtype = ds[var].dtype
quantized_storage_dtype = ds[var].encoding["dtype"]
has_scale_or_offset = any(["add_offset" in ds[var].encoding, "scale_factor" in ds[var].encoding])

if has_scale_or_offset and quantized_storage_dtype != loaded_dtype:
return True
else:
return False


def _jl_get_bitinformation(ds, var, axis, dim, kwargs={}):
X = ds[var].values
Main.X = X
Expand Down

0 comments on commit 59f2089

Please sign in to comment.