diff --git a/docs/source/data_format_conversion.md b/docs/source/data_format_conversion.md index b0a6ed628..d769fd3fc 100644 --- a/docs/source/data_format_conversion.md +++ b/docs/source/data_format_conversion.md @@ -95,6 +95,8 @@ would also be a valid input to `transform`. import io import json +import pandas as pd + buffer = io.BytesIO() data = pd.DataFrame({"str_col": [*"abc"], "int_col": range(3)}) data.to_parquet(buffer) diff --git a/pandera/api/dataframe/container.py b/pandera/api/dataframe/container.py index 4f7f58d33..89ab6e994 100644 --- a/pandera/api/dataframe/container.py +++ b/pandera/api/dataframe/container.py @@ -349,15 +349,19 @@ def __repr__(self) -> str: f"checks={self.checks}, " f"parsers={self.parsers}, " f"index={self.index.__repr__()}, " - f"coerce={self.coerce}, " f"dtype={self._dtype}, " + f"coerce={self.coerce}, " f"strict={self.strict}, " f"name={self.name}, " f"ordered={self.ordered}, " - f"unique_column_names={self.unique_column_names}" - f"metadata='{self.metadata}, " + f"unique={self.unique}, " + f"report_duplicates={self.report_duplicates}, " f"unique_column_names={self.unique_column_names}, " - f"add_missing_columns={self.add_missing_columns}" + f"add_missing_columns={self.add_missing_columns}, " + f"title={self.title}, " + f"description={self.description}, " + f"metadata={self.metadata}, " + f"drop_invalid_rows={self.drop_invalid_rows}" ")>" ) diff --git a/pandera/backends/pandas/checks.py b/pandera/backends/pandas/checks.py index fb0b41a7d..3c9cc3d61 100644 --- a/pandera/backends/pandas/checks.py +++ b/pandera/backends/pandas/checks.py @@ -238,16 +238,36 @@ def postprocess( ) -> CheckResult: """Postprocesses the result of applying the check function.""" assert check_obj.shape == check_output.shape - check_obj = check_obj.unstack() - check_output = check_output.unstack() if check_obj.index.equals(check_output.index) and self.check.ignore_na: check_output = check_output | check_obj.isna() - failure_cases = ( - check_obj[~check_output] # type: ignore [call-overload] - .rename("failure_case") - .rename_axis(["column", "index"]) - .reset_index() - ) + + # collect failure cases across all columns. Flse values in check_output + # are nulls. + select_failure_cases = check_obj[~check_output] + failure_cases = [] + for col in select_failure_cases.columns: + cases = select_failure_cases[col].rename("failure_case").dropna() + if len(cases) == 0: + continue + failure_cases.append( + cases.to_frame() + .assign(column=col) + .rename_axis("index") + .reset_index() + ) + if failure_cases: + failure_cases = pd.concat(failure_cases, axis=0) + # convert to a dataframe where each row is a failure case at + # a particular index, and failure case values are dictionaries + # indicating which column and value failed in that row. + failure_cases = ( + failure_cases.set_index("column") + .groupby("index") + .agg(lambda df: df.to_dict()) + ) + else: + failure_cases = pd.DataFrame(columns=["index", "failure_case"]) + if not failure_cases.empty and self.check.n_failure_cases is not None: failure_cases = failure_cases.drop_duplicates().head( self.check.n_failure_cases diff --git a/pandera/typing/pandas.py b/pandera/typing/pandas.py index 9c2a0b7c3..ee88a1977 100644 --- a/pandera/typing/pandas.py +++ b/pandera/typing/pandas.py @@ -19,6 +19,7 @@ import numpy as np import pandas as pd +from pandera.config import config_context from pandera.engines import PYDANTIC_V2 from pandera.errors import SchemaError, SchemaInitError from pandera.typing.common import ( @@ -30,12 +31,6 @@ ) from pandera.typing.formats import Formats -try: - from typing import get_args -except ImportError: - from typing_extensions import get_args - - try: from typing import _GenericAlias # type: ignore[attr-defined] except ImportError: # pragma: no cover @@ -190,12 +185,24 @@ def _get_schema_model(cls, field): def __get_pydantic_core_schema__( cls, _source_type: Any, _handler: GetCoreSchemaHandler ) -> core_schema.CoreSchema: - schema_model = get_args(_source_type)[0] - return core_schema.no_info_plain_validator_function( - functools.partial( - cls.pydantic_validate, - schema_model=schema_model, - ), + with config_context(validation_enabled=False): + schema = _source_type().__orig_class__.__args__[0].to_schema() + + type_map = { + "str": core_schema.str_schema(), + "int64": core_schema.int_schema(), + "float64": core_schema.float_schema(), + "bool": core_schema.bool_schema(), + "datetime64[ns]": core_schema.datetime_schema(), + } + + return core_schema.list_schema( + core_schema.typed_dict_schema( + { + i: core_schema.typed_dict_field(type_map[str(j.dtype)]) + for i, j in schema.columns.items() + }, + ) ) else: diff --git a/tests/core/test_schemas.py b/tests/core/test_schemas.py index c11f336e7..956a1f4b3 100644 --- a/tests/core/test_schemas.py +++ b/tests/core/test_schemas.py @@ -1285,7 +1285,7 @@ def test_lazy_dataframe_validation_error() -> None: "DataFrameSchema": { # check name -> failure cases "column_in_schema": ["unknown_col"], - "dataframe_not_equal_1": [1], + "dataframe_not_equal_1": [{"int_col": 1.0, "float_col": 1.0}], "column_in_dataframe": ["not_in_dataframe"], }, "Column": { diff --git a/tests/polars/test_polars_container.py b/tests/polars/test_polars_container.py index 06cf8a8ea..4d61d613a 100644 --- a/tests/polars/test_polars_container.py +++ b/tests/polars/test_polars_container.py @@ -506,7 +506,7 @@ def test_dataframe_schema_with_nested_types(lf_with_nested_types): def test_dataframe_model_with_annotated_nested_types(lf_with_nested_types): class ModelWithAnnotated(DataFrameModel): list_col: Annotated[pl.List, pl.Int64()] - array_col: Annotated[pl.Array, pl.Int64(), 3] + array_col: Annotated[pl.Array, pl.Int64(), 3, None] struct_col: Annotated[pl.Struct, {"a": pl.Utf8(), "b": pl.Float64()}] class Config: @@ -520,7 +520,7 @@ def test_dataframe_schema_with_kwargs_nested_types(lf_with_nested_types): class ModelWithDtypeKwargs(DataFrameModel): list_col: pl.List = pa.Field(dtype_kwargs={"inner": pl.Int64()}) array_col: pl.Array = pa.Field( - dtype_kwargs={"inner": pl.Int64(), "width": 3} + dtype_kwargs={"inner": pl.Int64(), "shape": 3, "width": None} ) struct_col: pl.Struct = pa.Field( dtype_kwargs={"fields": {"a": pl.Utf8(), "b": pl.Float64()}} @@ -550,6 +550,7 @@ def test_dataframe_schema_with_tz_agnostic_dates(time_zone, data): column("datetime_col", dtype=pl.Datetime()), lazy=True, size=10, + allow_null=False, ) lf = data.draw(strategy) lf = lf.cast({"datetime_col": pl.Datetime(time_zone=time_zone)}) diff --git a/tests/polars/test_polars_dtypes.py b/tests/polars/test_polars_dtypes.py index 872d7b002..2d7c4ab9b 100644 --- a/tests/polars/test_polars_dtypes.py +++ b/tests/polars/test_polars_dtypes.py @@ -75,6 +75,8 @@ def get_dataframe_strategy(type_: pl.DataType) -> st.SearchStrategy: @settings(max_examples=1) def test_coerce_no_cast(dtype, data): """Test that dtypes can be coerced without casting.""" + if dtype is pe.Categorical: + pl.enable_string_cache() pandera_dtype = dtype() df = data.draw(get_dataframe_strategy(type_=pandera_dtype.type)) coerced = pandera_dtype.coerce(data_container=PolarsData(df)) diff --git a/tests/polars/test_polars_model.py b/tests/polars/test_polars_model.py index b88481f1f..20931bdd2 100644 --- a/tests/polars/test_polars_model.py +++ b/tests/polars/test_polars_model.py @@ -237,6 +237,7 @@ def test_dataframe_schema_with_tz_agnostic_dates(time_zone, data): column("datetime_col", dtype=pl.Datetime()), lazy=True, size=10, + allow_null=False, ) lf = data.draw(strategy) lf = lf.cast({"datetime_col": pl.Datetime(time_zone=time_zone)})