From 228c2c5b0924c43fdd99bd10bfe851354a1aee62 Mon Sep 17 00:00:00 2001 From: arash andishgar Date: Wed, 19 Feb 2025 16:47:29 +0330 Subject: [PATCH] add extract_regex_span function --- cpp/src/arrow/compute/api_scalar.cc | 11 +- cpp/src/arrow/compute/api_scalar.h | 10 + .../compute/kernels/scalar_string_ascii.cc | 197 +++++++++++++++--- .../compute/kernels/scalar_string_test.cc | 49 ++++- 4 files changed, 232 insertions(+), 35 deletions(-) diff --git a/cpp/src/arrow/compute/api_scalar.cc b/cpp/src/arrow/compute/api_scalar.cc index 61a16f5f5eb9b..e6606ba53eda8 100644 --- a/cpp/src/arrow/compute/api_scalar.cc +++ b/cpp/src/arrow/compute/api_scalar.cc @@ -30,7 +30,6 @@ #include "arrow/type.h" #include "arrow/util/checked_cast.h" #include "arrow/util/logging.h" - namespace arrow { namespace internal { @@ -325,6 +324,9 @@ static auto kElementWiseAggregateOptionsType = DataMember("skip_nulls", &ElementWiseAggregateOptions::skip_nulls)); static auto kExtractRegexOptionsType = GetFunctionOptionsType( DataMember("pattern", &ExtractRegexOptions::pattern)); +static auto kExtractRegexSpanOptionsType = + GetFunctionOptionsType( + DataMember("pattern", &ExtractRegexSpanOptions::pattern)); static auto kJoinOptionsType = GetFunctionOptionsType( DataMember("null_handling", &JoinOptions::null_handling), DataMember("null_replacement", &JoinOptions::null_replacement)); @@ -438,6 +440,12 @@ ExtractRegexOptions::ExtractRegexOptions(std::string pattern) ExtractRegexOptions::ExtractRegexOptions() : ExtractRegexOptions("") {} constexpr char ExtractRegexOptions::kTypeName[]; +ExtractRegexSpanOptions::ExtractRegexSpanOptions(std::string pattern) + : FunctionOptions(internal::kExtractRegexSpanOptionsType), + pattern(std::move(pattern)) {} +ExtractRegexSpanOptions::ExtractRegexSpanOptions() : ExtractRegexSpanOptions("") {} +constexpr char ExtractRegexSpanOptions::kTypeName[]; + JoinOptions::JoinOptions(NullHandlingBehavior null_handling, std::string null_replacement) : FunctionOptions(internal::kJoinOptionsType), null_handling(null_handling), @@ -684,6 +692,7 @@ void RegisterScalarOptions(FunctionRegistry* registry) { DCHECK_OK(registry->AddFunctionOptionsType(kDayOfWeekOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kElementWiseAggregateOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kExtractRegexOptionsType)); + DCHECK_OK(registry->AddFunctionOptionsType(kExtractRegexSpanOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kJoinOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kListSliceOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kMakeStructOptionsType)); diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h index 0e5a388b1074f..3e299ac134f79 100644 --- a/cpp/src/arrow/compute/api_scalar.h +++ b/cpp/src/arrow/compute/api_scalar.h @@ -264,7 +264,17 @@ class ARROW_EXPORT ExtractRegexOptions : public FunctionOptions { /// Regular expression with named capture fields std::string pattern; }; +class ARROW_EXPORT ExtractRegexSpanOptions : public FunctionOptions { + public: + explicit ExtractRegexSpanOptions(std::string pattern); + ExtractRegexSpanOptions(); + static constexpr char const kTypeName[] = "ExtractRegexSpanOptions"; + /// Regular expression with named capture fields + std::string pattern; + + /// Shows the matched string +}; /// Options for IsIn and IndexIn functions class ARROW_EXPORT SetLookupOptions : public FunctionOptions { public: diff --git a/cpp/src/arrow/compute/kernels/scalar_string_ascii.cc b/cpp/src/arrow/compute/kernels/scalar_string_ascii.cc index e58f7b065a8e5..535bc7979a380 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_ascii.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string_ascii.cc @@ -22,6 +22,7 @@ #include #include "arrow/array/builder_nested.h" +#include "arrow/array/builder_primitive.h" #include "arrow/compute/kernels/scalar_string_internal.h" #include "arrow/result.h" #include "arrow/util/config.h" @@ -2185,51 +2186,61 @@ void AddAsciiStringReplaceSubstring(FunctionRegistry* registry) { using ExtractRegexState = OptionsWrapper; // TODO cache this once per ExtractRegexOptions -struct ExtractRegexData { - // Use unique_ptr<> because RE2 is non-movable (for ARROW_ASSIGN_OR_RAISE) - std::unique_ptr regex; - std::vector group_names; - +class ExtractRegexData { + public: static Result Make(const ExtractRegexOptions& options, bool is_utf8 = true) { ExtractRegexData data(options.pattern, is_utf8); - RETURN_NOT_OK(RegexStatus(*data.regex)); - - const int group_count = data.regex->NumberOfCapturingGroups(); - const auto& name_map = data.regex->CapturingGroupNames(); - data.group_names.reserve(group_count); - - for (int i = 0; i < group_count; i++) { - auto item = name_map.find(i + 1); // re2 starts counting from 1 - if (item == name_map.end()) { - // XXX should we instead just create fields with an empty name? - return Status::Invalid("Regular expression contains unnamed groups"); - } - data.group_names.emplace_back(item->second); - } + ARROW_RETURN_NOT_OK(data.Init()); return data; } Result ResolveOutputType(const std::vector& types) const { const DataType* input_type = types[0].type; - if (input_type == nullptr) { + // as mentioned here + // https://arrow.apache.org/docs/developers/cpp/development.html#code-style-linting-and-ci + // nullptr should not be used + if (input_type == NULLPTR) { // No input type specified - return nullptr; + return NULLPTR; } // Input type is either [Large]Binary or [Large]String and is also the type // of each field in the output struct type. DCHECK(is_base_binary_like(input_type->id())); FieldVector fields; - fields.reserve(group_names.size()); + fields.reserve(group_names_.size()); std::shared_ptr owned_type = input_type->GetSharedPtr(); - std::transform(group_names.begin(), group_names.end(), std::back_inserter(fields), + std::transform(group_names_.begin(), group_names_.end(), std::back_inserter(fields), [&](const std::string& name) { return field(name, owned_type); }); - return struct_(std::move(fields)); + return struct_(fields); } + int64_t num_group() const { return group_names_.size(); } + std::shared_ptr regex() const { return regex_; } - private: + protected: explicit ExtractRegexData(const std::string& pattern, bool is_utf8 = true) - : regex(new RE2(pattern, MakeRE2Options(is_utf8))) {} + : regex_(new RE2(pattern, MakeRE2Options(is_utf8))) {} + + Status Init() { + RETURN_NOT_OK(RegexStatus(*regex_)); + + const int group_count = regex_->NumberOfCapturingGroups(); + const auto& name_map = regex_->CapturingGroupNames(); + group_names_.reserve(group_count); + + for (int i = 0; i < group_count; i++) { + auto item = name_map.find(i + 1); // re2 starts counting from 1 + if (item == name_map.end()) { + // XXX should we instead just create fields with an empty name? + return Status::Invalid("Regular expression contains unnamed groups"); + } + group_names_.emplace_back(item->second); + } + return Status::OK(); + } + + std::shared_ptr regex_; + std::vector group_names_; }; Result ResolveExtractRegexOutput(KernelContext* ctx, @@ -2250,7 +2261,7 @@ struct ExtractRegexBase { explicit ExtractRegexBase(const ExtractRegexData& data) : data(data), - group_count(static_cast(data.group_names.size())), + group_count(static_cast(data.num_group())), found_values(group_count) { args.reserve(group_count); args_pointers.reserve(group_count); @@ -2265,7 +2276,7 @@ struct ExtractRegexBase { } bool Match(std::string_view s) { - return RE2::PartialMatchN(ToStringPiece(s), *data.regex, args_pointers_start, + return RE2::PartialMatchN(ToStringPiece(s), *data.regex(), args_pointers_start, group_count); } }; @@ -2284,11 +2295,10 @@ struct ExtractRegex : public ExtractRegexBase { } Status Extract(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { - // TODO: why is this needed? Type resolution should already be - // done and the output type set in the output variable - ARROW_ASSIGN_OR_RAISE(TypeHolder out_type, data.ResolveOutputType(batch.GetTypes())); - DCHECK_NE(out_type.type, nullptr); - std::shared_ptr type = out_type.GetSharedPtr(); + ExtractRegexOptions options = ExtractRegexState::Get(ctx); + DCHECK_NE(out->array_data(), NULLPTR); + std::shared_ptr type = out->array_data()->type; + DCHECK_NE(type, NULLPTR); std::unique_ptr array_builder; RETURN_NOT_OK(MakeBuilder(ctx->memory_pool(), type, &array_builder)); @@ -2347,6 +2357,126 @@ void AddAsciiStringExtractRegex(FunctionRegistry* registry) { } DCHECK_OK(registry->AddFunction(std::move(func))); } +class ExtractRegexSpanData : public ExtractRegexData { + public: + static Result Make(const std::string& pattern) { + auto data = ExtractRegexSpanData(pattern, true); + ARROW_RETURN_NOT_OK(data.Init()); + return data; + } + + Result ResolveOutputType(const std::vector& types) const { + const DataType* input_type = types[0].type; + if (input_type == NULLPTR) { + return NULLPTR; + } + DCHECK(is_base_binary_like(input_type->id())); + const size_t field_count = group_names_.size(); + FieldVector fields; + fields.reserve(field_count); + const auto owned_type = input_type->GetSharedPtr(); + for (const auto& group_name : group_names_) { + auto type = is_binary_like(owned_type->id()) ? int32() : int64(); + // size list is 2 as every span contains position and length + fields.push_back(field(group_name + "_span", fixed_size_list(type, 2))); + } + return struct_(fields); + } + + private: + ExtractRegexSpanData(const std::string& pattern, const bool is_utf8) + : ExtractRegexData(pattern, is_utf8) {} +}; + +template +struct ExtractRegexSpan : ExtractRegexBase { + using ArrayType = typename TypeTraits::ArrayType; + using BuilderType = typename TypeTraits::BuilderType; + using ExtractRegexBase::ExtractRegexBase; + + static Status Exec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { + auto options = OptionsWrapper::Get(ctx); + ARROW_ASSIGN_OR_RAISE(auto data, ExtractRegexSpanData::Make(options.pattern)); + return ExtractRegexSpan{data}.Extract(ctx, batch, out); + } + Status Extract(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { + DCHECK_NE(out->array_data(), NULLPTR); + std::shared_ptr out_type = out->array_data()->type; + DCHECK_NE(out_type, NULLPTR); + std::unique_ptr out_builder; + ARROW_RETURN_NOT_OK( + MakeBuilder(ctx->memory_pool(), out->type()->GetSharedPtr(), &out_builder)); + auto struct_builder = checked_pointer_cast(std::move(out_builder)); + std::vector span_builders; + std::vector array_builders; + span_builders.reserve(group_count); + array_builders.reserve(group_count); + for (int i = 0; i < group_count; i++) { + span_builders.push_back( + checked_cast(struct_builder->field_builder(i))); + array_builders.push_back(span_builders[i]->value_builder()); + } + auto visit_null = [&]() { return struct_builder->AppendNull(); }; + auto visit_value = [&](std::string_view element) -> Status { + if (Match(element)) { + for (int i = 0; i < group_count; i++) { + // https://github.com/google/re2/issues/24#issuecomment-97653183 + if (found_values[i].data() != NULLPTR) { + int64_t begin = found_values[i].data() - element.data(); + int64_t size = found_values[i].size(); + if (is_binary_like(batch.GetTypes()[0].id())) { + ARROW_RETURN_NOT_OK(checked_cast(array_builders[i]) + ->AppendValues({static_cast(begin), + static_cast(size)})); + } else { + ARROW_RETURN_NOT_OK(checked_cast(array_builders[i]) + ->AppendValues({begin, size})); + } + + ARROW_RETURN_NOT_OK(span_builders[i]->Append()); + } else { + ARROW_RETURN_NOT_OK(span_builders[i]->AppendNull()); + } + } + ARROW_RETURN_NOT_OK(struct_builder->Append()); + } else { + ARROW_RETURN_NOT_OK(struct_builder->AppendNull()); + } + return Status::OK(); + }; + ARROW_RETURN_NOT_OK( + VisitArraySpanInline(batch[0].array, visit_value, visit_null)); + + ARROW_ASSIGN_OR_RAISE(auto out_array, struct_builder->Finish()); + out->value = out_array->data(); + return Status::OK(); + } +}; + +const FunctionDoc extract_regex_doc_span( + "likes extract_regex; however, it contains the position and length of results", "", + {"strings"}, "ExtractRegexSpanOptions", true); + +Result resolver(KernelContext* ctx, const std::vector& types) { + auto options = OptionsWrapper::Get(*ctx->state()); + ARROW_ASSIGN_OR_RAISE(auto span, ExtractRegexSpanData::Make(options.pattern)); + return span.ResolveOutputType(types); +} + +void AddAsciiStringExtractRegexSpan(FunctionRegistry* registry) { + auto func = std::make_shared("extract_regex_span", Arity::Unary(), + extract_regex_doc_span); + OutputType output_type(resolver); + for (const auto& type : BaseBinaryTypes()) { + ScalarKernel kernel({type}, output_type, + GenerateVarBinaryToVarBinary(type), + OptionsWrapper::Init); + kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE; + kernel.mem_allocation = MemAllocation::NO_PREALLOCATE; + DCHECK_OK(func->AddKernel(std::move(kernel))); + } + DCHECK_OK(registry->AddFunction(func)); +} #endif // ARROW_WITH_RE2 // ---------------------------------------------------------------------- @@ -3457,6 +3587,7 @@ void RegisterScalarStringAscii(FunctionRegistry* registry) { AddAsciiStringSplitWhitespace(registry); #ifdef ARROW_WITH_RE2 AddAsciiStringSplitRegex(registry); + AddAsciiStringExtractRegexSpan(registry); #endif AddAsciiStringJoin(registry); AddAsciiStringRepeat(registry); diff --git a/cpp/src/arrow/compute/kernels/scalar_string_test.cc b/cpp/src/arrow/compute/kernels/scalar_string_test.cc index 38455dc146711..4023491dee5b0 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string_test.cc @@ -314,6 +314,7 @@ TYPED_TEST(TestBinaryKernels, NonUtf8Regex) { this->MakeArray({"\xfc\x40", "this \xfc\x40 that \xfc\x40"}), this->MakeArray({"bazz", "this bazz that \xfc\x40"}), &options); } + // TODO the following test is broken { ExtractRegexOptions options("(?P[\\xfc])(?P\\d)"); auto null_bitmap = std::make_shared("0"); @@ -370,6 +371,7 @@ TYPED_TEST(TestBinaryKernels, NonUtf8WithNullRegex) { this->template MakeArray({{"\x00\x40", 2}}), this->type(), R"(["bazz"])", &options); } + // TODO the following test is broken { ExtractRegexOptions options("(?P[\\x00])(?P\\d)"); auto null_bitmap = std::make_shared("0"); @@ -1958,6 +1960,29 @@ TYPED_TEST(TestBaseBinaryKernels, ExtractRegex) { R"([{"letter": "a", "digit": "1"}, {"letter": "b", "digit": "3"}])", &options); } +TYPED_TEST(TestBaseBinaryKernels, ExtractRegexSapn) { + ExtractRegexSpanOptions options{"(?P[ab])(?P\\d)"}; + auto type_fixe_size_list = is_binary_like(this->type()->id()) ? int32() : int64(); + auto out_type = struct_({field("letter_span", fixed_size_list(type_fixe_size_list, 2)), + field("digit_span", fixed_size_list(type_fixe_size_list, 2))}); + this->CheckUnary("extract_regex_span", R"([])", out_type, R"([])", &options); + this->CheckUnary( + "extract_regex_span", R"(["a1", "b2", "c3", null])", out_type, + R"([{"letter_span":[0,1], "digit_span":[1,1]}, {"letter_span":[0,1], "digit_span":[1,1]}, null, null])", + &options); + this->CheckUnary( + "extract_regex_span", R"(["a1", "c3", null, "b2"])", out_type, + R"([{"letter_span":[0,1], "digit_span": [1,1]}, null, null, {"letter_span":[0,1], "digit_span":[1,1]}])", + &options); + this->CheckUnary( + "extract_regex_span", R"(["a1", "b2"])", out_type, + R"([{"letter_span": [0,1], "digit_span": [1,1]}, {"letter_span": [0,1], "digit_span": [1,1]}])", + &options); + this->CheckUnary( + "extract_regex_span", R"(["a1", "zb3z"])", out_type, + R"([{"letter_span": [0,1], "digit_span": [1,1]}, {"letter_span": [1,1], "digit_span": [2,1]}])", + &options); +} TYPED_TEST(TestBaseBinaryKernels, ExtractRegexNoCapture) { // XXX Should we accept this or is it a user error? @@ -1966,12 +1991,23 @@ TYPED_TEST(TestBaseBinaryKernels, ExtractRegexNoCapture) { this->CheckUnary("extract_regex", R"(["oofoo", "bar", null])", type, R"([{}, null, null])", &options); } - +TYPED_TEST(TestBaseBinaryKernels, ExtractRegexSpanNoCapture) { + // XXX Should we accept this or is it a user error? + ExtractRegexSpanOptions options{"foo"}; + auto type = struct_({}); + this->CheckUnary("extract_regex_span", R"(["oofoo", "bar", null])", type, + R"([{}, null, null])", &options); +} TYPED_TEST(TestBaseBinaryKernels, ExtractRegexNoOptions) { Datum input = ArrayFromJSON(this->type(), "[]"); ASSERT_RAISES(Invalid, CallFunction("extract_regex", {input})); } +TYPED_TEST(TestBaseBinaryKernels, ExtractRegexSpanNoOptions) { + Datum input = ArrayFromJSON(this->type(), "[]"); + ASSERT_RAISES(Invalid, CallFunction("extract_regex_span", {input})); +} + TYPED_TEST(TestBaseBinaryKernels, ExtractRegexInvalid) { Datum input = ArrayFromJSON(this->type(), "[]"); ExtractRegexOptions options{"invalid["}; @@ -1984,6 +2020,17 @@ TYPED_TEST(TestBaseBinaryKernels, ExtractRegexInvalid) { Invalid, ::testing::HasSubstr("Regular expression contains unnamed groups"), CallFunction("extract_regex", {input}, &options)); } +TYPED_TEST(TestBaseBinaryKernels, ExtractRegexSpanInvalid) { + Datum input = ArrayFromJSON(this->type(), "[]"); + ExtractRegexSpanOptions options{"invalid["}; + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, ::testing::HasSubstr("Invalid regular expression: missing ]"), + CallFunction("extract_regex_span", {input}, &options)); + options = ExtractRegexSpanOptions{"(.)"}; + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, ::testing::HasSubstr("Regular expression contains unnamed groups"), + CallFunction("extract_regex_span", {input}, &options)); +} #endif