Skip to content

Commit

Permalink
Add looped version of gradient, next step would be to jit the function.
Browse files Browse the repository at this point in the history
  • Loading branch information
jrhosk committed Feb 21, 2025
1 parent 5385745 commit 66f6e7e
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
6 changes: 4 additions & 2 deletions src/calviper/math/optimizer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import itertools
import numba

import numpy as np

Expand Down Expand Up @@ -30,6 +31,7 @@ def gradient_(target: np.ndarray, model: np.ndarray, parameter: np.ndarray) -> n
return gradient_

@staticmethod
#@numba.njit
def gradient(target: np.ndarray, model: np.ndarray, parameter: np.ndarray) -> np.ndarray:
# cache_ = target, model.conj()
# numerator_ = np.matmul(cache_, parameter)
Expand All @@ -38,8 +40,8 @@ def gradient(target: np.ndarray, model: np.ndarray, parameter: np.ndarray) -> np
target = target.reshape(n_time, n_channel, 2, 2, n_antennas, n_antennas)
model = model.reshape(n_time, n_channel, 2, 2, n_antennas, n_antennas)

numerator_ = np.zeros((n_time, n_channel, 2, n_antennas))
denominator_ = np.zeros((n_time, n_channel, 2, n_antennas))
numerator_ = np.zeros((n_time, n_channel, 2, n_antennas), dtype=np.complex64)
denominator_ = np.zeros((n_time, n_channel, 2, n_antennas), dtype=np.complex64)

# polarizations per baseline are in the order [XX, XY, YX, YY] I think ... so
for p, q in itertools.product([0, 1], [0, 1]):
Expand Down
4 changes: 2 additions & 2 deletions src/calviper/math/solver/least_squares.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,8 @@ def solve(self, vis, iterations, optimizer=MeanSquaredError(), stopping=1e-3):

self.losses.append(optimizer.loss(y_pred, vis))

#if n % (iterations // 10) == 0:
# logger.info(f"iteration: {n}\tloss: {np.abs(self.losses[-1])}")
if n % (iterations // 10) == 0:
logger.info(f"iteration: {n}\tloss: {np.abs(self.losses[-1])}")

if self.losses[-1] < stopping:
logger.info(f"Iteration: ({n})\tStopping criterion reached: {self.losses[-1]}")
Expand Down

0 comments on commit 66f6e7e

Please sign in to comment.