Skip to content

Commit 45c69b9

Browse files
committed
add block diag methods
1 parent 22514f6 commit 45c69b9

File tree

1 file changed

+18
-10
lines changed

1 file changed

+18
-10
lines changed

include/albatross/src/linalg/block_diagonal.hpp

+18-10
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,9 @@ struct BlockDiagonalLDLT {
3737
solve(const Eigen::Matrix<_Scalar, _Rows, _Cols> &rhs,
3838
ThreadPool *pool) const;
3939

40-
template <class _Scalar, int _Rows, int _Cols>
41-
Eigen::Matrix<_Scalar, _Rows, _Cols>
42-
sqrt_solve(const Eigen::Matrix<_Scalar, _Rows, _Cols> &rhs,
43-
ThreadPool *pool) const;
40+
template <typename Derived>
41+
Eigen::MatrixXd sqrt_solve(const Eigen::DenseBase<Derived> &rhs,
42+
ThreadPool *pool) const;
4443

4544
BlockDiagonal sqrt_transpose() const;
4645

@@ -51,6 +50,8 @@ struct BlockDiagonalLDLT {
5150
Eigen::Index rows() const;
5251

5352
Eigen::Index cols() const;
53+
54+
bool operator==(const BlockDiagonalLDLT &other) const;
5455
};
5556

5657
struct BlockDiagonal {
@@ -141,20 +142,23 @@ BlockDiagonalLDLT::solve(const Eigen::Matrix<_Scalar, _Rows, _Cols> &rhs,
141142
return output;
142143
}
143144

144-
template <class _Scalar, int _Rows, int _Cols>
145-
inline Eigen::Matrix<_Scalar, _Rows, _Cols>
146-
BlockDiagonalLDLT::sqrt_solve(const Eigen::Matrix<_Scalar, _Rows, _Cols> &rhs,
145+
template <typename Derived>
146+
inline Eigen::MatrixXd
147+
BlockDiagonalLDLT::sqrt_solve(const Eigen::DenseBase<Derived> &rhs,
147148
ThreadPool *pool) const {
148149
ALBATROSS_ASSERT(cols() == rhs.rows());
149-
Eigen::Matrix<_Scalar, _Rows, _Cols> output(rows(), rhs.cols());
150+
Eigen::MatrixXd output(rows(), rhs.cols());
150151

151152
auto solve_and_fill_one_block = [&](const size_t i, const Eigen::Index row) {
152-
const auto rhs_chunk = rhs.block(row, 0, blocks[i].rows(), rhs.cols());
153+
const auto rhs_chunk =
154+
rhs.derived().block(row, 0, blocks[i].rows(), rhs.cols());
153155
output.block(row, 0, blocks[i].rows(), rhs.cols()) =
154156
blocks[i].sqrt_solve(rhs_chunk);
155157
};
156158

157-
apply_map(block_to_row_map(), solve_and_fill_one_block, pool);
159+
// Intentionally leaving pool out here due to an unknown bug
160+
// in which the thread pool version crashes in sqrt_solve.
161+
apply_map(block_to_row_map(), solve_and_fill_one_block);
158162
return output;
159163
}
160164

@@ -182,6 +186,10 @@ inline Eigen::Index BlockDiagonalLDLT::cols() const {
182186
return n;
183187
}
184188

189+
inline bool
190+
BlockDiagonalLDLT::operator==(const BlockDiagonalLDLT &other) const {
191+
return blocks == other.blocks;
192+
}
185193
/*
186194
* Block Diagonal
187195
*/

0 commit comments

Comments
 (0)