Skip to content

Commit

Permalink
merge split blocks to reduce kernel calls
Browse files Browse the repository at this point in the history
  • Loading branch information
RSchwan committed Feb 20, 2025
1 parent fb2d944 commit 773082c
Showing 1 changed file with 14 additions and 2 deletions.
16 changes: 14 additions & 2 deletions include/piqp/sparse/blocksparse_stage_kkt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -538,12 +538,12 @@ class BlocksparseStageKKT : public KKTSystem<T>
if (i + 1 >= current_block_info.diag_block_start + current_block_info.diag_block_size) {

bool hit_optimal_ratio = current_block_info.diag_block_size >= 2 * current_block_info.off_diag_block_size;
bool at_end = i + 1 >= n - current_block_info.arrow_width;
auto next_block_grows = [&]() {
if (i >= n) return false;
block_structure_info next_block_info = get_next_block_structure(I(i) + 1, current_block_info);
return next_block_info.diag_block_size + next_block_info.off_diag_block_size > current_block_info.diag_block_size + current_block_info.off_diag_block_size;
};
bool at_end = i + 1 >= n - current_block_info.arrow_width;

if (hit_optimal_ratio || at_end || next_block_grows()) {
// std::cout << "B " << current_block_info.diag_block_start << " " << current_block_info.diag_block_size << " " << current_block_info.off_diag_block_size << " " << current_block_info.arrow_width << std::endl;
Expand Down Expand Up @@ -582,6 +582,18 @@ class BlocksparseStageKKT : public KKTSystem<T>
}
}

// merge blocks which are split in two
// this doesn't change the flops, but reduces the number of kernel calls
for (std::size_t i = 0; i < block_info.size() - 1; i++) {
if (block_info[i].off_diag_size == block_info[i + 1].diag_size && block_info[i + 1].off_diag_size == 0) {
block_info[i].diag_size += block_info[i].off_diag_size;
block_info[i].off_diag_size = 0;
auto iter = block_info.begin();
std::advance(iter, i + 1);
block_info.erase(iter);
}
}

// last block corresponds to corner block of arrow
block_info.push_back({current_block_info.diag_block_start, current_block_info.arrow_width, 0});
assert(block_info.size() >= 2);
Expand Down Expand Up @@ -1260,7 +1272,7 @@ class BlocksparseStageKKT : public KKTSystem<T>
// L_1 = chol(D_1)
blasfeo_dpotrf_l(m, kkt_fac.D[0]->ref(), 0, 0, kkt_fac.D[0]->ref(), 0, 0);

if (kkt_fac.B[0]) {
if (N > 2 && kkt_fac.B[0]) {
m = kkt_fac.B[0]->rows();
n = kkt_fac.B[0]->cols();
assert(kkt_fac.D[0]->rows() == n && kkt_fac.D[0]->cols() == n && "size mismatch");
Expand Down

0 comments on commit 773082c

Please sign in to comment.