Skip to content

Commit

Permalink
TST: fix Decimal constructor xfail (pandas-dev#54338)
Browse files Browse the repository at this point in the history
* TST: fix Decimal constructor xfail

* mypy fixup
  • Loading branch information
jbrockmendel authored Aug 1, 2023
1 parent 55ec5e7 commit ae6a335
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 15 deletions.
8 changes: 6 additions & 2 deletions pandas/tests/extension/decimal/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from pandas.core.dtypes.common import (
is_dtype_equal,
is_float,
is_integer,
pandas_dtype,
)

Expand Down Expand Up @@ -71,11 +72,14 @@ class DecimalArray(OpsMixin, ExtensionScalarOpsMixin, ExtensionArray):

def __init__(self, values, dtype=None, copy=False, context=None) -> None:
for i, val in enumerate(values):
if is_float(val):
if is_float(val) or is_integer(val):
if np.isnan(val):
values[i] = DecimalDtype.na_value
else:
values[i] = DecimalDtype.type(val)
# error: Argument 1 has incompatible type "float | int |
# integer[Any]"; expected "Decimal | float | str | tuple[int,
# Sequence[int], int]"
values[i] = DecimalDtype.type(val) # type: ignore[arg-type]
elif not isinstance(val, decimal.Decimal):
raise TypeError("All values must be of type " + str(decimal.Decimal))
values = np.asarray(values, dtype=object)
Expand Down
22 changes: 9 additions & 13 deletions pandas/tests/extension/decimal/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,20 +267,16 @@ def test_series_repr(self, data):
assert "Decimal: " in repr(ser)


@pytest.mark.xfail(
reason=(
"DecimalArray constructor raises bc _from_sequence wants Decimals, not ints."
"Easy to fix, just need to do it."
),
raises=TypeError,
)
def test_series_constructor_coerce_data_to_extension_dtype_raises():
xpr = (
"Cannot cast data to extension dtype 'decimal'. Pass the "
"extension array directly."
def test_series_constructor_coerce_data_to_extension_dtype():
dtype = DecimalDtype()
ser = pd.Series([0, 1, 2], dtype=dtype)

arr = DecimalArray(
[decimal.Decimal(0), decimal.Decimal(1), decimal.Decimal(2)],
dtype=dtype,
)
with pytest.raises(ValueError, match=xpr):
pd.Series([0, 1, 2], dtype=DecimalDtype())
exp = pd.Series(arr)
tm.assert_series_equal(ser, exp)


def test_series_constructor_with_dtype():
Expand Down

0 comments on commit ae6a335

Please sign in to comment.