Skip to content

Commit

Permalink
fix bug with arrow structure
Browse files Browse the repository at this point in the history
  • Loading branch information
RSchwan committed Feb 17, 2025
1 parent 41d0738 commit ab16b69
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 22 deletions.
32 changes: 13 additions & 19 deletions include/piqp/sparse/blocksparse_stage_kkt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -576,8 +576,9 @@ class BlocksparseStageKKT : public KKTSystem<T>
current_block_info.prev_diag_block_size = current_block_info.diag_block_size;
current_block_info.diag_block_size = current_block_info.off_diag_block_size;
current_block_info.off_diag_block_size = 0;
break;
}

if (at_end) break;
}
}

Expand Down Expand Up @@ -1170,20 +1171,20 @@ class BlocksparseStageKKT : public KKTSystem<T>
}
}

// the terms AtA.E or GtG.E might be smaller,
// thus we have to zero the whole matrix just in case
if (!allocate && kkt_fac.E[i] && !mat_set) {
kkt_fac.E[i]->setZero();
}

if (AtA.E[i]) {
if (allocate) {
if (!kkt_fac.E[i]) {
kkt_fac.E[i] = std::make_unique<BlasfeoMat>(m, n);
}
} else {
if (mat_set) {
// E_i += delta^{-1} * AtA.E_i
blasfeo_dgead(delta_inv, *AtA.E[i], *kkt_fac.E[i]);
} else {
// E_i = delta^{-1} * AtA.E_i
blasfeo_dgecpsc(delta_inv, *AtA.E[i], *kkt_fac.E[i]);
mat_set = true;
}
// E_i += delta^{-1} * AtA.E_i
blasfeo_dgead(delta_inv, *AtA.E[i], *kkt_fac.E[i]);
}
}

Expand All @@ -1193,23 +1194,16 @@ class BlocksparseStageKKT : public KKTSystem<T>
kkt_fac.E[i] = std::make_unique<BlasfeoMat>(m, n);
}
} else {
if (mat_set) {
// E_i += GtG.E_i
blasfeo_dgead(1.0, *GtG.E[i], *kkt_fac.E[i]);
} else {
// E_i = GtG.E_i
blasfeo_dgecp(*GtG.E[i], *kkt_fac.E[i]);
}
// E_i += GtG.E_i
blasfeo_dgead(1.0, *GtG.E[i], *kkt_fac.E[i]);
}
}

// Only the arrow can have more allocated blocks because
// of the factorization if the previous factors exist.
if (!mat_set && i > 0 && kkt_fac.E[i - 1] && kkt_fac.B[i - 1]) {
if (allocate && !kkt_fac.E[i]) {
kkt_fac.E[i] = std::make_unique<BlasfeoMat>(kkt_fac.E[i - 1]->rows(), kkt_fac.D[i - 1]->rows());
} else {
kkt_fac.E[i]->setZero();
kkt_fac.E[i] = std::make_unique<BlasfeoMat>(kkt_fac.E[i - 1]->rows(), kkt_fac.D[i]->rows());
}
}
}
Expand Down
6 changes: 3 additions & 3 deletions include/piqp/utils/blasfeo_wrapper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ static inline void blasfeo_dgecp(BlasfeoMat& A, BlasfeoMat& B)
{
int m = A.rows();
int n = A.cols();
assert(B.rows() == m && B.cols() == n && "size mismatch");
assert(B.rows() >= m && B.cols() >= n && "size mismatch");
blasfeo_dgecp(m, n, A.ref(), 0, 0, B.ref(), 0, 0);
}

Expand All @@ -31,7 +31,7 @@ static inline void blasfeo_dgecpsc(double alpha, BlasfeoMat& A, BlasfeoMat& B)
{
int m = A.rows();
int n = A.cols();
assert(B.rows() == m && B.cols() == n && "size mismatch");
assert(B.rows() >= m && B.cols() >= n && "size mismatch");
blasfeo_dgecpsc(m, n, alpha, A.ref(), 0, 0, B.ref(), 0, 0);
}

Expand Down Expand Up @@ -63,7 +63,7 @@ static inline void blasfeo_dgead(double alpha, BlasfeoMat& A, BlasfeoMat& B)
{
int m = A.rows();
int n = A.cols();
assert(B.rows() == m && B.cols() == n && "size mismatch");
assert(B.rows() >= m && B.cols() >= n && "size mismatch");
blasfeo_dgead(m, n, alpha, A.ref(), 0, 0, B.ref(), 0, 0);
}

Expand Down

0 comments on commit ab16b69

Please sign in to comment.