From 23823d940f42f81b75cf1d896eb3f9b861de21c2 Mon Sep 17 00:00:00 2001 From: Enrico Minack Date: Wed, 12 Feb 2025 14:56:41 +0100 Subject: [PATCH] Add Dataset.assert_sorted to Python API --- python/pyarrow/_compute.pyx | 7 +++-- python/pyarrow/_dataset.pyx | 29 ++++++++++++++++++ python/pyarrow/acero.py | 14 +++++++++ python/pyarrow/tests/test_dataset.py | 45 ++++++++++++++++++++++++++++ 4 files changed, 93 insertions(+), 2 deletions(-) diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index 030cb4ee34f7e..43653c56e25ac 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -1981,8 +1981,11 @@ class PartitionNthOptions(_PartitionNthOptions): cdef class Ordering(_Weakrefable): - def __init__(self): - _forbid_instantiation(self.__class__) + def __init__(self, sort_keys, *, null_placement="at_end"): + c_sort_keys = unwrap_sort_keys(sort_keys, allow_str=False) + c_null_placement = unwrap_null_placement(null_placement) + ordering = COrdering(c_sort_keys, c_null_placement) + self.init(ordering) cdef void init(self, const COrdering& sp): self.wrapped = sp diff --git a/python/pyarrow/_dataset.pyx b/python/pyarrow/_dataset.pyx index 45b1acfa8464b..804095005bd84 100644 --- a/python/pyarrow/_dataset.pyx +++ b/python/pyarrow/_dataset.pyx @@ -872,6 +872,35 @@ cdef class Dataset(_Weakrefable): ) return res + def assert_sort(self, sorting, *, null_placement="at_end"): + """ + Assert the Dataset is sorted by one or multiple columns. + + Parameters + ---------- + sorting : str or list[tuple(name, order)] + Name of the column to use to sort (ascending), or + a list of multiple sorting conditions where + each entry is a tuple with column name + and sorting order ("ascending" or "descending") + null_placement : str, default "at_end" + Where nulls in input should be sorted, only applying to + columns/fields mentioned in `sort_keys`. + Accepted values are "at_start", "at_end". + + Returns + ------- + InMemoryDataset + A new dataset where sorted order is guaranteed or an exception is raised. + """ + if isinstance(sorting, str): + sorting = [(sorting, "ascending")] + + res = _pac()._assert_sorted( + self, output_type=InMemoryDataset, sort_keys=sorting, null_placement=null_placement + ) + return res + def join(self, right_dataset, keys, right_keys=None, join_type="left outer", left_suffix=None, right_suffix=None, coalesce_keys=True, use_threads=True): diff --git a/python/pyarrow/acero.py b/python/pyarrow/acero.py index 8f73d8cac4770..6e5baa91a9505 100644 --- a/python/pyarrow/acero.py +++ b/python/pyarrow/acero.py @@ -400,6 +400,20 @@ def _sort_source(table_or_dataset, sort_keys, output_type=Table, **kwargs): raise TypeError("Unsupported output type") +def _assert_sorted(dataset, sort_keys, *, null_placement="at_end", output_type=Table): + + ordering = Ordering(sort_keys, null_placement=null_placement) + decl = _dataset_to_decl(dataset, use_threads=True, ordering=ordering) + result_table = decl.to_table(use_threads=True) + + if output_type == Table: + return result_table + elif output_type == ds.InMemoryDataset: + return ds.InMemoryDataset(result_table) + else: + raise TypeError("Unsupported output type") + + def _group_by(table, aggregates, keys, use_threads=True): decl = Declaration.from_sequence([ diff --git a/python/pyarrow/tests/test_dataset.py b/python/pyarrow/tests/test_dataset.py index b6aaa2840d83c..44c98b04d8702 100644 --- a/python/pyarrow/tests/test_dataset.py +++ b/python/pyarrow/tests/test_dataset.py @@ -29,6 +29,8 @@ from shutil import copytree from urllib.parse import quote +from pyarrow import ArrowException + try: import numpy as np except ImportError: @@ -5654,6 +5656,49 @@ def test_dataset_sort_by(tempdir, dstype): assert sorted_tab_dict["b"] == ["foo", "car", "bar", "foobar"] +def do_test_dataset_assert_sorted(tempdir, dstype, table, expect_sorted, **kwargs): + if dstype == "fs": + filename = "-".join([f"{k}={v}" for k, v in kwargs.items()]) + ds.write_dataset(table, tempdir / filename, format="ipc") + dt = ds.dataset(tempdir / filename, format="ipc") + elif dstype == "mem": + dt = ds.dataset(table) + else: + raise NotImplementedError + + if expect_sorted: + dt.assert_sort(**kwargs).to_table() + else: + with pytest.raises(ArrowException, match="Data is not ordered"): + dt.assert_sort(**kwargs).to_table() + + +@pytest.mark.parametrize('dstype', [ + "fs", "mem" +]) +def test_dataset_assert_sorted(tempdir, dstype): + table = pa.table([ + pa.array([1, 2, 3, 4, None]), + pa.array(["b", "a", "b", "a", "c"]), + ], names=["values", "keys"]) + + def assert_sorted(**kwargs): + do_test_dataset_assert_sorted(tempdir, dstype, table, True, **kwargs) + + assert_sorted(sorting="values") + assert_sorted(sorting="values", null_placement="at_end") + assert_sorted(sorting=[("values", "ascending")]) + assert_sorted(sorting=[("values", "ascending")], null_placement="at_end") + + def assert_not_sorted(**kwargs): + do_test_dataset_assert_sorted(tempdir, dstype, table, False, **kwargs) + + assert_not_sorted(sorting="keys") + assert_not_sorted(sorting="values", null_placement="at_start") + assert_not_sorted(sorting=[("values", "descending")]) + assert_not_sorted(sorting=[("values", "ascending")], null_placement="at_start") + + def test_checksum_write_dataset_read_dataset_to_table(tempdir): """Check that checksum verification works for datasets created with ds.write_dataset and read with ds.dataset.to_table"""