Skip to content

Commit

Permalink
apacheGH-41540: [R] Simplify arrow_eval() logic and bindings environm…
Browse files Browse the repository at this point in the history
…ents (apache#41537)

### Rationale for this change

NSE is hard enough. I wanted to see if I could remove some layers of
complexity.

### What changes are included in this PR?

* There no longer are separate collections of `agg_funcs` and
`nse_funcs`. Now that the aggregation functions return Expressions
(apache#41223), there's no reason to treat
them separately. All bindings return Expressions now.
* Both are removed and functions are just stored in `.cache$functions`.
There was a note wondering why both `nse_funcs` and that needed to
exist. They don't.
* `arrow_mask()` no longer has an `aggregations` argument: agg functions
are always present.
* Because agg functions are always present, `filter` and `arrange` now
have to check for whether the expressions passed to them contain
aggregations--this is supported in regular dplyr but we have deferred
supporting it here for now (see
apache#41350). If we decide we want to
support it later, these checks are the entry points where we'd drop in
the `left_join()` as in `mutate()`.
* The logic of evaluating expresssions in `filter()` has been
simplified.
* Assorted other cleanups: `register_binding()` has two fewer arguments,
for example, and the duplicate functions for referencing agg_funcs are
gone.

There is one more refactor I intend to pursue, and that's to rework
abandon_ship and how arrow_eval does error handling, but I ~may~ will
defer that to a followup.

### Are these changes tested?

Yes, though I'll add some more for filter/aggregate in the followup
since I'm reworking things there.

### Are there any user-facing changes?

There are a couple of edge cases where the error message will change
subtly. For example, if you supplied a comma-separated list of filter
expressions, and more than one of them did not evaluate, previously you
would be informed of all of the failures; now, we error on the first
one. I don't think this is concerning.
* GitHub Issue: apache#41540
  • Loading branch information
nealrichardson authored May 7, 2024
1 parent c79b6a5 commit 03f8ae7
Show file tree
Hide file tree
Showing 15 changed files with 109 additions and 261 deletions.
8 changes: 8 additions & 0 deletions r/R/dplyr-arrange.R
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,14 @@ arrange.arrow_dplyr_query <- function(.data, ..., .by_group = FALSE) {
msg <- paste("Expression", names(sorts)[i], "not supported in Arrow")
return(abandon_ship(call, .data, msg))
}
if (length(mask$.aggregations)) {
# dplyr lets you arrange on e.g. x < mean(x), but we haven't implemented it.
# But we could, the same way it works in mutate() via join, if someone asks.
# Until then, just error.
# TODO: add a test for this
msg <- paste("Expression", format_expr(expr), "not supported in arrange() in Arrow")
return(abandon_ship(call, .data, msg))
}
descs[i] <- x[["desc"]]
}
.data$arrange_vars <- c(sorts, .data$arrange_vars)
Expand Down
17 changes: 1 addition & 16 deletions r/R/dplyr-eval.R
Original file line number Diff line number Diff line change
Expand Up @@ -121,24 +121,9 @@ arrow_not_supported <- function(msg) {
}

# Create a data mask for evaluating a dplyr expression
arrow_mask <- function(.data, aggregation = FALSE) {
arrow_mask <- function(.data) {
f_env <- new_environment(.cache$functions)

if (aggregation) {
# Add the aggregation functions to the environment.
for (f in names(agg_funcs)) {
f_env[[f]] <- agg_funcs[[f]]
}
} else {
# Add functions that need to error hard and clear.
# Some R functions will still try to evaluate on an Expression
# and return NA with a warning :exploding_head:
fail <- function(...) stop("Not implemented")
for (f in c("mean", "sd")) {
f_env[[f]] <- fail
}
}

# Assign the schema to the expressions
schema <- .data$.data$schema
walk(.data$selected_columns, ~ (.$schema <- schema))
Expand Down
54 changes: 15 additions & 39 deletions r/R/dplyr-filter.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,48 +35,24 @@ filter.arrow_dplyr_query <- function(.data, ..., .by = NULL, .preserve = FALSE)
}

# tidy-eval the filter expressions inside an Arrow data_mask
filters <- lapply(expanded_filters, arrow_eval, arrow_mask(out))
bad_filters <- map_lgl(filters, ~ inherits(., "try-error"))
if (any(bad_filters)) {
# This is similar to abandon_ship() except that the filter eval is
# vectorized, and we apply filters that _did_ work before abandoning ship
# with the rest
expr_labs <- map_chr(expanded_filters[bad_filters], format_expr)
if (query_on_dataset(out)) {
# Abort. We don't want to auto-collect if this is a Dataset because that
# could blow up, too big.
stop(
"Filter expression not supported for Arrow Datasets: ",
oxford_paste(expr_labs, quote = FALSE),
"\nCall collect() first to pull data into R.",
call. = FALSE
)
} else {
arrow_errors <- map2_chr(
filters[bad_filters], expr_labs,
handle_arrow_not_supported
)
if (length(arrow_errors) == 1) {
msg <- paste0(arrow_errors, "; ")
} else {
msg <- paste0("* ", arrow_errors, "\n", collapse = "")
}
warning(
msg, "pulling data into R",
immediate. = TRUE,
call. = FALSE
)
# Set any valid filters first, then collect and then apply the invalid ones in R
out <- dplyr::collect(set_filters(out, filters[!bad_filters]))
if (by$from_by) {
out <- dplyr::ungroup(out)
}
return(dplyr::filter(out, !!!expanded_filters[bad_filters], .by = {{ .by }}))
mask <- arrow_mask(out)
for (expr in expanded_filters) {
filt <- arrow_eval(expr, mask)
if (inherits(filt, "try-error")) {
msg <- handle_arrow_not_supported(filt, format_expr(expr))
return(abandon_ship(match.call(), .data, msg))
}
if (length(mask$.aggregations)) {
# dplyr lets you filter on e.g. x < mean(x), but we haven't implemented it.
# But we could, the same way it works in mutate() via join, if someone asks.
# Until then, just error.
# TODO: add a test for this
msg <- paste("Expression", format_expr(expr), "not supported in filter() in Arrow")
return(abandon_ship(match.call(), .data, msg))
}
out <- set_filters(out, filt)
}

out <- set_filters(out, filters)

if (by$from_by) {
out$group_by_vars <- character()
}
Expand Down
26 changes: 13 additions & 13 deletions r/R/dplyr-funcs-agg.R
Original file line number Diff line number Diff line change
Expand Up @@ -29,56 +29,56 @@
# you can use list_compute_functions("^hash_")

register_bindings_aggregate <- function() {
register_binding_agg("base::sum", function(..., na.rm = FALSE) {
register_binding("base::sum", function(..., na.rm = FALSE) {
set_agg(
fun = "sum",
data = ensure_one_arg(list2(...), "sum"),
options = list(skip_nulls = na.rm, min_count = 0L)
)
})
register_binding_agg("base::prod", function(..., na.rm = FALSE) {
register_binding("base::prod", function(..., na.rm = FALSE) {
set_agg(
fun = "product",
data = ensure_one_arg(list2(...), "prod"),
options = list(skip_nulls = na.rm, min_count = 0L)
)
})
register_binding_agg("base::any", function(..., na.rm = FALSE) {
register_binding("base::any", function(..., na.rm = FALSE) {
set_agg(
fun = "any",
data = ensure_one_arg(list2(...), "any"),
options = list(skip_nulls = na.rm, min_count = 0L)
)
})
register_binding_agg("base::all", function(..., na.rm = FALSE) {
register_binding("base::all", function(..., na.rm = FALSE) {
set_agg(
fun = "all",
data = ensure_one_arg(list2(...), "all"),
options = list(skip_nulls = na.rm, min_count = 0L)
)
})
register_binding_agg("base::mean", function(x, na.rm = FALSE) {
register_binding("base::mean", function(x, na.rm = FALSE) {
set_agg(
fun = "mean",
data = list(x),
options = list(skip_nulls = na.rm, min_count = 0L)
)
})
register_binding_agg("stats::sd", function(x, na.rm = FALSE, ddof = 1) {
register_binding("stats::sd", function(x, na.rm = FALSE, ddof = 1) {
set_agg(
fun = "stddev",
data = list(x),
options = list(skip_nulls = na.rm, min_count = 0L, ddof = ddof)
)
})
register_binding_agg("stats::var", function(x, na.rm = FALSE, ddof = 1) {
register_binding("stats::var", function(x, na.rm = FALSE, ddof = 1) {
set_agg(
fun = "variance",
data = list(x),
options = list(skip_nulls = na.rm, min_count = 0L, ddof = ddof)
)
})
register_binding_agg(
register_binding(
"stats::quantile",
function(x, probs, na.rm = FALSE) {
if (length(probs) != 1) {
Expand All @@ -103,7 +103,7 @@ register_bindings_aggregate <- function() {
"approximate quantile (t-digest) is computed"
)
)
register_binding_agg(
register_binding(
"stats::median",
function(x, na.rm = FALSE) {
# TODO: Bind to the Arrow function that returns an exact median and remove
Expand All @@ -122,28 +122,28 @@ register_bindings_aggregate <- function() {
},
notes = "approximate median (t-digest) is computed"
)
register_binding_agg("dplyr::n_distinct", function(..., na.rm = FALSE) {
register_binding("dplyr::n_distinct", function(..., na.rm = FALSE) {
set_agg(
fun = "count_distinct",
data = ensure_one_arg(list2(...), "n_distinct"),
options = list(na.rm = na.rm)
)
})
register_binding_agg("dplyr::n", function() {
register_binding("dplyr::n", function() {
set_agg(
fun = "count_all",
data = list(),
options = list()
)
})
register_binding_agg("base::min", function(..., na.rm = FALSE) {
register_binding("base::min", function(..., na.rm = FALSE) {
set_agg(
fun = "min",
data = ensure_one_arg(list2(...), "min"),
options = list(skip_nulls = na.rm, min_count = 0L)
)
})
register_binding_agg("base::max", function(..., na.rm = FALSE) {
register_binding("base::max", function(..., na.rm = FALSE) {
set_agg(
fun = "max",
data = ensure_one_arg(list2(...), "max"),
Expand Down
119 changes: 24 additions & 95 deletions r/R/dplyr-funcs.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ NULL

#' Register compute bindings
#'
#' The `register_binding()` and `register_binding_agg()` functions
#' are used to populate a list of functions that operate on (and return)
#' `register_binding()` is used to populate a list of functions that operate on
#' (and return)
#' Expressions. These are the basis for the `.data` mask inside dplyr methods.
#'
#' @section Writing bindings:
Expand All @@ -40,39 +40,21 @@ NULL
#' * Inside your function, you can call any other binding with `call_binding()`.
#'
#' @param fun_name A string containing a function name in the form `"function"` or
#' `"package::function"`. The package name is currently not used but
#' may be used in the future to allow these types of function calls.
#' @param fun A function or `NULL` to un-register a previous function.
#' `"package::function"`.
#' @param fun A function, or `NULL` to un-register a previous function.
#' This function must accept `Expression` objects as arguments and return
#' `Expression` objects instead of regular R objects.
#' @param agg_fun An aggregate function or `NULL` to un-register a previous
#' aggregate function. This function must accept `Expression` objects as
#' arguments and return a `list()` with components:
#' - `fun`: string function name
#' - `data`: list of 0 or more `Expression`s
#' - `options`: list of function options, as passed to call_function
#' @param update_cache Update .cache$functions at the time of registration.
#' the default is FALSE because the majority of usage is to register
#' bindings at package load, after which we create the cache once. The
#' reason why .cache$functions is needed in addition to nse_funcs for
#' non-aggregate functions could be revisited...it is currently used
#' as the data mask in mutate, filter, and aggregate (but not
#' summarise) because the data mask has to be a list.
#' @param registry An environment in which the functions should be
#' assigned.
#' @param notes string for the docs: note any limitations or differences in
#' behavior between the Arrow version and the R function.
#' @return The previously registered binding or `NULL` if no previously
#' registered function existed.
#' @keywords internal
register_binding <- function(fun_name,
fun,
registry = nse_funcs,
update_cache = FALSE,
notes = character(0)) {
unqualified_name <- sub("^.*?:{+}", "", fun_name)

previous_fun <- registry[[unqualified_name]]
previous_fun <- .cache$functions[[unqualified_name]]

# if the unqualified name exists in the registry, warn
if (!is.null(previous_fun) && !identical(fun, previous_fun)) {
Expand All @@ -87,58 +69,25 @@ register_binding <- function(fun_name,

# register both as `pkg::fun` and as `fun` if `qualified_name` is prefixed
# unqualified_name and fun_name will be the same if not prefixed
registry[[unqualified_name]] <- fun
registry[[fun_name]] <- fun

.cache$functions[[unqualified_name]] <- fun
.cache$functions[[fun_name]] <- fun
.cache$docs[[fun_name]] <- notes

if (update_cache) {
fun_cache <- .cache$functions
fun_cache[[unqualified_name]] <- fun
fun_cache[[fun_name]] <- fun
.cache$functions <- fun_cache
}

invisible(previous_fun)
}

unregister_binding <- function(fun_name, registry = nse_funcs,
update_cache = FALSE) {
unregister_binding <- function(fun_name) {
unqualified_name <- sub("^.*?:{+}", "", fun_name)
previous_fun <- registry[[unqualified_name]]
previous_fun <- .cache$functions[[unqualified_name]]

rm(
list = unique(c(fun_name, unqualified_name)),
envir = registry,
inherits = FALSE
)

if (update_cache) {
fun_cache <- .cache$functions
fun_cache[[unqualified_name]] <- NULL
fun_cache[[fun_name]] <- NULL
.cache$functions <- fun_cache
}
.cache$functions[[unqualified_name]] <- NULL
.cache$functions[[fun_name]] <- NULL

invisible(previous_fun)
}

#' @rdname register_binding
#' @keywords internal
register_binding_agg <- function(fun_name,
agg_fun,
registry = agg_funcs,
notes = character(0)) {
register_binding(fun_name, agg_fun, registry = registry, notes = notes)
}

# Supports functions and tests that call previously-defined bindings
call_binding <- function(fun_name, ...) {
nse_funcs[[fun_name]](...)
}

call_binding_agg <- function(fun_name, ...) {
agg_funcs[[fun_name]](...)
.cache$functions[[fun_name]](...)
}

create_binding_cache <- function() {
Expand All @@ -147,15 +96,15 @@ create_binding_cache <- function() {

# Register all available Arrow Compute functions, namespaced as arrow_fun.
all_arrow_funs <- list_compute_functions()
arrow_funcs <- set_names(
.cache$functions <- set_names(
lapply(all_arrow_funs, function(fun) {
force(fun)
function(...) Expression$create(fun, ...)
}),
paste0("arrow_", all_arrow_funs)
)

# Register bindings into nse_funcs and agg_funcs
# Register bindings into the cache
register_bindings_array_function_map()
register_bindings_aggregate()
register_bindings_conditional()
Expand All @@ -165,37 +114,17 @@ create_binding_cache <- function() {
register_bindings_type()
register_bindings_augmented()

# We only create the cache for nse_funcs and not agg_funcs
.cache$functions <- c(as.list(nse_funcs), arrow_funcs)
}

# environments in the arrow namespace used in the above functions
nse_funcs <- new.env(parent = emptyenv())
agg_funcs <- new.env(parent = emptyenv())
.cache <- new.env(parent = emptyenv())

# we register 2 versions of the "::" binding - one for use with nse_funcs
# and another one for use with agg_funcs (registered in dplyr-funcs-agg.R)
nse_funcs[["::"]] <- function(lhs, rhs) {
lhs_name <- as.character(substitute(lhs))
rhs_name <- as.character(substitute(rhs))
.cache$functions[["::"]] <- function(lhs, rhs) {
lhs_name <- as.character(substitute(lhs))
rhs_name <- as.character(substitute(rhs))

fun_name <- paste0(lhs_name, "::", rhs_name)
fun_name <- paste0(lhs_name, "::", rhs_name)

# if we do not have a binding for pkg::fun, then fall back on to the
# regular pkg::fun function
nse_funcs[[fun_name]] %||% asNamespace(lhs_name)[[rhs_name]]
# if we do not have a binding for pkg::fun, then fall back on to the
# regular pkg::fun function
.cache$functions[[fun_name]] %||% asNamespace(lhs_name)[[rhs_name]]
}
}

agg_funcs[["::"]] <- function(lhs, rhs) {
lhs_name <- as.character(substitute(lhs))
rhs_name <- as.character(substitute(rhs))

fun_name <- paste0(lhs_name, "::", rhs_name)

# if we do not have a binding for pkg::fun, then fall back on to the
# nse_funcs (useful when we have a regular function inside an aggregating one)
# and then, if searching nse_funcs fails too, fall back to the
# regular `pkg::fun()` function
agg_funcs[[fun_name]] %||% nse_funcs[[fun_name]] %||% asNamespace(lhs_name)[[rhs_name]]
}
# environment in the arrow namespace used in the above functions
.cache <- new.env(parent = emptyenv())
Loading

0 comments on commit 03f8ae7

Please sign in to comment.