Skip to content

Commit 5c4c79c

Browse files
yuvaltassacopybara-github
authored andcommitted
Refactor LD back-substitution (CSR version).
PiperOrigin-RevId: 717403904 Change-Id: Idcd5e71f03a960203c22cb737d399fac5a0ba59c
1 parent 2a10054 commit 5c4c79c

File tree

1 file changed

+53
-59
lines changed

1 file changed

+53
-59
lines changed

src/engine/engine_core_smooth.c

+53-59
Original file line numberDiff line numberDiff line change
@@ -1609,84 +1609,78 @@ void mj_solveLD(const mjModel* m, mjtNum* restrict x, int n,
16091609
// like mj_solveLD, but using the CSR representation of L
16101610
void mj_solveLDs(mjtNum* restrict x, const mjtNum* qLDs, const mjtNum* qLDiagInv, int nv, int n,
16111611
const int* rownnz, const int* rowadr, const int* diagnum, const int* colind) {
1612-
// single vector
1613-
if (n == 1) {
1614-
// x <- L^-T x
1615-
for (int i=nv-1; i > 0; i--) {
1616-
// skip diagonal rows, zero elements in input vector
1617-
mjtNum x_i = x[i];
1618-
if (x_i == 0 || diagnum[i]) {
1619-
continue;
1620-
}
1621-
1622-
int start = rowadr[i];
1623-
int end = start + rownnz[i] - 1;
1624-
for (int adr=start; adr < end; adr++) {
1625-
x[colind[adr]] -= qLDs[adr] * x_i;
1626-
}
1627-
}
1628-
1629-
// x <- D^-1 x
1630-
for (int i=0; i < nv; i++) {
1631-
x[i] *= qLDiagInv[i];
1612+
// x <- L^-T x
1613+
for (int i=nv-1; i > 0; i--) {
1614+
// skip diagonal rows
1615+
if (diagnum[i]) {
1616+
continue;
16321617
}
16331618

1634-
// x <- L^-1 x
1635-
for (int i=1; i < nv; i++) {
1636-
// skip diagonal rows
1637-
if (diagnum[i]) {
1638-
i += diagnum[i] - 1; // iterating forward: skip ahead, adjust i
1639-
continue;
1619+
// one vector
1620+
if (n == 1) {
1621+
mjtNum x_i;
1622+
if ((x_i = x[i])) {
1623+
int start = rowadr[i];
1624+
int end = start + rownnz[i] - 1;
1625+
for (int adr=start; adr < end; adr++) {
1626+
x[colind[adr]] -= qLDs[adr] * x_i;
1627+
}
16401628
}
1641-
1642-
int adr = rowadr[i];
1643-
x[i] -= mju_dotSparse(qLDs+adr, x, rownnz[i] - 1, colind+adr, /*flg_unc1=*/0);
16441629
}
1645-
}
1646-
1647-
// multiple vectors
1648-
else {
1649-
// x <- L^-T x
1650-
for (int i=nv-1; i > 0; i--) {
1651-
// skip diagonal rows
1652-
if (diagnum[i]) {
1653-
continue;
1654-
}
16551630

1631+
// multiple vectors
1632+
else {
16561633
int start = rowadr[i];
16571634
int end = start + rownnz[i] - 1;
1658-
for (int adr=start; adr < end; adr++) {
1659-
int j = colind[adr];
1660-
mjtNum val = qLDs[adr];
1661-
for (int offset=0; offset < n*nv; offset+=nv) {
1662-
mjtNum x_i;
1663-
if ((x_i = x[i+offset])) {
1664-
x[j+offset] -= val * x_i;
1635+
for (int offset=0; offset < n*nv; offset+=nv) {
1636+
mjtNum x_i;
1637+
if ((x_i = x[i+offset])) {
1638+
for (int adr=start; adr < end; adr++) {
1639+
x[offset + colind[adr]] -= qLDs[adr] * x_i;
16651640
}
16661641
}
16671642
}
16681643
}
1644+
}
1645+
1646+
// x <- D^-1 x
1647+
for (int i=0; i < nv; i++) {
1648+
mjtNum invD_i = qLDiagInv[i];
1649+
1650+
// one vector
1651+
if (n == 1) {
1652+
x[i] *= invD_i;
1653+
}
16691654

1670-
// x <- D^-1 x
1671-
for (int i=0; i < nv; i++) {
1672-
mjtNum invD_i = qLDiagInv[i];
1655+
// multiple vectors
1656+
else {
16731657
for (int offset=0; offset < n*nv; offset+=nv) {
16741658
x[i+offset] *= invD_i;
16751659
}
16761660
}
1661+
}
16771662

1678-
// x <- L^-1 x
1679-
for (int i=1; i < nv; i++) {
1680-
// skip diagonal rows
1681-
if (diagnum[i]) {
1682-
i += diagnum[i] - 1; // iterating forward: skip ahead, adjust i
1683-
continue;
1663+
// x <- L^-1 x
1664+
for (int i=1; i < nv; i++) {
1665+
// skip diagonal rows
1666+
if (diagnum[i]) {
1667+
i += diagnum[i] - 1; // iterating forward: skip ahead, adjust i
1668+
continue;
1669+
}
1670+
1671+
int adr = rowadr[i];
1672+
int d = rownnz[i] - 1;
1673+
if (d > 0) {
1674+
// one vector
1675+
if (n == 1) {
1676+
x[i] -= mju_dotSparse(qLDs+adr, x, d, colind+adr, /*flg_unc1=*/0);
16841677
}
16851678

1686-
int adr = rowadr[i];
1687-
int d = rownnz[i] - 1;
1688-
for (int offset=0; offset < n*nv; offset+=nv) {
1689-
x[i+offset] -= mju_dotSparse(qLDs+adr, x+offset, d, colind+adr, /*flg_unc1=*/0);
1679+
// multiple vectors
1680+
else {
1681+
for (int offset=0; offset < n*nv; offset+=nv) {
1682+
x[i+offset] -= mju_dotSparse(qLDs+adr, x+offset, d, colind+adr, /*flg_unc1=*/0);
1683+
}
16901684
}
16911685
}
16921686
}

0 commit comments

Comments
 (0)