Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] PyArrow #1637

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,13 @@ keywords = ["data science", "machine learning", "typing"]
license = {file = "LICENSE"}
requires-python = ">=3.8,<4"
dependencies = [
"pandas >= 1.4.0, != 1.4.2",
"pandas >= 1.5.0",
"scikit-learn >= 0.22",
"python-dateutil >= 2.8.1",
"scipy >= 1.4.0",
"importlib-resources >= 5.10.0",
"numpy >= 1.21.0",
"pyarrow >= 1.0.0",
]

[project.urls]
Expand Down
8 changes: 6 additions & 2 deletions woodwork/accessor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
import pandas as pd

from woodwork.exceptions import ColumnNotPresentInSchemaError, WoodworkNotInitError
from woodwork.utils import _get_column_logical_type, import_or_none
from woodwork.utils import (
_check_data_type_equality,
_get_column_logical_type,
import_or_none,
)

dd = import_or_none("dask.dataframe")
ps = import_or_none("pyspark.pandas")
Expand Down Expand Up @@ -123,7 +127,7 @@ def get_invalid_schema_message(dataframe, schema):
for name in dataframe.columns:
df_dtype = dataframe[name].dtype
valid_dtype = logical_types[name]._get_valid_dtype(type(dataframe[name]))
if str(df_dtype) != valid_dtype:
if not _check_data_type_equality(str(df_dtype), valid_dtype):
return (
f"dtype mismatch for column {name} between DataFrame dtype, "
f"{df_dtype}, and {logical_types[name]} dtype, {valid_dtype}"
Expand Down
14 changes: 10 additions & 4 deletions woodwork/column_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@
from woodwork.logical_types import _NULLABLE_PHYSICAL_TYPES, LatLong, Ordinal
from woodwork.statistics_utils import _get_box_plot_info_for_column
from woodwork.table_schema import TableSchema
from woodwork.utils import _get_column_logical_type, import_or_none
from woodwork.utils import (
_check_data_type_equality,
_get_column_logical_type,
import_or_none,
)

dd = import_or_none("dask.dataframe")
ps = import_or_none("pyspark.pandas")
Expand Down Expand Up @@ -107,7 +111,9 @@ def init(
logical_type.validate(self._series)
else:
valid_dtype = logical_type._get_valid_dtype(type(self._series))
if valid_dtype != str(self._series.dtype) and not (
if not _check_data_type_equality(
valid_dtype, str(self._series.dtype)
) and not (
pdtypes.is_integer_dtype(valid_dtype)
and pdtypes.is_float_dtype(self._series.dtype)
):
Expand Down Expand Up @@ -287,7 +293,7 @@ def wrapper(*args, **kwargs):
valid_dtype = self._schema.logical_type._get_valid_dtype(
type(result),
)
if str(result.dtype) == valid_dtype:
if _check_data_type_equality(str(result.dtype), valid_dtype):
result.ww.init(schema=self.schema, validate=False)
else:
invalid_schema_message = (
Expand Down Expand Up @@ -571,7 +577,7 @@ def _validate_schema(schema, series):
raise TypeError("Provided schema must be a Woodwork.ColumnSchema object.")

valid_dtype = schema.logical_type._get_valid_dtype(type(series))
if str(series.dtype) != valid_dtype:
if not _check_data_type_equality(str(series.dtype), valid_dtype):
raise ValueError(
f"dtype mismatch between Series dtype {series.dtype}, and {schema.logical_type} dtype, {valid_dtype}",
)
Expand Down
2 changes: 1 addition & 1 deletion woodwork/deserializers/deserializer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def _set_init_dict(self, loading_info):
cat_object = pd.CategoricalDtype(pd.Series(cat_values))
col_type = cat_object
elif table_type == "spark" and col_type == "object":
col_type = "string"
col_type = "string[pyarrow]"
self.column_dtypes[col_name] = col_type

if "index" in self.kwargs.keys():
Expand Down
47 changes: 24 additions & 23 deletions woodwork/logical_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
)
from woodwork.type_sys.utils import _get_specified_ltype_params
from woodwork.utils import (
_check_data_type_equality,
_infer_datetime_format,
_is_valid_latlong_series,
_is_valid_latlong_value,
Expand Down Expand Up @@ -45,7 +46,7 @@ class LogicalType(object, metaclass=LogicalTypeMetaClass):
"""Base class for all other Logical Types"""

type_string = ClassNameDescriptor()
primary_dtype = "string"
primary_dtype = "string[pyarrow]"
pyspark_dtype = None
standard_tags = set()

Expand All @@ -68,7 +69,7 @@ def _get_valid_dtype(cls, series_type):
def transform(self, series, null_invalid_values=False):
"""Converts the series dtype to match the logical type's if it is different."""
new_dtype = self._get_valid_dtype(type(series))
if new_dtype != str(series.dtype):
if new_dtype != str(series.dtype) and new_dtype != series.dtype:
# Update the underlying series
try:
series = series.astype(new_dtype)
Expand All @@ -81,7 +82,7 @@ def validate(self, series, *args, **kwargs):
specific validation, as required. When the series' dtype does not match the logical types' required dtype,
raises a TypeValidationError."""
valid_dtype = self._get_valid_dtype(type(series))
if valid_dtype != str(series.dtype):
if not _check_data_type_equality(valid_dtype, str(series.dtype)):
raise TypeValidationError(
f"Series dtype '{series.dtype}' is incompatible with {self.type_string} LogicalType, try converting to {valid_dtype} dtype",
)
Expand All @@ -97,7 +98,7 @@ class Address(LogicalType):
['26387 Russell Hill, Dallas, TX 34521', '54305 Oxford Street, Seattle, WA 95132']
"""

primary_dtype = "string"
primary_dtype = "string[pyarrow]"


class Age(LogicalType):
Expand Down Expand Up @@ -270,7 +271,7 @@ class Categorical(LogicalType):
"""

primary_dtype = "category"
pyspark_dtype = "string"
pyspark_dtype = "string[pyarrow]"
standard_tags = {"category"}

def __init__(self, encoding=None):
Expand All @@ -291,7 +292,7 @@ class CountryCode(LogicalType):
"""

primary_dtype = "category"
pyspark_dtype = "string"
pyspark_dtype = "string[pyarrow]"
standard_tags = {"category"}


Expand All @@ -306,7 +307,7 @@ class CurrencyCode(LogicalType):
"""

primary_dtype = "category"
pyspark_dtype = "string"
pyspark_dtype = "string[pyarrow]"
standard_tags = {"category"}


Expand Down Expand Up @@ -486,7 +487,7 @@ class EmailAddress(LogicalType):
"team@example.com"]
"""

primary_dtype = "string"
primary_dtype = "string[pyarrow]"

def transform(self, series, null_invalid_values=False):
if null_invalid_values:
Expand Down Expand Up @@ -518,7 +519,7 @@ class Filepath(LogicalType):
"/tmp"]
"""

primary_dtype = "string"
primary_dtype = "string[pyarrow]"


class PersonFullName(LogicalType):
Expand All @@ -533,7 +534,7 @@ class PersonFullName(LogicalType):
"James Brown"]
"""

primary_dtype = "string"
primary_dtype = "string[pyarrow]"


class IPAddress(LogicalType):
Expand All @@ -548,7 +549,7 @@ class IPAddress(LogicalType):
"2001:0db8:0000:0000:0000:ff00:0042:8329"]
"""

primary_dtype = "string"
primary_dtype = "string[pyarrow]"


class LatLong(LogicalType):
Expand Down Expand Up @@ -616,7 +617,7 @@ class NaturalLanguage(LogicalType):
"When will humans go to mars?"]
"""

primary_dtype = "string"
primary_dtype = "string[pyarrow]"


class Unknown(LogicalType):
Expand All @@ -631,7 +632,7 @@ class Unknown(LogicalType):

"""

primary_dtype = "string"
primary_dtype = "string[pyarrow]"


class Ordinal(LogicalType):
Expand All @@ -651,7 +652,7 @@ class Ordinal(LogicalType):
"""

primary_dtype = "category"
pyspark_dtype = "string"
pyspark_dtype = "string[pyarrow]"
standard_tags = {"category"}

def __init__(self, order=None):
Expand Down Expand Up @@ -707,7 +708,7 @@ class PhoneNumber(LogicalType):
"5551235495"]
"""

primary_dtype = "string"
primary_dtype = "string[pyarrow]"

def transform(self, series, null_invalid_values=False):
if null_invalid_values:
Expand Down Expand Up @@ -741,7 +742,7 @@ class SubRegionCode(LogicalType):
"""

primary_dtype = "category"
pyspark_dtype = "string"
pyspark_dtype = "string[pyarrow]"
standard_tags = {"category"}


Expand Down Expand Up @@ -771,7 +772,7 @@ class URL(LogicalType):
"example.com"]
"""

primary_dtype = "string"
primary_dtype = "string[pyarrow]"

def transform(self, series, null_invalid_values=False):
if null_invalid_values:
Expand Down Expand Up @@ -804,7 +805,7 @@ class PostalCode(LogicalType):
"""

