diff --git a/doc/source/whatsnew/v2.3.0.rst b/doc/source/whatsnew/v2.3.0.rst index e1f8e007dc68f..09134763977c3 100644 --- a/doc/source/whatsnew/v2.3.0.rst +++ b/doc/source/whatsnew/v2.3.0.rst @@ -38,7 +38,7 @@ Other enhancements - :meth:`Series.str.decode` result now has ``StringDtype`` when ``future.infer_string`` is True (:issue:`60709`) - :meth:`~Series.to_hdf` and :meth:`~DataFrame.to_hdf` now round-trip with ``StringDtype`` (:issue:`60663`) - The :meth:`Series.str.decode` has gained the argument ``dtype`` to control the dtype of the result (:issue:`60940`) -- The :meth:`~Series.cumsum`, :meth:`~Series.cummin`, and :meth:`~Series.cummax` reductions are now implemented for ``StringDtype`` columns when backed by PyArrow (:issue:`60633`) +- The :meth:`~Series.cumsum`, :meth:`~Series.cummin`, and :meth:`~Series.cummax` reductions are now implemented for ``StringDtype`` columns (:issue:`60633`) - The :meth:`~Series.sum` reduction is now implemented for ``StringDtype`` columns (:issue:`59853`) .. --------------------------------------------------------------------------- diff --git a/pandas/core/arrays/string_.py b/pandas/core/arrays/string_.py index 623a6a10c75b5..7227ea77ca433 100644 --- a/pandas/core/arrays/string_.py +++ b/pandas/core/arrays/string_.py @@ -49,6 +49,7 @@ ) from pandas.core import ( + missing, nanops, ops, ) @@ -870,6 +871,88 @@ def _reduce( raise TypeError(f"Cannot perform reduction '{name}' with string dtype") + def _accumulate(self, name: str, *, skipna: bool = True, **kwargs) -> StringArray: + """ + Return an ExtensionArray performing an accumulation operation. + + The underlying data type might change. + + Parameters + ---------- + name : str + Name of the function, supported values are: + - cummin + - cummax + - cumsum + - cumprod + skipna : bool, default True + If True, skip NA values. + **kwargs + Additional keyword arguments passed to the accumulation function. + Currently, there is no supported kwarg. + + Returns + ------- + array + + Raises + ------ + NotImplementedError : subclass does not define accumulations + """ + if name == "cumprod": + msg = f"operation '{name}' not supported for dtype '{self.dtype}'" + raise TypeError(msg) + + # We may need to strip out trailing NA values + tail: np.ndarray | None = None + na_mask: np.ndarray | None = None + ndarray = self._ndarray + np_func = { + "cumsum": np.cumsum, + "cummin": np.minimum.accumulate, + "cummax": np.maximum.accumulate, + }[name] + + if self._hasna: + na_mask = cast("npt.NDArray[np.bool_]", isna(ndarray)) + if np.all(na_mask): + return type(self)(ndarray) + if skipna: + if name == "cumsum": + ndarray = np.where(na_mask, "", ndarray) + else: + # We can retain the running min/max by forward/backward filling. + ndarray = ndarray.copy() + missing.pad_or_backfill_inplace( + ndarray, + method="pad", + axis=0, + ) + missing.pad_or_backfill_inplace( + ndarray, + method="backfill", + axis=0, + ) + else: + # When not skipping NA values, the result should be null from + # the first NA value onward. + idx = np.argmax(na_mask) + tail = np.empty(len(ndarray) - idx, dtype="object") + tail[:] = self.dtype.na_value + ndarray = ndarray[:idx] + + # mypy: Cannot call function of unknown type + np_result = np_func(ndarray) # type: ignore[operator] + + if tail is not None: + np_result = np.hstack((np_result, tail)) + elif na_mask is not None: + # Argument 2 to "where" has incompatible type "NAType | float" + np_result = np.where(na_mask, self.dtype.na_value, np_result) # type: ignore[arg-type] + + result = type(self)(np_result) + return result + def _wrap_reduction_result(self, axis: AxisInt | None, result) -> Any: if self.dtype.na_value is np.nan and result is libmissing.NA: # the masked_reductions use pd.NA -> convert to np.nan diff --git a/pandas/tests/apply/test_str.py b/pandas/tests/apply/test_str.py index ce71cfec535e4..e5a9492630b13 100644 --- a/pandas/tests/apply/test_str.py +++ b/pandas/tests/apply/test_str.py @@ -5,7 +5,6 @@ import pytest from pandas.compat import ( - HAS_PYARROW, WASM, ) @@ -162,17 +161,10 @@ def test_agg_cython_table_series(series, func, expected): ), ), ) -def test_agg_cython_table_transform_series(request, series, func, expected): +def test_agg_cython_table_transform_series(series, func, expected): # GH21224 # test transforming functions in # pandas.core.base.SelectionMixin._cython_table (cumprod, cumsum) - if series.dtype == "string" and func == "cumsum" and not HAS_PYARROW: - request.applymarker( - pytest.mark.xfail( - raises=NotImplementedError, - reason="TODO(infer_string) cumsum not yet implemented for string", - ) - ) warn = None if isinstance(func, str) else FutureWarning with tm.assert_produces_warning(warn, match="is currently using Series.*"): result = series.agg(func) diff --git a/pandas/tests/extension/test_string.py b/pandas/tests/extension/test_string.py index 6ce48e434d329..25129111180d6 100644 --- a/pandas/tests/extension/test_string.py +++ b/pandas/tests/extension/test_string.py @@ -196,11 +196,7 @@ def _supports_reduction(self, ser: pd.Series, op_name: str) -> bool: def _supports_accumulation(self, ser: pd.Series, op_name: str) -> bool: assert isinstance(ser.dtype, StorageExtensionDtype) - return ser.dtype.storage == "pyarrow" and op_name in [ - "cummin", - "cummax", - "cumsum", - ] + return op_name in ["cummin", "cummax", "cumsum"] def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result): dtype = cast(StringDtype, tm.get_dtype(obj)) diff --git a/pandas/tests/series/test_cumulative.py b/pandas/tests/series/test_cumulative.py index 89882d9d797c5..db83cf1112e74 100644 --- a/pandas/tests/series/test_cumulative.py +++ b/pandas/tests/series/test_cumulative.py @@ -265,13 +265,14 @@ def test_cumprod_timedelta(self): ([pd.NA, pd.NA, pd.NA], "cummax", False, [pd.NA, pd.NA, pd.NA]), ], ) - def test_cum_methods_pyarrow_strings( - self, pyarrow_string_dtype, data, op, skipna, expected_data + def test_cum_methods_ea_strings( + self, string_dtype_no_object, data, op, skipna, expected_data ): - # https://github.com/pandas-dev/pandas/pull/60633 - ser = pd.Series(data, dtype=pyarrow_string_dtype) + # https://github.com/pandas-dev/pandas/pull/60633 - pyarrow + # https://github.com/pandas-dev/pandas/pull/60938 - Python + ser = pd.Series(data, dtype=string_dtype_no_object) method = getattr(ser, op) - expected = pd.Series(expected_data, dtype=pyarrow_string_dtype) + expected = pd.Series(expected_data, dtype=string_dtype_no_object) result = method(skipna=skipna) tm.assert_series_equal(result, expected)