diff --git a/jones.ipynb b/jones.ipynb index e17a1f5..aedd182 100644 --- a/jones.ipynb +++ b/jones.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "id": "5076dee4-776f-427e-b790-ce440f171cbe", "metadata": {}, "outputs": [], @@ -12,7 +12,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "id": "b75de828-c739-476f-aad5-a6b01e53aef7", "metadata": {}, "outputs": [], @@ -32,7 +32,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "id": "6ed9a29d-367e-46ec-ab9d-ac55cfcfad6e", "metadata": {}, "outputs": [], @@ -43,153 +43,10 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "id": "d9493256-8a80-4bee-ae5e-c4145627b157", "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
nameintentsshapepolarizationscan_numberspw_namefield_namesource_nameline_namefield_coordsstart_frequencyend_frequency
0gaincaltest2_0[CALIBRATE_DELAY#ON_SOURCE, CALIBRATE_PHASE#ON...(957, 45, 8, 4)[XX, XY, YX, YY][2, 4, 6, 9, 11, 14, 16, 18, 21, 23, 26]X0000000000#ALMA_RB_03#BB_1#SW-01#FULL_RES_0[J2255-3500_0][J2255-3500_0][][icrs, 22h55m57.68s, -35d00m00.00s]8.607155e+108.629030e+10
3gaincaltest2_2[CALIBRATE_DELAY#ON_SOURCE, CALIBRATE_PHASE#ON...(957, 45, 8, 4)[XX, XY, YX, YY][2, 4, 6, 9, 11, 14, 16, 18, 21, 23, 26]X0000000000#ALMA_RB_03#BB_2#SW-01#FULL_RES_1[J2255-3500_0][J2255-3500_0][][icrs, 22h55m57.68s, -35d00m00.00s]8.794655e+108.816530e+10
2gaincaltest2_4[CALIBRATE_DELAY#ON_SOURCE, CALIBRATE_PHASE#ON...(957, 45, 8, 4)[XX, XY, YX, YY][2, 4, 6, 9, 11, 14, 16, 18, 21, 23, 26]X0000000000#ALMA_RB_03#BB_3#SW-01#FULL_RES_2[J2255-3500_0][J2255-3500_0][][icrs, 22h55m57.68s, -35d00m00.00s]9.632156e+109.654030e+10
1gaincaltest2_6[CALIBRATE_DELAY#ON_SOURCE, CALIBRATE_PHASE#ON...(957, 45, 8, 4)[XX, XY, YX, YY][2, 4, 6, 9, 11, 14, 16, 18, 21, 23, 26]X0000000000#ALMA_RB_03#BB_4#SW-01#FULL_RES_3[J2255-3500_0][J2255-3500_0][][icrs, 22h55m57.68s, -35d00m00.00s]9.819656e+109.841530e+10
\n", - "
" - ], - "text/plain": [ - " name intents \\\n", - "0 gaincaltest2_0 [CALIBRATE_DELAY#ON_SOURCE, CALIBRATE_PHASE#ON... \n", - "3 gaincaltest2_2 [CALIBRATE_DELAY#ON_SOURCE, CALIBRATE_PHASE#ON... \n", - "2 gaincaltest2_4 [CALIBRATE_DELAY#ON_SOURCE, CALIBRATE_PHASE#ON... \n", - "1 gaincaltest2_6 [CALIBRATE_DELAY#ON_SOURCE, CALIBRATE_PHASE#ON... \n", - "\n", - " shape polarization \\\n", - "0 (957, 45, 8, 4) [XX, XY, YX, YY] \n", - "3 (957, 45, 8, 4) [XX, XY, YX, YY] \n", - "2 (957, 45, 8, 4) [XX, XY, YX, YY] \n", - "1 (957, 45, 8, 4) [XX, XY, YX, YY] \n", - "\n", - " scan_number \\\n", - "0 [2, 4, 6, 9, 11, 14, 16, 18, 21, 23, 26] \n", - "3 [2, 4, 6, 9, 11, 14, 16, 18, 21, 23, 26] \n", - "2 [2, 4, 6, 9, 11, 14, 16, 18, 21, 23, 26] \n", - "1 [2, 4, 6, 9, 11, 14, 16, 18, 21, 23, 26] \n", - "\n", - " spw_name field_name \\\n", - "0 X0000000000#ALMA_RB_03#BB_1#SW-01#FULL_RES_0 [J2255-3500_0] \n", - "3 X0000000000#ALMA_RB_03#BB_2#SW-01#FULL_RES_1 [J2255-3500_0] \n", - "2 X0000000000#ALMA_RB_03#BB_3#SW-01#FULL_RES_2 [J2255-3500_0] \n", - "1 X0000000000#ALMA_RB_03#BB_4#SW-01#FULL_RES_3 [J2255-3500_0] \n", - "\n", - " source_name line_name field_coords \\\n", - "0 [J2255-3500_0] [] [icrs, 22h55m57.68s, -35d00m00.00s] \n", - "3 [J2255-3500_0] [] [icrs, 22h55m57.68s, -35d00m00.00s] \n", - "2 [J2255-3500_0] [] [icrs, 22h55m57.68s, -35d00m00.00s] \n", - "1 [J2255-3500_0] [] [icrs, 22h55m57.68s, -35d00m00.00s] \n", - "\n", - " start_frequency end_frequency \n", - "0 8.607155e+10 8.629030e+10 \n", - "3 8.794655e+10 8.816530e+10 \n", - "2 9.632156e+10 9.654030e+10 \n", - "1 9.819656e+10 9.841530e+10 " - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "ps = ms.open_processing_set(\"data/gaincaltest2.ps.zarr\")\n", "\n", @@ -199,7 +56,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "id": "0613dde0-351b-4c16-a775-206029bb91ae", "metadata": {}, "outputs": [], @@ -209,21 +66,10 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "id": "79e116f8-09d8-4151-b30a-ada716b467f3", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(45, 8, 4)" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "V = dataset.VISIBILITY.mean(dim=\"time\").data.compute()\n", "V.shape" @@ -231,21 +77,10 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "id": "4dc7448f-20ab-4eca-9e8c-11eba2e971ca", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(957, 45, 8, 4)" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "s = dataset.VISIBILITY.shape\n", "s" @@ -271,7 +106,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "id": "0ce2493e-0c7e-4e6f-9711-5b63df1973ef", "metadata": {}, "outputs": [], @@ -281,7 +116,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "id": "e62cd1f3-36ce-44d0-8132-a31cdd4f4fa8", "metadata": {}, "outputs": [], @@ -291,63 +126,30 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "id": "6403894b-a634-45a8-9219-05aed71457f1", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(957, 45, 8, 2, 2)" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "G.matrix.shape" ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "id": "7ccb777d-9509-4d19-a2c1-77d53c9564f9", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(957, 8, 36)" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "G.parameters.shape" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "id": "3ee0983f-5a19-4001-973c-c88216caa061", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(45, 8, 2, 2)" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "g = G.matrix.mean(axis=0)\n", "g.shape" @@ -355,7 +157,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "id": "bf6e3883-7923-4a5f-be89-ac0244ef8c17", "metadata": {}, "outputs": [], @@ -372,28 +174,17 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": null, "id": "273a36a3-7f9d-4bfe-93de-3ac5888adab5", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(8, 4, 9, 9)" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "v_.shape" ] }, { "cell_type": "code", - "execution_count": 15, + "execution_count": null, "id": "2ff5cc79-02c8-49a0-9016-c7113a657226", "metadata": {}, "outputs": [], @@ -404,36 +195,14 @@ { "cell_type": "code", "execution_count": null, - "id": "e4db2435-76f6-4203-af7b-b78e69baaec0", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 16, "id": "99704e1f-e800-40bf-a0e4-5d4715e011b3", "metadata": {}, - "outputs": [ - { - "ename": "ValueError", - "evalue": "operands could not be broadcast together with shapes (9,9) (8,4,9) ", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[16], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m gain_solutions \u001b[38;5;241m=\u001b[39m \u001b[43msolver\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msolve\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2\u001b[0m \u001b[43m \u001b[49m\u001b[43mvis\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mv_\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3\u001b[0m \u001b[43m \u001b[49m\u001b[43miterations\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m40\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4\u001b[0m \u001b[43m \u001b[49m\u001b[43moptimizer\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcv\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmath\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptimizer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mMeanSquaredError\u001b[49m\u001b[43m(\u001b[49m\u001b[43malpha\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m0.2\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 5\u001b[0m \u001b[43m \u001b[49m\u001b[43mstopping\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1e-4\u001b[39;49m\n\u001b[1;32m 6\u001b[0m \u001b[43m)\u001b[49m\n", - "File \u001b[0;32m/export/home/fornax/Development/calviper/src/calviper/math/solver/least_squares.py:110\u001b[0m, in \u001b[0;36mLeastSquaresSolver.solve\u001b[0;34m(self, vis, iterations, optimizer, stopping)\u001b[0m\n\u001b[1;32m 99\u001b[0m gradient_ \u001b[38;5;241m=\u001b[39m optimizer\u001b[38;5;241m.\u001b[39mgradient(\n\u001b[1;32m 100\u001b[0m target\u001b[38;5;241m=\u001b[39mvis,\n\u001b[1;32m 101\u001b[0m model\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel_,\n\u001b[1;32m 102\u001b[0m parameter\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mparameter\n\u001b[1;32m 103\u001b[0m )\n\u001b[1;32m 105\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mparameter \u001b[38;5;241m=\u001b[39m optimizer\u001b[38;5;241m.\u001b[39mstep(\n\u001b[1;32m 106\u001b[0m parameter\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mparameter,\n\u001b[1;32m 107\u001b[0m gradient\u001b[38;5;241m=\u001b[39mgradient_\n\u001b[1;32m 108\u001b[0m )\n\u001b[0;32m--> 110\u001b[0m y_pred \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpredict\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 112\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlosses\u001b[38;5;241m.\u001b[39mappend(optimizer\u001b[38;5;241m.\u001b[39mloss(y_pred, vis))\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m n \u001b[38;5;241m%\u001b[39m (iterations \u001b[38;5;241m/\u001b[39m\u001b[38;5;241m/\u001b[39m \u001b[38;5;241m10\u001b[39m) \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m:\n", - "File \u001b[0;32m/export/home/fornax/Development/calviper/src/calviper/math/solver/least_squares.py:65\u001b[0m, in \u001b[0;36mLeastSquaresSolver.predict\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 62\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mpredict\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 64\u001b[0m n_channel, n_polarizations, n_antennas \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mparameter\u001b[38;5;241m.\u001b[39mshape\n\u001b[0;32m---> 65\u001b[0m parameter_matrix_ \u001b[38;5;241m=\u001b[39m \u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43midentity\u001b[49m\u001b[43m(\u001b[49m\u001b[43mn_antennas\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mparameter\u001b[49m\n\u001b[1;32m 66\u001b[0m cache_ \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mdot(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel_, parameter_matrix_)\n\u001b[1;32m 68\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m np\u001b[38;5;241m.\u001b[39mdot(parameter_matrix_\u001b[38;5;241m.\u001b[39mconj(), cache_)\n", - "\u001b[0;31mValueError\u001b[0m: operands could not be broadcast together with shapes (9,9) (8,4,9) " - ] - } - ], + "outputs": [], "source": [ "gain_solutions = solver.solve(\n", " vis=v_,\n", - " iterations=40,\n", - " optimizer=cv.math.optimizer.MeanSquaredError(alpha=0.2),\n", + " iterations=50,\n", + " optimizer=cv.math.optimizer.MeanSquaredError(alpha=0.25),\n", " stopping=1e-4\n", ")" ] @@ -471,32 +240,6 @@ "source": [ "solver.parameter.shape" ] - }, - { - "cell_type": "code", - "execution_count": 19, - "id": "3ef8f034-65d3-439c-a7c7-e9142fcf148b", - "metadata": {}, - "outputs": [], - "source": [ - "x = np.linspace(1, 3, 3)\n", - "X = np.tile(x, [3, 2, 1])" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "id": "b303fecf-8ac7-4869-a6c3-a0b0196b9582", - "metadata": {}, - "outputs": [], - "source": [ - "m = np.ones((3, 3))\n", - "\n", - "I = np.ones((3, 3))\n", - "np.fill_diagonal(I, 0.)\n", - "\n", - "Z = np.tile(I, [3, 2, 1, 1])" - ] } ], "metadata": { @@ -515,7 +258,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.19" + "version": "3.12.8" } }, "nbformat": 4, diff --git a/src/calviper/math/solver/least_squares.py b/src/calviper/math/solver/least_squares.py index a1f4caa..8ad5e57 100644 --- a/src/calviper/math/solver/least_squares.py +++ b/src/calviper/math/solver/least_squares.py @@ -1,3 +1,5 @@ +import numba + import numpy as np import toolviper.utils.logger as logger @@ -59,7 +61,7 @@ def _solve(self, vis, iterations, loss=mse, optimizer=None, alpha=0.1): return _gains ''' - def predict(self): + def predict_(self): n_channel, n_polarizations, n_antennas = self.parameter.shape parameter_matrix_ = np.identity(n_antennas) * self.parameter @@ -67,6 +69,24 @@ def predict(self): return np.dot(parameter_matrix_.conj(), cache_) + @staticmethod + @numba.njit() + def predict(model_, parameter)->np.ndarray: + n_channel, n_polarizations, n_antennas = parameter.shape + prediction = np.zeros_like(model_) + + # This can definitely be optimized, but I just want to test for now. + for channel in range(n_channel): + for polarization in range(n_polarizations): + for i in range(n_antennas): + for j in range(n_antennas): + if i == j: + continue + + prediction[channel, polarization, i, j] = parameter[channel, polarization, i] * model_[channel, polarization, i, j] * np.conj(parameter[channel, polarization, j]) + + return prediction + def solve(self, vis, iterations, optimizer=MeanSquaredError(), stopping=1e-3): # This is an attempt to do the solving in a vectorized way @@ -107,15 +127,15 @@ def solve(self, vis, iterations, optimizer=MeanSquaredError(), stopping=1e-3): gradient=gradient_ ) - #y_pred = self.predict() + y_pred = self.predict(self.model_, self.parameter) - #self.losses.append(optimizer.loss(y_pred, vis)) + self.losses.append(optimizer.loss(y_pred, vis)) - #if n % (iterations // 10) == 0: - # logger.info(f"iteration: {n}\tloss: {np.abs(self.losses[-1])}") + if n % (iterations // 10) == 0: + logger.info(f"iteration: {n}\tloss: {np.abs(self.losses[-1])}") - #if self.losses[-1] < stopping: - # logger.info(f"Iteration: ({n})\tStopping criterion reached: {self.losses[-1]}") - # break + if self.losses[-1] < stopping: + logger.info(f"Iteration: ({n})\tStopping criterion reached: {self.losses[-1]}") + break return self.parameter