Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support sorted merges in cudf.polars #18075

Open
wants to merge 10 commits into
base: branch-25.04
Choose a base branch
from
35 changes: 30 additions & 5 deletions python/cudf_polars/cudf_polars/dsl/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1653,11 +1653,36 @@ def do_evaluate(cls, schema: Schema, df: DataFrame) -> DataFrame:
class MergeSorted(IR):
"""Merge sorted operation."""

def __init__(self, schema: Schema, left: IR, right: IR, key: str):
# libcudf merge is not stable wrt order of inputs, since
# it uses a priority queue to manage the tables it produces.
# See: https://github.com/rapidsai/cudf/issues/16010
raise NotImplementedError("MergeSorted not yet implemented")
__slots__ = ("key",)
_non_child = ("key",)
key: str
"""Key that is sorted."""

def __init__(self, schema: Schema, key: str, left: IR, right: IR):
assert isinstance(left, Sort)
assert isinstance(right, Sort)
assert left.order == right.order
assert len(left.schema.keys()) <= len(right.schema.keys())
self.schema = schema
self.key = key
self.children = (left, right)
self._non_child_args = (key,)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this defined above as _non_child? Can something be reused or simplified here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need both. The _non_child arguments are needed to serialization/deserialization, among other things. While the _non_child_args are used to execute the IR nodes.


@classmethod
def do_evaluate(cls, key: str, *dfs: DataFrame) -> DataFrame:
left, right = dfs
right = right.discard_columns(right.column_names_set - left.column_names_set)
on_col_left = left.select_columns({key})[0]
on_col_right = right.select_columns({key})[0]
return DataFrame.from_table(
plc.merge.merge(
[right.table, left.table],
[left.column_names.index(key), right.column_names.index(key)],
[on_col_left.order, on_col_right.order],
[on_col_left.null_order, on_col_right.null_order],
),
left.column_names,
)


class MapFunction(IR):
Expand Down
4 changes: 2 additions & 2 deletions python/cudf_polars/cudf_polars/dsl/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,14 +467,14 @@ def _(
def _(
node: pl_ir.MergeSorted, translator: Translator, schema: dict[str, plc.DataType]
) -> ir.IR:
key = node.key
inp_left = translator.translate_ir(n=node.input_left)
inp_right = translator.translate_ir(n=node.input_right)
key = node.key
return ir.MergeSorted(
schema,
key,
inp_left,
inp_right,
key,
)


Expand Down
10 changes: 0 additions & 10 deletions python/cudf_polars/tests/test_mapfunction.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,6 @@
)


def test_merge_sorted_raises():
df1 = pl.LazyFrame({"a": [1, 6, 9], "b": [1, -10, 4]})
df2 = pl.LazyFrame({"a": [-1, 5, 11, 20], "b": [2, 7, -4, None]})
df3 = pl.LazyFrame({"a": [-10, 20, 21], "b": [1, 2, 3]})

q = df1.merge_sorted(df2, key="a").merge_sorted(df3, key="a")

assert_ir_translation_raises(q, NotImplementedError)


def test_explode_multiple_raises():
df = pl.LazyFrame({"a": [[1, 2], [3, 4]], "b": [[5, 6], [7, 8]]})
q = df.explode("a", "b")
Expand Down
59 changes: 59 additions & 0 deletions python/cudf_polars/tests/test_merge_sorted.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

import pytest

import polars as pl

from cudf_polars.testing.asserts import assert_gpu_result_equal


@pytest.mark.parametrize(
"descending",
[
pytest.param(
True,
marks=pytest.mark.xfail(reason="polars/issues/21464"),
),
False,
],
)
def test_merge_sorted_without_nulls(descending):
df0 = pl.LazyFrame({"name": ["steve", "elise", "bob"], "age": [42, 44, 18]}).sort(
"age", descending=descending
)
df1 = pl.LazyFrame(
{
"name": ["anna", "megan", "steve", "thomas"],
"age": [21, 33, 42, 20],
"height": [5, 5, 5, 5],
}
).sort("age", descending=descending)
q = df0.merge_sorted(df1, key="age")
assert_gpu_result_equal(q)


@pytest.mark.parametrize(
"descending",
[
pytest.param(
True,
marks=pytest.mark.xfail(reason="polars/issues/21464"),
),
False,
],
)
def test_merge_sorted_with_nulls(descending):
df0 = pl.LazyFrame(
{"name": ["steve", "elise", "bob", "john"], "age": [42, 44, 18, None]}
).sort("age", descending=descending)
df1 = pl.LazyFrame(
{
"name": ["anna", "megan", "steve", "thomas", "john"],
"age": [21, 33, 42, 20, None],
"height": [5, 5, 5, 5, 5],
}
).sort("age", descending=descending)
q = df0.merge_sorted(df1, key="age")
assert_gpu_result_equal(q)
Loading