forked from kokkos/kokkos-kernels
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement batched serial pttrf (kokkos#2256)
* Batched serial pttrf implementation * fix: use GEMM to add matrices * fix: initialization order * fformat * fix: temporary variable in a test code * fix: docstring of pttrf * check_positive_definitiveness only if KOKKOSKERNELS_DEBUG_LEVEL > 0 * Improve the test for pttrf * fix: int type * fix: cleanup tests for SerialPttrf * cleanup: remove unused deep_copies * fix: docstrings and comments for pttrf * ConjTranspose with conj and Transpose * quick return in pttrf for size 1 or 0 matrix * Add tests for invalid input * fix: info computation --------- Co-authored-by: Yuuichi Asahi <y.asahi@nr.titech.ac.jp>
- Loading branch information
1 parent
ea430c3
commit 994891a
Showing
9 changed files
with
909 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
//@HEADER | ||
// ************************************************************************ | ||
// | ||
// Kokkos v. 4.0 | ||
// Copyright (2022) National Technology & Engineering | ||
// Solutions of Sandia, LLC (NTESS). | ||
// | ||
// Under the terms of Contract DE-NA0003525 with NTESS, | ||
// the U.S. Government retains certain rights in this software. | ||
// | ||
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://kokkos.org/LICENSE for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//@HEADER | ||
#ifndef KOKKOSBATCHED_PTTRF_SERIAL_IMPL_HPP_ | ||
#define KOKKOSBATCHED_PTTRF_SERIAL_IMPL_HPP_ | ||
|
||
#include <KokkosBatched_Util.hpp> | ||
#include "KokkosBatched_Pttrf_Serial_Internal.hpp" | ||
|
||
/// \author Yuuichi Asahi (yuuichi.asahi@cea.fr) | ||
|
||
namespace KokkosBatched { | ||
|
||
template <typename DViewType, typename EViewType> | ||
KOKKOS_INLINE_FUNCTION static int checkPttrfInput( | ||
[[maybe_unused]] const DViewType &d, [[maybe_unused]] const EViewType &e) { | ||
static_assert(Kokkos::is_view<DViewType>::value, | ||
"KokkosBatched::pttrf: DViewType is not a Kokkos::View."); | ||
static_assert(Kokkos::is_view<EViewType>::value, | ||
"KokkosBatched::pttrf: EViewType is not a Kokkos::View."); | ||
|
||
static_assert(DViewType::rank == 1, | ||
"KokkosBatched::pttrf: DViewType must have rank 1."); | ||
static_assert(EViewType::rank == 1, | ||
"KokkosBatched::pttrf: EViewType must have rank 1."); | ||
|
||
#if (KOKKOSKERNELS_DEBUG_LEVEL > 0) | ||
const int nd = d.extent(0); | ||
const int ne = e.extent(0); | ||
|
||
if (ne + 1 != nd) { | ||
Kokkos::printf( | ||
"KokkosBatched::pttrf: Dimensions of d and e do not match: d: %d, e: " | ||
"%d \n" | ||
"e.extent(0) must be equal to d.extent(0) - 1\n", | ||
nd, ne); | ||
return 1; | ||
} | ||
#endif | ||
return 0; | ||
} | ||
|
||
template <> | ||
struct SerialPttrf<Algo::Pttrf::Unblocked> { | ||
template <typename DViewType, typename EViewType> | ||
KOKKOS_INLINE_FUNCTION static int invoke(const DViewType &d, | ||
const EViewType &e) { | ||
// Quick return if possible | ||
if (d.extent(0) == 0) return 0; | ||
if (d.extent(0) == 1) return (d(0) < 0 ? 1 : 0); | ||
|
||
auto info = checkPttrfInput(d, e); | ||
if (info) return info; | ||
|
||
return SerialPttrfInternal<Algo::Pttrf::Unblocked>::invoke( | ||
d.extent(0), d.data(), d.stride(0), e.data(), e.stride(0)); | ||
} | ||
}; | ||
} // namespace KokkosBatched | ||
|
||
#endif // KOKKOSBATCHED_PTTRF_SERIAL_IMPL_HPP_ |
211 changes: 211 additions & 0 deletions
211
batched/dense/impl/KokkosBatched_Pttrf_Serial_Internal.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,211 @@ | ||
//@HEADER | ||
// ************************************************************************ | ||
// | ||
// Kokkos v. 4.0 | ||
// Copyright (2022) National Technology & Engineering | ||
// Solutions of Sandia, LLC (NTESS). | ||
// | ||
// Under the terms of Contract DE-NA0003525 with NTESS, | ||
// the U.S. Government retains certain rights in this software. | ||
// | ||
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://kokkos.org/LICENSE for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//@HEADER | ||
#ifndef KOKKOSBATCHED_PTTRF_SERIAL_INTERNAL_HPP_ | ||
#define KOKKOSBATCHED_PTTRF_SERIAL_INTERNAL_HPP_ | ||
|
||
#include <KokkosBatched_Util.hpp> | ||
|
||
/// \author Yuuichi Asahi (yuuichi.asahi@cea.fr) | ||
|
||
namespace KokkosBatched { | ||
|
||
template <typename AlgoType> | ||
struct SerialPttrfInternal { | ||
template <typename ValueType> | ||
KOKKOS_INLINE_FUNCTION static int invoke(const int n, | ||
ValueType *KOKKOS_RESTRICT d, | ||
const int ds0, | ||
ValueType *KOKKOS_RESTRICT e, | ||
const int es0); | ||
|
||
template <typename ValueType> | ||
KOKKOS_INLINE_FUNCTION static int invoke( | ||
const int n, ValueType *KOKKOS_RESTRICT d, const int ds0, | ||
Kokkos::complex<ValueType> *KOKKOS_RESTRICT e, const int es0); | ||
}; | ||
|
||
/// | ||
/// Real matrix | ||
/// | ||
|
||
template <> | ||
template <typename ValueType> | ||
KOKKOS_INLINE_FUNCTION int SerialPttrfInternal<Algo::Pttrf::Unblocked>::invoke( | ||
const int n, ValueType *KOKKOS_RESTRICT d, const int ds0, | ||
ValueType *KOKKOS_RESTRICT e, const int es0) { | ||
int info = 0; | ||
|
||
auto update = [&](const int i) { | ||
auto ei_tmp = e[i * es0]; | ||
e[i * es0] = ei_tmp / d[i * ds0]; | ||
d[(i + 1) * ds0] -= e[i * es0] * ei_tmp; | ||
}; | ||
|
||
auto check_positive_definitiveness = [&](const int i) { | ||
return (d[i] <= 0.0) ? (i + 1) : 0; | ||
}; | ||
|
||
// Compute the L*D*L' (or U'*D*U) factorization of A. | ||
const int i4 = (n - 1) % 4; | ||
for (int i = 0; i < i4; i++) { | ||
#if (KOKKOSKERNELS_DEBUG_LEVEL > 0) | ||
info = check_positive_definitiveness(i); | ||
if (info) { | ||
return info; | ||
} | ||
#endif | ||
|
||
update(i); | ||
} // for (int i = 0; i < i4; i++) | ||
|
||
for (int i = i4; i < n - 4; i += 4) { | ||
#if (KOKKOSKERNELS_DEBUG_LEVEL > 0) | ||
info = check_positive_definitiveness(i); | ||
if (info) { | ||
return info; | ||
} | ||
#endif | ||
|
||
update(i); | ||
|
||
#if (KOKKOSKERNELS_DEBUG_LEVEL > 0) | ||
info = check_positive_definitiveness(i + 1); | ||
if (info) { | ||
return info; | ||
} | ||
#endif | ||
|
||
update(i + 1); | ||
|
||
#if (KOKKOSKERNELS_DEBUG_LEVEL > 0) | ||
info = check_positive_definitiveness(i + 2); | ||
if (info) { | ||
return info; | ||
} | ||
#endif | ||
|
||
update(i + 2); | ||
|
||
#if (KOKKOSKERNELS_DEBUG_LEVEL > 0) | ||
info = check_positive_definitiveness(i + 3); | ||
if (info) { | ||
return info; | ||
} | ||
#endif | ||
|
||
update(i + 3); | ||
|
||
} // for (int i = i4; i < n-4; 4) | ||
|
||
#if (KOKKOSKERNELS_DEBUG_LEVEL > 0) | ||
info = check_positive_definitiveness(n - 1); | ||
if (info) { | ||
return info; | ||
} | ||
#endif | ||
|
||
return 0; | ||
} | ||
|
||
/// | ||
/// Complex matrix | ||
/// | ||
|
||
template <> | ||
template <typename ValueType> | ||
KOKKOS_INLINE_FUNCTION int SerialPttrfInternal<Algo::Pttrf::Unblocked>::invoke( | ||
const int n, ValueType *KOKKOS_RESTRICT d, const int ds0, | ||
Kokkos::complex<ValueType> *KOKKOS_RESTRICT e, const int es0) { | ||
int info = 0; | ||
|
||
auto update = [&](const int i) { | ||
auto eir_tmp = e[i * es0].real(); | ||
auto eii_tmp = e[i * es0].imag(); | ||
auto f_tmp = eir_tmp / d[i * ds0]; | ||
auto g_tmp = eii_tmp / d[i * ds0]; | ||
e[i * es0] = Kokkos::complex<ValueType>(f_tmp, g_tmp); | ||
d[(i + 1) * ds0] = d[(i + 1) * ds0] - f_tmp * eir_tmp - g_tmp * eii_tmp; | ||
}; | ||
|
||
auto check_positive_definitiveness = [&](const int i) { | ||
return (d[i] <= 0.0) ? (i + 1) : 0; | ||
}; | ||
|
||
// Compute the L*D*L' (or U'*D*U) factorization of A. | ||
const int i4 = (n - 1) % 4; | ||
for (int i = 0; i < i4; i++) { | ||
#if (KOKKOSKERNELS_DEBUG_LEVEL > 0) | ||
info = check_positive_definitiveness(i); | ||
if (info) { | ||
return info; | ||
} | ||
#endif | ||
|
||
update(i); | ||
} // for (int i = 0; i < i4; i++) | ||
|
||
for (int i = i4; i < n - 4; i += 4) { | ||
#if (KOKKOSKERNELS_DEBUG_LEVEL > 0) | ||
info = check_positive_definitiveness(i); | ||
if (info) { | ||
return info; | ||
} | ||
#endif | ||
|
||
update(i); | ||
|
||
#if (KOKKOSKERNELS_DEBUG_LEVEL > 0) | ||
info = check_positive_definitiveness(i + 1); | ||
if (info) { | ||
return info; | ||
} | ||
#endif | ||
|
||
update(i + 1); | ||
|
||
#if (KOKKOSKERNELS_DEBUG_LEVEL > 0) | ||
info = check_positive_definitiveness(i + 2); | ||
if (info) { | ||
return info; | ||
} | ||
#endif | ||
|
||
update(i + 2); | ||
|
||
#if (KOKKOSKERNELS_DEBUG_LEVEL > 0) | ||
info = check_positive_definitiveness(i + 3); | ||
if (info) { | ||
return info; | ||
} | ||
#endif | ||
|
||
update(i + 3); | ||
|
||
} // for (int i = i4; i < n-4; 4) | ||
|
||
#if (KOKKOSKERNELS_DEBUG_LEVEL > 0) | ||
info = check_positive_definitiveness(n - 1); | ||
if (info) { | ||
return info; | ||
} | ||
#endif | ||
|
||
return 0; | ||
} | ||
|
||
} // namespace KokkosBatched | ||
|
||
#endif // KOKKOSBATCHED_PTTRF_SERIAL_INTERNAL_HPP_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
//@HEADER | ||
// ************************************************************************ | ||
// | ||
// Kokkos v. 4.0 | ||
// Copyright (2022) National Technology & Engineering | ||
// Solutions of Sandia, LLC (NTESS). | ||
// | ||
// Under the terms of Contract DE-NA0003525 with NTESS, | ||
// the U.S. Government retains certain rights in this software. | ||
// | ||
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://kokkos.org/LICENSE for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//@HEADER | ||
#ifndef KOKKOSBATCHED_PTTRF_HPP_ | ||
#define KOKKOSBATCHED_PTTRF_HPP_ | ||
|
||
#include <KokkosBatched_Util.hpp> | ||
|
||
/// \author Yuuichi Asahi (yuuichi.asahi@cea.fr) | ||
|
||
namespace KokkosBatched { | ||
|
||
/// \brief Serial Batched Pttrf: | ||
/// Compute the Cholesky factorization L*D*L**T (or L*D*L**H) of a real | ||
/// symmetric (or complex Hermitian) positive definite tridiagonal matrix A_l | ||
/// for all l = 0, ..., N | ||
/// | ||
/// \tparam DViewType: Input type for the a diagonal matrix, needs to be a 1D | ||
/// view | ||
/// \tparam EViewType: Input type for the a upper/lower diagonal matrix, | ||
/// needs to be a 1D view | ||
/// | ||
/// \param d [inout]: n diagonal elements of the diagonal matrix D | ||
/// \param e [inout]: n-1 upper/lower diagonal elements of the diagonal matrix E | ||
/// | ||
/// No nested parallel_for is used inside of the function. | ||
/// | ||
|
||
template <typename ArgAlgo> | ||
struct SerialPttrf { | ||
template <typename DViewType, typename EViewType> | ||
KOKKOS_INLINE_FUNCTION static int invoke(const DViewType &d, | ||
const EViewType &e); | ||
}; | ||
|
||
} // namespace KokkosBatched | ||
|
||
#include "KokkosBatched_Pttrf_Serial_Impl.hpp" | ||
|
||
#endif // KOKKOSBATCHED_PTTRF_HPP_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.