From 6d0a3acb1b5e5cd4d635315582de1d23fc01268a Mon Sep 17 00:00:00 2001 From: Malte Londschien Date: Fri, 24 Jan 2025 09:23:04 +0100 Subject: [PATCH] try something --- src/tabmat/dense_matrix.py | 4 +--- src/tabmat/ext/dense.pyx | 6 +++--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/tabmat/dense_matrix.py b/src/tabmat/dense_matrix.py index ae19d74e..3502251f 100644 --- a/src/tabmat/dense_matrix.py +++ b/src/tabmat/dense_matrix.py @@ -165,9 +165,7 @@ def _cross_sandwich( def _get_col_stds(self, weights: np.ndarray, col_means: np.ndarray) -> np.ndarray: """Get standard deviations of columns using weights `weights`.""" - sqrt_arg = transpose_square_dot_weights( - self._array - col_means[np.newaxis, :], weights - ) + sqrt_arg = transpose_square_dot_weights(self._array, weights, col_means) # Minor floating point errors above can result in a very slightly # negative sqrt_arg (e.g. -5e-16). We just set those values equal to # zero. diff --git a/src/tabmat/ext/dense.pyx b/src/tabmat/ext/dense.pyx index 32a3ac49..0538df5b 100644 --- a/src/tabmat/ext/dense.pyx +++ b/src/tabmat/ext/dense.pyx @@ -100,7 +100,7 @@ def dense_matvec(np.ndarray X, floating[:] v, int[:] rows, int[:] cols): raise Exception("The matrix X is not contiguous.") return out -def transpose_square_dot_weights(np.ndarray X, floating[:] weights): +def transpose_square_dot_weights(np.ndarray X, floating[:] weights, floating[:] shift): cdef floating* Xp = X.data cdef int nrows = weights.shape[0] cdef int ncols = X.shape[1] @@ -112,11 +112,11 @@ def transpose_square_dot_weights(np.ndarray X, floating[:] weights): if X.flags["C_CONTIGUOUS"]: for j in prange(ncols, nogil=True): for i in range(nrows): - outp[j] = outp[j] + weights[i] * (Xp[i * ncols + j] ** 2) + outp[j] = outp[j] + weights[i] * ((Xp[i * ncols + j] - shift[j]) ** 2) elif X.flags["F_CONTIGUOUS"]: for j in prange(ncols, nogil=True): for i in range(nrows): - outp[j] = outp[j] + weights[i] * (Xp[j * nrows + i] ** 2) + outp[j] = outp[j] + weights[i] * ((Xp[j * nrows + i] - shift[j]) ** 2) else: raise Exception("The matrix X is not contiguous.") return out