Skip to content

Commit

Permalink
try something
Browse files Browse the repository at this point in the history
  • Loading branch information
mlondschien committed Jan 24, 2025
1 parent f9fa1a5 commit 6d0a3ac
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 6 deletions.
4 changes: 1 addition & 3 deletions src/tabmat/dense_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,7 @@ def _cross_sandwich(

def _get_col_stds(self, weights: np.ndarray, col_means: np.ndarray) -> np.ndarray:
"""Get standard deviations of columns using weights `weights`."""
sqrt_arg = transpose_square_dot_weights(
self._array - col_means[np.newaxis, :], weights
)
sqrt_arg = transpose_square_dot_weights(self._array, weights, col_means)
# Minor floating point errors above can result in a very slightly
# negative sqrt_arg (e.g. -5e-16). We just set those values equal to
# zero.
Expand Down
6 changes: 3 additions & 3 deletions src/tabmat/ext/dense.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def dense_matvec(np.ndarray X, floating[:] v, int[:] rows, int[:] cols):
raise Exception("The matrix X is not contiguous.")
return out

def transpose_square_dot_weights(np.ndarray X, floating[:] weights):
def transpose_square_dot_weights(np.ndarray X, floating[:] weights, floating[:] shift):
cdef floating* Xp = <floating*>X.data
cdef int nrows = weights.shape[0]
cdef int ncols = X.shape[1]
Expand All @@ -112,11 +112,11 @@ def transpose_square_dot_weights(np.ndarray X, floating[:] weights):
if X.flags["C_CONTIGUOUS"]:
for j in prange(ncols, nogil=True):
for i in range(nrows):
outp[j] = outp[j] + weights[i] * (Xp[i * ncols + j] ** 2)
outp[j] = outp[j] + weights[i] * ((Xp[i * ncols + j] - shift[j]) ** 2)
elif X.flags["F_CONTIGUOUS"]:
for j in prange(ncols, nogil=True):
for i in range(nrows):
outp[j] = outp[j] + weights[i] * (Xp[j * nrows + i] ** 2)
outp[j] = outp[j] + weights[i] * ((Xp[j * nrows + i] - shift[j]) ** 2)
else:
raise Exception("The matrix X is not contiguous.")
return out

0 comments on commit 6d0a3ac

Please sign in to comment.