@@ -37,10 +37,9 @@ struct BlockDiagonalLDLT {
37
37
solve (const Eigen::Matrix<_Scalar, _Rows, _Cols> &rhs,
38
38
ThreadPool *pool) const ;
39
39
40
- template <class _Scalar , int _Rows, int _Cols>
41
- Eigen::Matrix<_Scalar, _Rows, _Cols>
42
- sqrt_solve (const Eigen::Matrix<_Scalar, _Rows, _Cols> &rhs,
43
- ThreadPool *pool) const ;
40
+ template <typename Derived>
41
+ Eigen::MatrixXd sqrt_solve (const Eigen::DenseBase<Derived> &rhs,
42
+ ThreadPool *pool) const ;
44
43
45
44
BlockDiagonal sqrt_transpose () const ;
46
45
@@ -51,6 +50,8 @@ struct BlockDiagonalLDLT {
51
50
Eigen::Index rows () const ;
52
51
53
52
Eigen::Index cols () const ;
53
+
54
+ bool operator ==(const BlockDiagonalLDLT &other) const ;
54
55
};
55
56
56
57
struct BlockDiagonal {
@@ -141,20 +142,23 @@ BlockDiagonalLDLT::solve(const Eigen::Matrix<_Scalar, _Rows, _Cols> &rhs,
141
142
return output;
142
143
}
143
144
144
- template <class _Scalar , int _Rows, int _Cols >
145
- inline Eigen::Matrix<_Scalar, _Rows, _Cols>
146
- BlockDiagonalLDLT::sqrt_solve (const Eigen::Matrix<_Scalar, _Rows, _Cols > &rhs,
145
+ template <typename Derived >
146
+ inline Eigen::MatrixXd
147
+ BlockDiagonalLDLT::sqrt_solve (const Eigen::DenseBase<Derived > &rhs,
147
148
ThreadPool *pool) const {
148
149
ALBATROSS_ASSERT (cols () == rhs.rows ());
149
- Eigen::Matrix<_Scalar, _Rows, _Cols> output (rows (), rhs.cols ());
150
+ Eigen::MatrixXd output (rows (), rhs.cols ());
150
151
151
152
auto solve_and_fill_one_block = [&](const size_t i, const Eigen::Index row) {
152
- const auto rhs_chunk = rhs.block (row, 0 , blocks[i].rows (), rhs.cols ());
153
+ const auto rhs_chunk =
154
+ rhs.derived ().block (row, 0 , blocks[i].rows (), rhs.cols ());
153
155
output.block (row, 0 , blocks[i].rows (), rhs.cols ()) =
154
156
blocks[i].sqrt_solve (rhs_chunk);
155
157
};
156
158
157
- apply_map (block_to_row_map (), solve_and_fill_one_block, pool);
159
+ // Intentionally leaving pool out here due to an unknown bug
160
+ // in which the thread pool version crashes in sqrt_solve.
161
+ apply_map (block_to_row_map (), solve_and_fill_one_block);
158
162
return output;
159
163
}
160
164
@@ -182,6 +186,10 @@ inline Eigen::Index BlockDiagonalLDLT::cols() const {
182
186
return n;
183
187
}
184
188
189
+ inline bool
190
+ BlockDiagonalLDLT::operator ==(const BlockDiagonalLDLT &other) const {
191
+ return blocks == other.blocks ;
192
+ }
185
193
/*
186
194
* Block Diagonal
187
195
*/
0 commit comments