primary_dtype = "category"
pyspark_dtype = "string"
pyspark_dtype = "string[pyarrow]"
standard_tags = {"category"}

def transform(self, series, null_invalid_values=False):
Expand All @@ -813,9 +814,9 @@ def transform(self, series, null_invalid_values=False):

if pd.api.types.is_numeric_dtype(series):
try:
series = series.astype("Int64").astype("string")
series = series.astype("Int64").astype("string[pyarrow]")
except TypeError:
raise TypeConversionError(series, "string", type(self))
raise TypeConversionError(series, "string[pyarrow]", type(self))

return super().transform(series)

Expand Down Expand Up @@ -851,7 +852,7 @@ def validate(self, series, return_invalid_values=False):
"float64",
"float128",
"object",
"string",
"string[pyarrow]",
"timedelta64[ns]",
}

Expand Down Expand Up @@ -889,7 +890,7 @@ def _replace_nans(series: pd.Series, primary_dtype: Optional[str] = None) -> pd.
original_dtype = series.dtype
if primary_dtype == str(original_dtype):
return series
if str(original_dtype) == "string":
if "string" in str(original_dtype):
series = series.replace(ww.config.get_option("nan_values"), pd.NA)
return series
if not _is_spark_series(series):
Expand Down Expand Up @@ -945,7 +946,7 @@ def _get_index_invalid_latlong(series):

def _coerce_string(series, regex=None):
if pd.api.types.is_object_dtype(series) or not pd.api.types.is_string_dtype(series):
series = series.astype("string")
series = series.astype("string[pyarrow]")

if isinstance(regex, str):
invalid = _get_index_invalid_string(series, regex)
Expand Down
2 changes: 1 addition & 1 deletion woodwork/serializers/orc_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,6 @@ def save_orc_file(dataframe, filepath):
df = dataframe.copy()
for c in df:
if df[c].dtype.name == "category":
df[c] = df[c].astype("string")
df[c] = df[c].astype("string[pyarrow]")
pa_table = Table.from_pandas(df, preserve_index=False)
orc.write_table(pa_table, filepath)
8 changes: 7 additions & 1 deletion woodwork/statistics_utils/_get_mode.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
from pyarrow.lib import ArrowNotImplementedError


def _get_mode(series):
"""Get the mode value for a series"""
mode_values = series.mode()
try:
mode_values = series.mode()
except ArrowNotImplementedError:
mode_values = series.astype("string").mode()
if len(mode_values) > 0:
return mode_values[0]
return None
Loading