Skip to content

Commit

Permalink
Convert ArrowExtensionArray to proper NumPy dtype (#56290)
Browse files Browse the repository at this point in the history
  • Loading branch information
phofl authored Dec 21, 2023
1 parent 3096bd6 commit 2a9c3d7
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 42 deletions.
2 changes: 1 addition & 1 deletion doc/source/whatsnew/v2.2.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ documentation.
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

``to_numpy`` for NumPy nullable and Arrow types will now convert to a
suitable NumPy dtype instead of ``object`` dtype for nullable extension dtypes.
suitable NumPy dtype instead of ``object`` dtype for nullable and PyArrow backed extension dtypes.

*Old behavior:*

Expand Down
54 changes: 54 additions & 0 deletions pandas/core/arrays/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from __future__ import annotations

from typing import (
TYPE_CHECKING,
Any,
)

import numpy as np

from pandas._libs import lib
from pandas.errors import LossySetitemError

from pandas.core.dtypes.cast import np_can_hold_element
from pandas.core.dtypes.common import is_numeric_dtype

if TYPE_CHECKING:
from pandas._typing import (
ArrayLike,
npt,
)


def to_numpy_dtype_inference(
arr: ArrayLike, dtype: npt.DTypeLike | None, na_value, hasna: bool
) -> tuple[npt.DTypeLike, Any]:
if dtype is None and is_numeric_dtype(arr.dtype):
dtype_given = False
if hasna:
if arr.dtype.kind == "b":
dtype = np.dtype(np.object_)
else:
if arr.dtype.kind in "iu":
dtype = np.dtype(np.float64)
else:
dtype = arr.dtype.numpy_dtype # type: ignore[union-attr]
if na_value is lib.no_default:
na_value = np.nan
else:
dtype = arr.dtype.numpy_dtype # type: ignore[union-attr]
elif dtype is not None:
dtype = np.dtype(dtype)
dtype_given = True
else:
dtype_given = True

if na_value is lib.no_default:
na_value = arr.dtype.na_value

if not dtype_given and hasna:
try:
np_can_hold_element(dtype, na_value) # type: ignore[arg-type]
except LossySetitemError:
dtype = np.dtype(np.object_)
return dtype, na_value
16 changes: 10 additions & 6 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
is_bool_dtype,
is_integer,
is_list_like,
is_numeric_dtype,
is_scalar,
)
from pandas.core.dtypes.dtypes import DatetimeTZDtype
Expand All @@ -50,8 +51,10 @@
ops,
roperator,
)
from pandas.core.algorithms import map_array
from pandas.core.arraylike import OpsMixin
from pandas.core.arrays._arrow_string_mixins import ArrowStringArrayMixin
from pandas.core.arrays._utils import to_numpy_dtype_inference
from pandas.core.arrays.base import (
ExtensionArray,
ExtensionArraySupportsAnyAll,
Expand Down Expand Up @@ -1317,12 +1320,7 @@ def to_numpy(
copy: bool = False,
na_value: object = lib.no_default,
) -> np.ndarray:
if dtype is not None:
dtype = np.dtype(dtype)

if na_value is lib.no_default:
na_value = self.dtype.na_value

dtype, na_value = to_numpy_dtype_inference(self, dtype, na_value, self._hasna)
pa_type = self._pa_array.type
if not self._hasna or isna(na_value) or pa.types.is_null(pa_type):
data = self
Expand Down Expand Up @@ -1366,6 +1364,12 @@ def to_numpy(
result[~mask] = data[~mask]._pa_array.to_numpy()
return result

def map(self, mapper, na_action=None):
if is_numeric_dtype(self.dtype):
return map_array(self.to_numpy(), mapper, na_action=None)
else:
return super().map(mapper, na_action)

@doc(ExtensionArray.duplicated)
def duplicated(
self, keep: Literal["first", "last", False] = "first"
Expand Down
34 changes: 3 additions & 31 deletions pandas/core/arrays/masked.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,11 @@
IS64,
is_platform_windows,
)
from pandas.errors import (
AbstractMethodError,
LossySetitemError,
)
from pandas.errors import AbstractMethodError
from pandas.util._decorators import doc
from pandas.util._validators import validate_fillna_kwargs

from pandas.core.dtypes.base import ExtensionDtype
from pandas.core.dtypes.cast import np_can_hold_element
from pandas.core.dtypes.common import (
is_bool,
is_integer_dtype,
Expand Down Expand Up @@ -80,6 +76,7 @@
)
from pandas.core.array_algos.quantile import quantile_with_mask
from pandas.core.arraylike import OpsMixin
from pandas.core.arrays._utils import to_numpy_dtype_inference
from pandas.core.arrays.base import ExtensionArray
from pandas.core.construction import (
array as pd_array,
Expand Down Expand Up @@ -477,32 +474,7 @@ def to_numpy(
array([ True, False, False])
"""
hasna = self._hasna

if dtype is None:
dtype_given = False
if hasna:
if self.dtype.kind == "b":
dtype = object
else:
if self.dtype.kind in "iu":
dtype = np.dtype(np.float64)
else:
dtype = self.dtype.numpy_dtype
if na_value is lib.no_default:
na_value = np.nan
else:
dtype = self.dtype.numpy_dtype
else:
dtype = np.dtype(dtype)
dtype_given = True
if na_value is lib.no_default:
na_value = libmissing.NA

if not dtype_given and hasna:
try:
np_can_hold_element(dtype, na_value) # type: ignore[arg-type]
except LossySetitemError:
dtype = object
dtype, na_value = to_numpy_dtype_inference(self, dtype, na_value, hasna)

if hasna:
if (
Expand Down
14 changes: 10 additions & 4 deletions pandas/tests/extension/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,13 @@ def test_map(self, data_missing, na_action):
expected = data_missing.to_numpy(dtype=object)
tm.assert_numpy_array_equal(result, expected)
else:
super().test_map(data_missing, na_action)
result = data_missing.map(lambda x: x, na_action=na_action)
if data_missing.dtype == "float32[pyarrow]":
# map roundtrips through objects, which converts to float64
expected = data_missing.to_numpy(dtype="float64", na_value=np.nan)
else:
expected = data_missing.to_numpy()
tm.assert_numpy_array_equal(result, expected)

def test_astype_str(self, data, request):
pa_dtype = data.dtype.pyarrow_dtype
Expand Down Expand Up @@ -1585,7 +1591,7 @@ def test_to_numpy_with_defaults(data):
else:
expected = np.array(data._pa_array)

if data._hasna:
if data._hasna and not is_numeric_dtype(data.dtype):
expected = expected.astype(object)
expected[pd.isna(data)] = pd.NA

Expand All @@ -1597,8 +1603,8 @@ def test_to_numpy_int_with_na():
data = [1, None]
arr = pd.array(data, dtype="int64[pyarrow]")
result = arr.to_numpy()
expected = np.array([1, pd.NA], dtype=object)
assert isinstance(result[0], int)
expected = np.array([1, np.nan])
assert isinstance(result[0], float)
tm.assert_numpy_array_equal(result, expected)


Expand Down
11 changes: 11 additions & 0 deletions pandas/tests/series/test_npfuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import numpy as np
import pytest

import pandas.util._test_decorators as td

from pandas import Series
import pandas._testing as tm

Expand Down Expand Up @@ -33,3 +35,12 @@ def test_numpy_argwhere(index):
expected = np.array([[3], [4]], dtype=np.int64)

tm.assert_numpy_array_equal(result, expected)


@td.skip_if_no("pyarrow")
def test_log_arrow_backed_missing_value():
# GH#56285
ser = Series([1, 2, None], dtype="float64[pyarrow]")
result = np.log(ser)
expected = np.log(Series([1, 2, None], dtype="float64"))
tm.assert_series_equal(result, expected)

0 comments on commit 2a9c3d7

Please sign in to comment.