From bcff1f7c55eb7077fa42e8e1ef231d0542fadd46 Mon Sep 17 00:00:00 2001 From: Matthew Murray <41342305+Matt711@users.noreply.github.com> Date: Mon, 24 Feb 2025 14:44:56 -0500 Subject: [PATCH] Add a slice expression to polars IR (#18050) Closes https://github.com/rapidsai/cudf/issues/18051 Authors: - Matthew Murray (https://github.com/Matt711) Approvers: - Matthew Roeschke (https://github.com/mroeschke) URL: https://github.com/rapidsai/cudf/pull/18050 --- .../cudf_polars/containers/dataframe.py | 4 +- python/cudf_polars/cudf_polars/dsl/expr.py | 4 +- .../cudf_polars/dsl/expressions/slicing.py | 51 +++++++++++++++++++ .../cudf_polars/cudf_polars/dsl/translate.py | 14 +++++ .../tests/expressions/test_slice.py | 24 +++++++++ 5 files changed, 95 insertions(+), 2 deletions(-) create mode 100644 python/cudf_polars/cudf_polars/dsl/expressions/slicing.py create mode 100644 python/cudf_polars/tests/expressions/test_slice.py diff --git a/python/cudf_polars/cudf_polars/containers/dataframe.py b/python/cudf_polars/cudf_polars/containers/dataframe.py index a605b476197..a2b496b8cfe 100644 --- a/python/cudf_polars/cudf_polars/containers/dataframe.py +++ b/python/cudf_polars/cudf_polars/containers/dataframe.py @@ -295,7 +295,7 @@ def filter(self, mask: Column) -> Self: table = plc.stream_compaction.apply_boolean_mask(self.table, mask.obj) return type(self).from_table(table, self.column_names).sorted_like(self) - def slice(self, zlice: tuple[int, int] | None) -> Self: + def slice(self, zlice: tuple[int, int | None] | None) -> Self: """ Slice a dataframe. @@ -312,6 +312,8 @@ def slice(self, zlice: tuple[int, int] | None) -> Self: if zlice is None: return self start, length = zlice + if length is None: + length = self.num_rows if start < 0: start += self.num_rows # Polars implementation wraps negative start by num_rows, then diff --git a/python/cudf_polars/cudf_polars/dsl/expr.py b/python/cudf_polars/cudf_polars/dsl/expr.py index 98d49e36fb1..3ba54543a3e 100644 --- a/python/cudf_polars/cudf_polars/dsl/expr.py +++ b/python/cudf_polars/cudf_polars/dsl/expr.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. # SPDX-License-Identifier: Apache-2.0 # TODO: remove need for this # ruff: noqa: D101 @@ -30,6 +30,7 @@ from cudf_polars.dsl.expressions.literal import Literal, LiteralColumn from cudf_polars.dsl.expressions.rolling import GroupedRollingWindow, RollingWindow from cudf_polars.dsl.expressions.selection import Filter, Gather +from cudf_polars.dsl.expressions.slicing import Slice from cudf_polars.dsl.expressions.sorting import Sort, SortBy from cudf_polars.dsl.expressions.string import StringFunction from cudf_polars.dsl.expressions.ternary import Ternary @@ -53,6 +54,7 @@ "LiteralColumn", "NamedExpr", "RollingWindow", + "Slice", "Sort", "SortBy", "StringFunction", diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/slicing.py b/python/cudf_polars/cudf_polars/dsl/expressions/slicing.py new file mode 100644 index 00000000000..2d3640cce86 --- /dev/null +++ b/python/cudf_polars/cudf_polars/dsl/expressions/slicing.py @@ -0,0 +1,51 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 +# TODO: remove need for this +# ruff: noqa: D101 +"""Slicing DSL nodes.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from cudf_polars.dsl.expressions.base import ( + ExecutionContext, + Expr, +) + +if TYPE_CHECKING: + from collections.abc import Mapping + + import pylibcudf as plc + + from cudf_polars.containers import Column, DataFrame + + +__all__ = ["Slice"] + + +class Slice(Expr): + __slots__ = ("length", "offset") + _non_child = ("dtype", "offset", "length") + + def __init__( + self, + dtype: plc.DataType, + offset: int, + length: int, + column: Expr, + ) -> None: + self.dtype = dtype + self.offset = offset + self.length = length + self.children = (column,) + + def do_evaluate( + self, + df: DataFrame, + *, + context: ExecutionContext = ExecutionContext.FRAME, + mapping: Mapping[Expr, Column] | None = None, + ) -> Column: + """Evaluate this expression given a dataframe for context.""" + return df.slice((self.offset, self.length)).columns[0] diff --git a/python/cudf_polars/cudf_polars/dsl/translate.py b/python/cudf_polars/cudf_polars/dsl/translate.py index 22f97f2bf52..369328d3a8c 100644 --- a/python/cudf_polars/cudf_polars/dsl/translate.py +++ b/python/cudf_polars/cudf_polars/dsl/translate.py @@ -690,6 +690,20 @@ def _(node: pl_expr.SortBy, translator: Translator, dtype: plc.DataType) -> expr ) +@_translate_expr.register +def _(node: pl_expr.Slice, translator: Translator, dtype: plc.DataType) -> expr.Expr: + offset = translator.translate_expr(n=node.offset) + length = translator.translate_expr(n=node.length) + assert isinstance(offset, expr.Literal) + assert isinstance(length, expr.Literal) + return expr.Slice( + dtype, + offset.value.as_py(), + length.value.as_py(), + translator.translate_expr(n=node.input), + ) + + @_translate_expr.register def _(node: pl_expr.Gather, translator: Translator, dtype: plc.DataType) -> expr.Expr: return expr.Gather( diff --git a/python/cudf_polars/tests/expressions/test_slice.py b/python/cudf_polars/tests/expressions/test_slice.py new file mode 100644 index 00000000000..9873be2455f --- /dev/null +++ b/python/cudf_polars/tests/expressions/test_slice.py @@ -0,0 +1,24 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +import pytest + +import polars as pl + +from cudf_polars.testing.asserts import assert_gpu_result_equal + + +@pytest.mark.parametrize( + "zlice", + [ + (1,), + (1, 3), + (-1,), + ], +) +def test_slice(zlice): + df = pl.LazyFrame({"a": [0, 1, 2, 3], "b": [1, 2, 3, 4]}) + q = df.select(pl.col("a").slice(*zlice)) + + assert_gpu_result_equal(q)