diff --git a/src/calviper/math/optimizer.py b/src/calviper/math/optimizer.py index 8fd7dfd..167a0d2 100644 --- a/src/calviper/math/optimizer.py +++ b/src/calviper/math/optimizer.py @@ -1,4 +1,5 @@ import itertools +import numba import numpy as np @@ -30,6 +31,7 @@ def gradient_(target: np.ndarray, model: np.ndarray, parameter: np.ndarray) -> n return gradient_ @staticmethod + #@numba.njit def gradient(target: np.ndarray, model: np.ndarray, parameter: np.ndarray) -> np.ndarray: # cache_ = target, model.conj() # numerator_ = np.matmul(cache_, parameter) @@ -38,8 +40,8 @@ def gradient(target: np.ndarray, model: np.ndarray, parameter: np.ndarray) -> np target = target.reshape(n_time, n_channel, 2, 2, n_antennas, n_antennas) model = model.reshape(n_time, n_channel, 2, 2, n_antennas, n_antennas) - numerator_ = np.zeros((n_time, n_channel, 2, n_antennas)) - denominator_ = np.zeros((n_time, n_channel, 2, n_antennas)) + numerator_ = np.zeros((n_time, n_channel, 2, n_antennas), dtype=np.complex64) + denominator_ = np.zeros((n_time, n_channel, 2, n_antennas), dtype=np.complex64) # polarizations per baseline are in the order [XX, XY, YX, YY] I think ... so for p, q in itertools.product([0, 1], [0, 1]): diff --git a/src/calviper/math/solver/least_squares.py b/src/calviper/math/solver/least_squares.py index 4291666..9b1514a 100644 --- a/src/calviper/math/solver/least_squares.py +++ b/src/calviper/math/solver/least_squares.py @@ -146,8 +146,8 @@ def solve(self, vis, iterations, optimizer=MeanSquaredError(), stopping=1e-3): self.losses.append(optimizer.loss(y_pred, vis)) - #if n % (iterations // 10) == 0: - # logger.info(f"iteration: {n}\tloss: {np.abs(self.losses[-1])}") + if n % (iterations // 10) == 0: + logger.info(f"iteration: {n}\tloss: {np.abs(self.losses[-1])}") if self.losses[-1] < stopping: logger.info(f"Iteration: ({n})\tStopping criterion reached: {self.losses[-1]}")