Skip to content

Commit

Permalink
BUG: __eq__ raising for new arrow string dtype for incompatible objec…
Browse files Browse the repository at this point in the history
…ts (#56245)
  • Loading branch information
phofl authored Dec 21, 2023
1 parent 8864319 commit 64c20dc
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 1 deletion.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v2.2.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -691,6 +691,7 @@ Strings
- Bug in :meth:`Series.str.replace` when ``n < 0`` for :class:`ArrowDtype` with ``pyarrow.string`` (:issue:`56404`)
- Bug in :meth:`Series.str.startswith` and :meth:`Series.str.endswith` with arguments of type ``tuple[str, ...]`` for :class:`ArrowDtype` with ``pyarrow.string`` dtype (:issue:`56579`)
- Bug in :meth:`Series.str.startswith` and :meth:`Series.str.endswith` with arguments of type ``tuple[str, ...]`` for ``string[pyarrow]`` (:issue:`54942`)
- Bug in comparison operations for ``dtype="string[pyarrow_numpy]"`` raising if dtypes can't be compared (:issue:`56008`)

Interval
^^^^^^^^
Expand Down
6 changes: 5 additions & 1 deletion pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
BaseStringArray,
StringDtype,
)
from pandas.core.ops import invalid_comparison
from pandas.core.strings.object_array import ObjectStringArrayMixin

if not pa_version_under10p1:
Expand Down Expand Up @@ -676,7 +677,10 @@ def _convert_int_dtype(self, result):
return result

def _cmp_method(self, other, op):
result = super()._cmp_method(other, op)
try:
result = super()._cmp_method(other, op)
except pa.ArrowNotImplementedError:
return invalid_comparison(self, other, op)
if op == operator.ne:
return result.to_numpy(np.bool_, na_value=True)
else:
Expand Down
16 changes: 16 additions & 0 deletions pandas/tests/series/test_logical_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,3 +530,19 @@ def test_int_dtype_different_index_not_bool(self):

result = ser1 ^ ser2
tm.assert_series_equal(result, expected)

def test_pyarrow_numpy_string_invalid(self):
# GH#56008
pytest.importorskip("pyarrow")
ser = Series([False, True])
ser2 = Series(["a", "b"], dtype="string[pyarrow_numpy]")
result = ser == ser2
expected = Series(False, index=ser.index)
tm.assert_series_equal(result, expected)

result = ser != ser2
expected = Series(True, index=ser.index)
tm.assert_series_equal(result, expected)

with pytest.raises(TypeError, match="Invalid comparison"):
ser > ser2

0 comments on commit 64c20dc

Please sign in to comment.