diff --git a/MARBL_tools/netcdf_comparison.py b/MARBL_tools/netcdf_comparison.py index 087d6bf8..b6331a45 100755 --- a/MARBL_tools/netcdf_comparison.py +++ b/MARBL_tools/netcdf_comparison.py @@ -206,8 +206,8 @@ def _variable_check_loose(ds_base, ds_new, rtol, atol, thres): conversion_factor = _get_conversion_factor(ds_base, ds_new, var) # (1) Are NaNs in the same place? - mask = ~np.isnan(ds_base[var].data) - if np.any(mask ^ ~np.isnan(ds_new[var].data)): + mask = np.isfinite(ds_base[var].data) + if np.any(mask ^ np.isfinite(ds_new[var].data)): error_checking['messages'].append('NaNs are not in same place') # (2) compare everywhere that baseline is 0 @@ -237,8 +237,13 @@ def _variable_check_loose(ds_base, ds_new, rtol, atol, thres): new_data = np.where(np.abs(ds_base[var].data[mask]) > thres, conversion_factor*ds_new[var].data[mask], 0) + # denominator for relative error is column max value + # note the assumption that column is first dimension + col_max = np.nanmax(np.abs(ds_base[var].data), axis=tuple(np.arange(1,len(ds_base[var].data.shape)))) + rel_denom = (np.ones(ds_base[var].data.transpose().shape)*col_max).transpose() + rel_denom = rel_denom[mask] rel_err = np.where(base_data != 0, np.abs(new_data - base_data), 0) / \ - np.where(base_data != 0, np.abs(base_data), 1) + np.where(rel_denom != 0, rel_denom, 1) if np.any(rel_err > rtol): if rtol == 0: abs_err = np.abs(new_data - base_data)