diff --git a/pandas/tests/extension/decimal/array.py b/pandas/tests/extension/decimal/array.py index d24b70a884c45..f0a6a42cb252e 100644 --- a/pandas/tests/extension/decimal/array.py +++ b/pandas/tests/extension/decimal/array.py @@ -12,6 +12,7 @@ from pandas.core.dtypes.common import ( is_dtype_equal, is_float, + is_integer, pandas_dtype, ) @@ -71,11 +72,14 @@ class DecimalArray(OpsMixin, ExtensionScalarOpsMixin, ExtensionArray): def __init__(self, values, dtype=None, copy=False, context=None) -> None: for i, val in enumerate(values): - if is_float(val): + if is_float(val) or is_integer(val): if np.isnan(val): values[i] = DecimalDtype.na_value else: - values[i] = DecimalDtype.type(val) + # error: Argument 1 has incompatible type "float | int | + # integer[Any]"; expected "Decimal | float | str | tuple[int, + # Sequence[int], int]" + values[i] = DecimalDtype.type(val) # type: ignore[arg-type] elif not isinstance(val, decimal.Decimal): raise TypeError("All values must be of type " + str(decimal.Decimal)) values = np.asarray(values, dtype=object) diff --git a/pandas/tests/extension/decimal/test_decimal.py b/pandas/tests/extension/decimal/test_decimal.py index 6feac7fb9d9dc..16ce2c9312f7c 100644 --- a/pandas/tests/extension/decimal/test_decimal.py +++ b/pandas/tests/extension/decimal/test_decimal.py @@ -267,20 +267,16 @@ def test_series_repr(self, data): assert "Decimal: " in repr(ser) -@pytest.mark.xfail( - reason=( - "DecimalArray constructor raises bc _from_sequence wants Decimals, not ints." - "Easy to fix, just need to do it." - ), - raises=TypeError, -) -def test_series_constructor_coerce_data_to_extension_dtype_raises(): - xpr = ( - "Cannot cast data to extension dtype 'decimal'. Pass the " - "extension array directly." +def test_series_constructor_coerce_data_to_extension_dtype(): + dtype = DecimalDtype() + ser = pd.Series([0, 1, 2], dtype=dtype) + + arr = DecimalArray( + [decimal.Decimal(0), decimal.Decimal(1), decimal.Decimal(2)], + dtype=dtype, ) - with pytest.raises(ValueError, match=xpr): - pd.Series([0, 1, 2], dtype=DecimalDtype()) + exp = pd.Series(arr) + tm.assert_series_equal(ser, exp) def test_series_constructor_with_dtype():