Skip to content

Commit

Permalink
add extract_regex_span function
Browse files Browse the repository at this point in the history
  • Loading branch information
arashandishgar committed Feb 19, 2025
1 parent c7a9100 commit 228c2c5
Show file tree
Hide file tree
Showing 4 changed files with 232 additions and 35 deletions.
11 changes: 10 additions & 1 deletion cpp/src/arrow/compute/api_scalar.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
#include "arrow/type.h"
#include "arrow/util/checked_cast.h"
#include "arrow/util/logging.h"

namespace arrow {

namespace internal {
Expand Down Expand Up @@ -325,6 +324,9 @@ static auto kElementWiseAggregateOptionsType =
DataMember("skip_nulls", &ElementWiseAggregateOptions::skip_nulls));
static auto kExtractRegexOptionsType = GetFunctionOptionsType<ExtractRegexOptions>(
DataMember("pattern", &ExtractRegexOptions::pattern));
static auto kExtractRegexSpanOptionsType =
GetFunctionOptionsType<ExtractRegexSpanOptions>(
DataMember("pattern", &ExtractRegexSpanOptions::pattern));
static auto kJoinOptionsType = GetFunctionOptionsType<JoinOptions>(
DataMember("null_handling", &JoinOptions::null_handling),
DataMember("null_replacement", &JoinOptions::null_replacement));
Expand Down Expand Up @@ -438,6 +440,12 @@ ExtractRegexOptions::ExtractRegexOptions(std::string pattern)
ExtractRegexOptions::ExtractRegexOptions() : ExtractRegexOptions("") {}
constexpr char ExtractRegexOptions::kTypeName[];

ExtractRegexSpanOptions::ExtractRegexSpanOptions(std::string pattern)
: FunctionOptions(internal::kExtractRegexSpanOptionsType),
pattern(std::move(pattern)) {}
ExtractRegexSpanOptions::ExtractRegexSpanOptions() : ExtractRegexSpanOptions("") {}
constexpr char ExtractRegexSpanOptions::kTypeName[];

JoinOptions::JoinOptions(NullHandlingBehavior null_handling, std::string null_replacement)
: FunctionOptions(internal::kJoinOptionsType),
null_handling(null_handling),
Expand Down Expand Up @@ -684,6 +692,7 @@ void RegisterScalarOptions(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunctionOptionsType(kDayOfWeekOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kElementWiseAggregateOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kExtractRegexOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kExtractRegexSpanOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kJoinOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kListSliceOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kMakeStructOptionsType));
Expand Down
10 changes: 10 additions & 0 deletions cpp/src/arrow/compute/api_scalar.h
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,17 @@ class ARROW_EXPORT ExtractRegexOptions : public FunctionOptions {
/// Regular expression with named capture fields
std::string pattern;
};
class ARROW_EXPORT ExtractRegexSpanOptions : public FunctionOptions {
public:
explicit ExtractRegexSpanOptions(std::string pattern);
ExtractRegexSpanOptions();
static constexpr char const kTypeName[] = "ExtractRegexSpanOptions";

/// Regular expression with named capture fields
std::string pattern;

/// Shows the matched string
};
/// Options for IsIn and IndexIn functions
class ARROW_EXPORT SetLookupOptions : public FunctionOptions {
public:
Expand Down
197 changes: 164 additions & 33 deletions cpp/src/arrow/compute/kernels/scalar_string_ascii.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <string>

#include "arrow/array/builder_nested.h"
#include "arrow/array/builder_primitive.h"
#include "arrow/compute/kernels/scalar_string_internal.h"
#include "arrow/result.h"
#include "arrow/util/config.h"
Expand Down Expand Up @@ -2185,51 +2186,61 @@ void AddAsciiStringReplaceSubstring(FunctionRegistry* registry) {
using ExtractRegexState = OptionsWrapper<ExtractRegexOptions>;

// TODO cache this once per ExtractRegexOptions
struct ExtractRegexData {
// Use unique_ptr<> because RE2 is non-movable (for ARROW_ASSIGN_OR_RAISE)
std::unique_ptr<RE2> regex;
std::vector<std::string> group_names;

class ExtractRegexData {
public:
static Result<ExtractRegexData> Make(const ExtractRegexOptions& options,
bool is_utf8 = true) {
ExtractRegexData data(options.pattern, is_utf8);
RETURN_NOT_OK(RegexStatus(*data.regex));

const int group_count = data.regex->NumberOfCapturingGroups();
const auto& name_map = data.regex->CapturingGroupNames();
data.group_names.reserve(group_count);

for (int i = 0; i < group_count; i++) {
auto item = name_map.find(i + 1); // re2 starts counting from 1
if (item == name_map.end()) {
// XXX should we instead just create fields with an empty name?
return Status::Invalid("Regular expression contains unnamed groups");
}
data.group_names.emplace_back(item->second);
}
ARROW_RETURN_NOT_OK(data.Init());
return data;
}

Result<TypeHolder> ResolveOutputType(const std::vector<TypeHolder>& types) const {
const DataType* input_type = types[0].type;
if (input_type == nullptr) {
// as mentioned here
// https://arrow.apache.org/docs/developers/cpp/development.html#code-style-linting-and-ci
// nullptr should not be used
if (input_type == NULLPTR) {
// No input type specified
return nullptr;
return NULLPTR;
}
// Input type is either [Large]Binary or [Large]String and is also the type
// of each field in the output struct type.
DCHECK(is_base_binary_like(input_type->id()));
FieldVector fields;
fields.reserve(group_names.size());
fields.reserve(group_names_.size());
std::shared_ptr<DataType> owned_type = input_type->GetSharedPtr();
std::transform(group_names.begin(), group_names.end(), std::back_inserter(fields),
std::transform(group_names_.begin(), group_names_.end(), std::back_inserter(fields),
[&](const std::string& name) { return field(name, owned_type); });
return struct_(std::move(fields));
return struct_(fields);
}
int64_t num_group() const { return group_names_.size(); }
std::shared_ptr<RE2> regex() const { return regex_; }

private:
protected:
explicit ExtractRegexData(const std::string& pattern, bool is_utf8 = true)
: regex(new RE2(pattern, MakeRE2Options(is_utf8))) {}
: regex_(new RE2(pattern, MakeRE2Options(is_utf8))) {}

Status Init() {
RETURN_NOT_OK(RegexStatus(*regex_));

const int group_count = regex_->NumberOfCapturingGroups();
const auto& name_map = regex_->CapturingGroupNames();
group_names_.reserve(group_count);

for (int i = 0; i < group_count; i++) {
auto item = name_map.find(i + 1); // re2 starts counting from 1
if (item == name_map.end()) {
// XXX should we instead just create fields with an empty name?
return Status::Invalid("Regular expression contains unnamed groups");
}
group_names_.emplace_back(item->second);
}
return Status::OK();
}

std::shared_ptr<RE2> regex_;
std::vector<std::string> group_names_;
};

Result<TypeHolder> ResolveExtractRegexOutput(KernelContext* ctx,
Expand All @@ -2250,7 +2261,7 @@ struct ExtractRegexBase {

explicit ExtractRegexBase(const ExtractRegexData& data)
: data(data),
group_count(static_cast<int>(data.group_names.size())),
group_count(static_cast<int>(data.num_group())),
found_values(group_count) {
args.reserve(group_count);
args_pointers.reserve(group_count);
Expand All @@ -2265,7 +2276,7 @@ struct ExtractRegexBase {
}

bool Match(std::string_view s) {
return RE2::PartialMatchN(ToStringPiece(s), *data.regex, args_pointers_start,
return RE2::PartialMatchN(ToStringPiece(s), *data.regex(), args_pointers_start,
group_count);
}
};
Expand All @@ -2284,11 +2295,10 @@ struct ExtractRegex : public ExtractRegexBase {
}

Status Extract(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
// TODO: why is this needed? Type resolution should already be
// done and the output type set in the output variable
ARROW_ASSIGN_OR_RAISE(TypeHolder out_type, data.ResolveOutputType(batch.GetTypes()));
DCHECK_NE(out_type.type, nullptr);
std::shared_ptr<DataType> type = out_type.GetSharedPtr();
ExtractRegexOptions options = ExtractRegexState::Get(ctx);
DCHECK_NE(out->array_data(), NULLPTR);
std::shared_ptr<DataType> type = out->array_data()->type;
DCHECK_NE(type, NULLPTR);

std::unique_ptr<ArrayBuilder> array_builder;
RETURN_NOT_OK(MakeBuilder(ctx->memory_pool(), type, &array_builder));
Expand Down Expand Up @@ -2347,6 +2357,126 @@ void AddAsciiStringExtractRegex(FunctionRegistry* registry) {
}
DCHECK_OK(registry->AddFunction(std::move(func)));
}
class ExtractRegexSpanData : public ExtractRegexData {
public:
static Result<ExtractRegexSpanData> Make(const std::string& pattern) {
auto data = ExtractRegexSpanData(pattern, true);
ARROW_RETURN_NOT_OK(data.Init());
return data;
}

Result<TypeHolder> ResolveOutputType(const std::vector<TypeHolder>& types) const {
const DataType* input_type = types[0].type;
if (input_type == NULLPTR) {
return NULLPTR;
}
DCHECK(is_base_binary_like(input_type->id()));
const size_t field_count = group_names_.size();
FieldVector fields;
fields.reserve(field_count);
const auto owned_type = input_type->GetSharedPtr();
for (const auto& group_name : group_names_) {
auto type = is_binary_like(owned_type->id()) ? int32() : int64();
// size list is 2 as every span contains position and length
fields.push_back(field(group_name + "_span", fixed_size_list(type, 2)));
}
return struct_(fields);
}

private:
ExtractRegexSpanData(const std::string& pattern, const bool is_utf8)
: ExtractRegexData(pattern, is_utf8) {}
};

template <typename Type>
struct ExtractRegexSpan : ExtractRegexBase {
using ArrayType = typename TypeTraits<Type>::ArrayType;
using BuilderType = typename TypeTraits<Type>::BuilderType;
using ExtractRegexBase::ExtractRegexBase;

static Status Exec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
auto options = OptionsWrapper<ExtractRegexSpanOptions>::Get(ctx);
ARROW_ASSIGN_OR_RAISE(auto data, ExtractRegexSpanData::Make(options.pattern));
return ExtractRegexSpan{data}.Extract(ctx, batch, out);
}
Status Extract(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
DCHECK_NE(out->array_data(), NULLPTR);
std::shared_ptr<DataType> out_type = out->array_data()->type;
DCHECK_NE(out_type, NULLPTR);
std::unique_ptr<ArrayBuilder> out_builder;
ARROW_RETURN_NOT_OK(
MakeBuilder(ctx->memory_pool(), out->type()->GetSharedPtr(), &out_builder));
auto struct_builder = checked_pointer_cast<StructBuilder>(std::move(out_builder));
std::vector<FixedSizeListBuilder*> span_builders;
std::vector<ArrayBuilder*> array_builders;
span_builders.reserve(group_count);
array_builders.reserve(group_count);
for (int i = 0; i < group_count; i++) {
span_builders.push_back(
checked_cast<FixedSizeListBuilder*>(struct_builder->field_builder(i)));
array_builders.push_back(span_builders[i]->value_builder());
}
auto visit_null = [&]() { return struct_builder->AppendNull(); };
auto visit_value = [&](std::string_view element) -> Status {
if (Match(element)) {
for (int i = 0; i < group_count; i++) {
// https://github.com/google/re2/issues/24#issuecomment-97653183
if (found_values[i].data() != NULLPTR) {
int64_t begin = found_values[i].data() - element.data();
int64_t size = found_values[i].size();
if (is_binary_like(batch.GetTypes()[0].id())) {
ARROW_RETURN_NOT_OK(checked_cast<Int32Builder*>(array_builders[i])
->AppendValues({static_cast<int32_t>(begin),
static_cast<int32_t>(size)}));
} else {
ARROW_RETURN_NOT_OK(checked_cast<Int64Builder*>(array_builders[i])
->AppendValues({begin, size}));
}

ARROW_RETURN_NOT_OK(span_builders[i]->Append());
} else {
ARROW_RETURN_NOT_OK(span_builders[i]->AppendNull());
}
}
ARROW_RETURN_NOT_OK(struct_builder->Append());
} else {
ARROW_RETURN_NOT_OK(struct_builder->AppendNull());
}
return Status::OK();
};
ARROW_RETURN_NOT_OK(
VisitArraySpanInline<Type>(batch[0].array, visit_value, visit_null));

ARROW_ASSIGN_OR_RAISE(auto out_array, struct_builder->Finish());
out->value = out_array->data();
return Status::OK();
}
};

const FunctionDoc extract_regex_doc_span(
"likes extract_regex; however, it contains the position and length of results", "",
{"strings"}, "ExtractRegexSpanOptions", true);

Result<TypeHolder> resolver(KernelContext* ctx, const std::vector<TypeHolder>& types) {
auto options = OptionsWrapper<ExtractRegexSpanOptions>::Get(*ctx->state());
ARROW_ASSIGN_OR_RAISE(auto span, ExtractRegexSpanData::Make(options.pattern));
return span.ResolveOutputType(types);
}

void AddAsciiStringExtractRegexSpan(FunctionRegistry* registry) {
auto func = std::make_shared<ScalarFunction>("extract_regex_span", Arity::Unary(),
extract_regex_doc_span);
OutputType output_type(resolver);
for (const auto& type : BaseBinaryTypes()) {
ScalarKernel kernel({type}, output_type,
GenerateVarBinaryToVarBinary<ExtractRegexSpan>(type),
OptionsWrapper<ExtractRegexSpanOptions>::Init);
kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE;
kernel.mem_allocation = MemAllocation::NO_PREALLOCATE;
DCHECK_OK(func->AddKernel(std::move(kernel)));
}
DCHECK_OK(registry->AddFunction(func));
}
#endif // ARROW_WITH_RE2

// ----------------------------------------------------------------------
Expand Down Expand Up @@ -3457,6 +3587,7 @@ void RegisterScalarStringAscii(FunctionRegistry* registry) {
AddAsciiStringSplitWhitespace(registry);
#ifdef ARROW_WITH_RE2
AddAsciiStringSplitRegex(registry);
AddAsciiStringExtractRegexSpan(registry);
#endif
AddAsciiStringJoin(registry);
AddAsciiStringRepeat(registry);
Expand Down
Loading

0 comments on commit 228c2c5

Please sign in to comment.