From 63249f2aa95ef0b0300ea2f1cc68200cc8b13484 Mon Sep 17 00:00:00 2001 From: Richard Shadrach <45562402+rhshadrach@users.noreply.github.com> Date: Tue, 18 Feb 2025 12:39:50 -0500 Subject: [PATCH] ENH: Add dtype argument to str.decode (#60940) * ENH: Add dtype argument to str.decode * Refinements * cleanup * cleanup * type-hint fixup * Simplify condition * lint --- doc/source/whatsnew/v2.3.0.rst | 1 + pandas/core/strings/accessor.py | 16 ++++++++++++++-- pandas/tests/strings/test_strings.py | 24 ++++++++++++++++++++++++ 3 files changed, 39 insertions(+), 2 deletions(-) diff --git a/doc/source/whatsnew/v2.3.0.rst b/doc/source/whatsnew/v2.3.0.rst index 32d9253326277..e1f8e007dc68f 100644 --- a/doc/source/whatsnew/v2.3.0.rst +++ b/doc/source/whatsnew/v2.3.0.rst @@ -37,6 +37,7 @@ Other enhancements updated to work correctly with NumPy >= 2 (:issue:`57739`) - :meth:`Series.str.decode` result now has ``StringDtype`` when ``future.infer_string`` is True (:issue:`60709`) - :meth:`~Series.to_hdf` and :meth:`~DataFrame.to_hdf` now round-trip with ``StringDtype`` (:issue:`60663`) +- The :meth:`Series.str.decode` has gained the argument ``dtype`` to control the dtype of the result (:issue:`60940`) - The :meth:`~Series.cumsum`, :meth:`~Series.cummin`, and :meth:`~Series.cummax` reductions are now implemented for ``StringDtype`` columns when backed by PyArrow (:issue:`60633`) - The :meth:`~Series.sum` reduction is now implemented for ``StringDtype`` columns (:issue:`59853`) diff --git a/pandas/core/strings/accessor.py b/pandas/core/strings/accessor.py index b854338c2d1d7..75fbd642c3520 100644 --- a/pandas/core/strings/accessor.py +++ b/pandas/core/strings/accessor.py @@ -34,6 +34,7 @@ is_numeric_dtype, is_object_dtype, is_re, + is_string_dtype, ) from pandas.core.dtypes.dtypes import ( ArrowDtype, @@ -2102,7 +2103,9 @@ def slice_replace(self, start=None, stop=None, repl=None): result = self._data.array._str_slice_replace(start, stop, repl) return self._wrap_result(result) - def decode(self, encoding, errors: str = "strict"): + def decode( + self, encoding, errors: str = "strict", dtype: str | DtypeObj | None = None + ): """ Decode character string in the Series/Index using indicated encoding. @@ -2116,6 +2119,12 @@ def decode(self, encoding, errors: str = "strict"): errors : str, optional Specifies the error handling scheme. Possible values are those supported by :meth:`bytes.decode`. + dtype : str or dtype, optional + The dtype of the result. When not ``None``, must be either a string or + object dtype. When ``None``, the dtype of the result is determined by + ``pd.options.future.infer_string``. + + .. versionadded:: 2.3.0 Returns ------- @@ -2137,6 +2146,10 @@ def decode(self, encoding, errors: str = "strict"): 2 () dtype: object """ + if dtype is not None and not is_string_dtype(dtype): + raise ValueError(f"dtype must be string or object, got {dtype=}") + if dtype is None and get_option("future.infer_string"): + dtype = "str" # TODO: Add a similar _bytes interface. if encoding in _cpython_optimized_decoders: # CPython optimized implementation @@ -2146,7 +2159,6 @@ def decode(self, encoding, errors: str = "strict"): f = lambda x: decoder(x, errors)[0] arr = self._data.array result = arr._str_map(f) - dtype = "str" if get_option("future.infer_string") else None return self._wrap_result(result, dtype=dtype) @forbid_nonstring_types(["bytes"]) diff --git a/pandas/tests/strings/test_strings.py b/pandas/tests/strings/test_strings.py index ee531b32aa82d..025f837982595 100644 --- a/pandas/tests/strings/test_strings.py +++ b/pandas/tests/strings/test_strings.py @@ -601,6 +601,30 @@ def test_decode_errors_kwarg(): tm.assert_series_equal(result, expected) +def test_decode_string_dtype(string_dtype): + # https://github.com/pandas-dev/pandas/pull/60940 + ser = Series([b"a", b"b"]) + result = ser.str.decode("utf-8", dtype=string_dtype) + expected = Series(["a", "b"], dtype=string_dtype) + tm.assert_series_equal(result, expected) + + +def test_decode_object_dtype(object_dtype): + # https://github.com/pandas-dev/pandas/pull/60940 + ser = Series([b"a", rb"\ud800"]) + result = ser.str.decode("utf-8", dtype=object_dtype) + expected = Series(["a", r"\ud800"], dtype=object_dtype) + tm.assert_series_equal(result, expected) + + +def test_decode_bad_dtype(): + # https://github.com/pandas-dev/pandas/pull/60940 + ser = Series([b"a", b"b"]) + msg = "dtype must be string or object, got dtype='int64'" + with pytest.raises(ValueError, match=msg): + ser.str.decode("utf-8", dtype="int64") + + @pytest.mark.parametrize( "form, expected", [