diff --git a/src/atlas/array/helpers/ArrayForEach.h b/src/atlas/array/helpers/ArrayForEach.h index 38d4d6441..f20b4c80b 100644 --- a/src/atlas/array/helpers/ArrayForEach.h +++ b/src/atlas/array/helpers/ArrayForEach.h @@ -1,5 +1,5 @@ /* - * (C) Crown Copyright 2023 Met Office + * (C) Crown Copyright 2024 Met Office * * This software is licensed under the terms of the Apache Licence Version 2.0 * which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. @@ -7,9 +7,10 @@ #pragma once +#include #include #include -#include +#include #include "atlas/array/ArrayView.h" #include "atlas/array/Range.h" @@ -22,18 +23,19 @@ namespace atlas { namespace execution { -// As in C++17 std::execution namespace. Note: unsequenced_policy is a C++20 addition. +// As in C++17 std::execution namespace. Note: unsequenced_policy is a C++20 +// addition. class sequenced_policy {}; class unsequenced_policy {}; class parallel_unsequenced_policy {}; class parallel_policy {}; -// execution policy objects as in C++ std::execution namespace. Note: unseq is a C++20 addition. -inline constexpr sequenced_policy seq{ /*unspecified*/ }; -inline constexpr parallel_policy par{ /*unspecified*/ }; -inline constexpr parallel_unsequenced_policy par_unseq{ /*unspecified*/ }; -inline constexpr unsequenced_policy unseq{ /*unspecified*/ }; - +// execution policy objects as in C++ std::execution namespace. Note: unseq is a +// C++20 addition. +inline constexpr sequenced_policy seq{/*unspecified*/}; +inline constexpr parallel_policy par{/*unspecified*/}; +inline constexpr parallel_unsequenced_policy par_unseq{/*unspecified*/}; +inline constexpr unsequenced_policy unseq{/*unspecified*/}; // Type names for execution policy (Not in C++ standard) template @@ -64,34 +66,31 @@ constexpr std::string_view policy_name(execution_policy) { // Type check for execution policy (Not in C++ standard) template -constexpr auto is_execution_policy() { return - std::is_same_v || - std::is_same_v || - std::is_same_v || - std::is_same_v; +constexpr auto is_execution_policy() { + return std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v; } template constexpr auto demote_policy() { - if constexpr(std::is_same_v) { + if constexpr (std::is_same_v) { return unseq; - } - else if constexpr (std::is_same_v) { + } else if constexpr (std::is_same_v) { return seq; - } - else { + } else { return ExecutionPolicy{}; } } template -constexpr auto is_omp_policy() { return - std::is_same_v || - std::is_same_v; +constexpr auto is_omp_policy() { + return std::is_same_v || + std::is_same_v; } - -template +template using demote_policy_t = decltype(demote_policy()); } // namespace execution @@ -102,7 +101,8 @@ namespace option { template util::Config execution_policy() { - return util::Config("execution_policy", execution::policy_name>()); + return util::Config("execution_policy", + execution::policy_name>()); } template @@ -110,7 +110,7 @@ util::Config execution_policy(T) { return execution_policy>(); } -} // namespace option +} // namespace option namespace array { namespace helpers { @@ -119,7 +119,9 @@ namespace detail { struct NoMask { template - constexpr bool operator()(Args...) const { return 0; } + constexpr bool operator()(Args...) const { + return 0; + } }; inline constexpr NoMask no_mask; @@ -131,13 +133,11 @@ constexpr auto tuplePushBack(const std::tuple& tuple, T value) { template void forEach(idx_t idxMax, const Functor& functor) { - - if constexpr(execution::is_omp_policy()) { + if constexpr (execution::is_omp_policy()) { atlas_omp_parallel_for(auto idx = idx_t{}; idx < idxMax; ++idx) { functor(idx); } - } - else { + } else { // Simple for-loop for sequenced or unsequenced execution policies. for (auto idx = idx_t{}; idx < idxMax; ++idx) { functor(idx); @@ -147,11 +147,10 @@ void forEach(idx_t idxMax, const Functor& functor) { template constexpr auto argPadding() { - if constexpr(NPad > 0) { + if constexpr (NPad > 0) { return std::tuple_cat(std::make_tuple(Range::all()), argPadding()); - } - else { + } else { return std::make_tuple(); } } @@ -159,29 +158,31 @@ constexpr auto argPadding() { template auto makeSlices(const std::tuple& slicerArgs, ArrayViewTuple&& arrayViews) { - constexpr auto nb_views = std::tuple_size_v; auto&& arrayView = std::get(arrayViews); using ArrayView = std::decay_t; - constexpr auto Dim = sizeof...(SlicerArgs); - constexpr auto Rank = ArrayView::rank(); + constexpr auto Dim = sizeof...(SlicerArgs); + constexpr auto Rank = ArrayView::rank(); - static_assert (Dim <= Rank, "Error: number of slicer arguments exceeds the rank of ArrayView."); + static_assert( + Dim <= Rank, + "Error: number of slicer arguments exceeds the rank of ArrayView."); const auto paddedArgs = std::tuple_cat(slicerArgs, argPadding()); const auto slicer = [&arrayView](const auto&... args) { return std::make_tuple(arrayView.slice(args...)); }; - if constexpr (ViewIdx == nb_views-1) { + if constexpr (ViewIdx == nb_views - 1) { return std::apply(slicer, paddedArgs); - } - else { + } else { // recurse - return std::tuple_cat(std::apply(slicer, paddedArgs), - makeSlices(slicerArgs, std::forward(arrayViews))); + return std::tuple_cat( + std::apply(slicer, paddedArgs), + makeSlices(slicerArgs, + std::forward(arrayViews))); } } @@ -192,80 +193,79 @@ template struct ArrayForEachImpl { template - static void apply(ArrayViewTuple&& arrayViews, - const Mask& mask, + static void apply(ArrayViewTuple&& arrayViews, const Mask& mask, const Function& function, const std::tuple& slicerArgs, const std::tuple& maskArgs) { // Iterate over this dimension. - if constexpr(Dim == ItrDim) { - + if constexpr (Dim == ItrDim) { // Get size of iteration dimenion from first view argument. const auto idxMax = std::get<0>(arrayViews).shape(ItrDim); forEach(idxMax, [&](idx_t idx) { - // Demote parallel execution policy to a non-parallel one in further recursion - ArrayForEachImpl, Dim + 1, ItrDims...>::apply( - std::forward(arrayViews), mask, function, - tuplePushBack(slicerArgs, idx), - tuplePushBack(maskArgs, idx)); + // Demote parallel execution policy to a non-parallel one in further + // recursion + ArrayForEachImpl< + execution::demote_policy_t, Dim + 1, + ItrDims...>::apply(std::forward(arrayViews), mask, + function, tuplePushBack(slicerArgs, idx), + tuplePushBack(maskArgs, idx)); }); } // Add a RangeAll to arguments. else { ArrayForEachImpl::apply( std::forward(arrayViews), mask, function, - tuplePushBack(slicerArgs, Range::all()), - maskArgs); + tuplePushBack(slicerArgs, Range::all()), maskArgs); } } }; template - struct is_applicable : std::false_type {}; +struct is_applicable : std::false_type {}; template -struct is_applicable> : std::is_invocable {}; +struct is_applicable> + : std::is_invocable {}; template -inline constexpr bool is_applicable_v = is_applicable::value; +inline constexpr bool is_applicable_v = is_applicable::value; template struct ArrayForEachImpl { - template - static void apply(ArrayViewTuple&& arrayViews, - const Mask& mask, + static void apply(ArrayViewTuple&& arrayViews, const Mask& mask, const Function& function, const std::tuple& slicerArgs, const std::tuple& maskArgs) { - constexpr auto maskPresent = !std::is_same_v; if constexpr (maskPresent) { - - constexpr auto invocableMask = std::is_invocable_r_v; - static_assert (invocableMask, - "Cannot invoke mask function with given arguments.\n" - "Make sure you arguments are N integers (or auto...) " - "where N == sizeof...(ItrDims). Function must return an int." - ); - - if (std::apply(mask, maskArgs)) { - return; - } - + constexpr auto invocableMask = + std::is_invocable_r_v; + static_assert( + invocableMask, + "Cannot invoke mask function with given arguments.\n" + "Make sure you arguments are N integers (or auto...) " + "where N == sizeof...(ItrDims). Function must return an int."); + + if (std::apply(mask, maskArgs)) { + return; + } } - auto slices = makeSlices(slicerArgs, std::forward(arrayViews)); + auto slices = + makeSlices(slicerArgs, std::forward(arrayViews)); - constexpr auto applicable = is_applicable_v; - static_assert(applicable, "Cannot invoke function with given arguments. " - "Make sure you the arguments are rvalue references (Slice&&) or const references (const Slice&) or regular value (Slice)" ); + constexpr auto applicable = is_applicable_v; + static_assert( + applicable, + "Cannot invoke function with given arguments. " + "Make sure you the arguments are rvalue references (Slice&&) or const " + "references (const Slice&) or regular value (Slice)"); std::apply(function, std::move(slices)); } - }; } // namespace detail @@ -286,41 +286,36 @@ struct ArrayForEach { /// and is executed with signature g(idx_i, idx_j,...), where the idxs /// are indices of ItrDims. /// When a config is supplied containing "execution_policy" = - /// "sequenced_policy" (default). All loops are then executed in sequential - /// (row-major) order. - /// With "execution_policy" = "parallel_unsequenced" the first loop is executed - /// using OpenMP. The remaining loops are executed in serial. - /// Note: The lowest ArrayView.rank() must be greater than or equal - /// to the highest dim in ItrDims. TODO: static checking for this. + /// "sequenced_policy" (default). All loops are then executed in + /// sequential (row-major) order. With "execution_policy" = + /// "parallel_unsequenced" the first loop is executed using OpenMP. + /// The remaining loops are executed in serial. Note: The lowest + /// ArrayView.rank() must be greater than or equal to the highest dim + /// in ItrDims. TODO: static checking for this. template static void apply(const eckit::Parametrisation& conf, - std::tuple&& arrayViews, - const Mask& mask, const Function& function) { - + std::tuple&& arrayViews, const Mask& mask, + const Function& function) { auto execute = [&](auto execution_policy) { apply(execution_policy, std::move(arrayViews), mask, function); }; using namespace execution; std::string execution_policy; - if (conf.get("execution_policy",execution_policy)) { + if (conf.get("execution_policy", execution_policy)) { if (execution_policy == policy_name(par_unseq)) { execute(par_unseq); - } - else if (execution_policy == policy_name(par)) { + } else if (execution_policy == policy_name(par)) { execute(par); - } - else if (execution_policy == policy_name(unseq)) { + } else if (execution_policy == policy_name(unseq)) { execute(unseq); - } - else if (execution_policy == policy_name(seq)) { + } else if (execution_policy == policy_name(seq)) { execute(seq); + } else { + throw_Exception("Unrecognized execution policy " + execution_policy, + Here()); } - else { - throw_Exception("Unrecognized execution policy "+execution_policy, Here()); - } - } - else { + } else { execute(seq); } } @@ -328,35 +323,47 @@ struct ArrayForEach { /// brief Apply "For-Each" method. /// /// details As above, but Execution policy is determined at compile-time. - template ()>> - static void apply(ExecutionPolicy, std::tuple&& arrayViews, const Mask& mask, const Function& function) { - detail::ArrayForEachImpl::apply( - std::move(arrayViews), mask, function, std::make_tuple(), std::make_tuple()); + template ()>> + static void apply(ExecutionPolicy, std::tuple&& arrayViews, + const Mask& mask, const Function& function) { + detail::ArrayForEachImpl::apply( + std::move(arrayViews), mask, function, std::make_tuple(), + std::make_tuple()); } /// brief Apply "For-Each" method /// /// details Apply ForEach with default execution policy. template - static void apply(std::tuple&& arrayViews, const Mask& mask, const Function& function) { - apply(std::move(arrayViews), mask, function); + static void apply(std::tuple&& arrayViews, const Mask& mask, + const Function& function) { + apply(std::move(arrayViews), mask, function); } /// brief Apply "For-Each" method /// - /// details Apply ForEach with run-time determined execution policy and no mask. + /// details Apply ForEach with run-time determined execution policy and no + /// mask. template - static void apply(const eckit::Parametrisation& conf, std::tuple&& arrayViews, const Function& function) { + static void apply(const eckit::Parametrisation& conf, + std::tuple&& arrayViews, + const Function& function) { apply(conf, std::move(arrayViews), detail::no_mask, function); } /// brief Apply "For-Each" method /// - /// details Apply ForEach with compile-time determined execution policy and no mask. + /// details Apply ForEach with compile-time determined execution policy and no + /// mask. template ()>> - static void apply(ExecutionPolicy executionPolicy, std::tuple&& arrayViews, const Function& function) { + typename = std::enable_if_t< + execution::is_execution_policy()>> + static void apply(ExecutionPolicy executionPolicy, + std::tuple&& arrayViews, + const Function& function) { apply(executionPolicy, std::move(arrayViews), detail::no_mask, function); } @@ -364,12 +371,22 @@ struct ArrayForEach { /// /// details Apply ForEach with default execution policy and no mask. template - static void apply(std::tuple&& arrayViews, const Function& function) { + static void apply(std::tuple&& arrayViews, + const Function& function) { apply(execution::seq, std::move(arrayViews), function); } - }; +/// brief Construct ArrayForEach and call apply +/// +/// details Construct an ArrayForEach using std::integer_sequence +/// . Remaining arguments are forwarded to apply +/// method. +template +void arrayForEachDim(std::integer_sequence, Args&&... args) { + ArrayForEach::apply(std::forward(args)...); +} + } // namespace helpers } // namespace array } // namespace atlas diff --git a/src/tests/array/test_array_foreach.cc b/src/tests/array/test_array_foreach.cc index f8c5b0b4a..42b9d2b09 100644 --- a/src/tests/array/test_array_foreach.cc +++ b/src/tests/array/test_array_foreach.cc @@ -1,5 +1,5 @@ /* - * (C) Crown Copyright 2023 Met Office + * (C) Crown Copyright 2024 Met Office * * This software is licensed under the terms of the Apache Licence Version 2.0 * which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. @@ -7,13 +7,13 @@ #include #include +#include #include "atlas/array.h" #include "atlas/array/MakeView.h" #include "atlas/array/helpers/ArrayForEach.h" #include "atlas/array/helpers/ArraySlicer.h" #include "atlas/util/Config.h" - #include "tests/AtlasTestEnvironment.h" using namespace atlas::array; @@ -207,6 +207,61 @@ CASE("test_array_foreach_3_views") { EXPECT_EQ(count, 60); } +CASE("test_array_foreach_integer_sequence") { + + const auto arr1 = ArrayT(2, 3); + const auto view1 = make_view(arr1); + + const auto arr2 = ArrayT(2, 3, 4); + const auto view2 = make_view(arr2); + + const auto arr3 = ArrayT(2, 3, 4, 5); + const auto view3 = make_view(arr3); + + const auto zero = std::integer_sequence{}; + const auto one = std::integer_sequence{}; + const auto zeroOneTwoThree = std::make_integer_sequence{}; + + + // Test slice shapes. + + const auto loopFunctorDim0 = [](auto&& slice1, auto&& slice2, auto&& slice3) { + EXPECT_EQ(slice1.rank(), 1); + EXPECT_EQ(slice1.shape(0), 3); + + EXPECT_EQ(slice2.rank(), 2); + EXPECT_EQ(slice2.shape(0), 3); + EXPECT_EQ(slice2.shape(1), 4); + + EXPECT_EQ(slice3.rank(), 3); + EXPECT_EQ(slice3.shape(0), 3); + EXPECT_EQ(slice3.shape(1), 4); + EXPECT_EQ(slice3.shape(2), 5); + }; + arrayForEachDim(zero, std::tie(view1, view2, view3), loopFunctorDim0); + + const auto loopFunctorDim1 = [](auto&& slice1, auto&& slice2, auto&& slice3) { + EXPECT_EQ(slice1.rank(), 1); + EXPECT_EQ(slice1.shape(0), 2); + + EXPECT_EQ(slice2.rank(), 2); + EXPECT_EQ(slice2.shape(0), 2); + EXPECT_EQ(slice2.shape(1), 4); + + EXPECT_EQ(slice3.rank(), 3); + EXPECT_EQ(slice3.shape(0), 2); + EXPECT_EQ(slice3.shape(1), 4); + EXPECT_EQ(slice3.shape(2), 5); + }; + arrayForEachDim(one, std::tie(view1, view2, view3), loopFunctorDim1); + + // Test that slice resolves to double. + + const auto loopFunctorDimAll = [](auto&& slice3) { + static_assert(std::is_convertible_v); + }; + arrayForEachDim(zeroOneTwoThree, std::tie(view3), loopFunctorDimAll); +} CASE("test_array_foreach_forwarding") {