diff --git a/HISTORY.rst b/HISTORY.rst index 6215303..629557c 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -6,6 +6,7 @@ History X.Y.Z (YYYY-MM-DD) ------------------ +* Add multi model column expression (:pr:`76`) * Removed .travis and .travis.yml (:pr:`71`) and (:pr:`73`) * Github CI Actions (:pr:`71`) diff --git a/setup.py b/setup.py index 166b828..917ae42 100644 --- a/setup.py +++ b/setup.py @@ -9,13 +9,16 @@ readme = readme_file.read() requirements = [ + 'dask-ms' + '@git+https://github.com/ska-sa/dask-ms.git' + '@master', 'dask[array] >= 2.2.0', 'donfig >= 0.4.0', 'numpy >= 1.14.0', 'numba >= 0.43.0', 'scipy >= 1.2.0', 'threadpoolctl >= 1.0.0', - 'dask-ms >= 0.2.3', + # 'dask-ms >= 0.2.3', 'zarr >= 2.3.1' ] diff --git a/tricolour/apps/tricolour/app.py b/tricolour/apps/tricolour/app.py index e420a7f..3f5dd1c 100644 --- a/tricolour/apps/tricolour/app.py +++ b/tricolour/apps/tricolour/app.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- """ Main tricolour application """ - +import warnings import re import argparse import contextlib @@ -22,6 +22,7 @@ CacheProfiler, visualize) import numpy as np from daskms import xds_from_ms, xds_from_table, xds_to_table +from daskms.expressions import data_column_expr, DataColumnParseError from threadpoolctl import threadpool_limits from tricolour.apps.tricolour.strat_executor import StrategyExecutor @@ -183,7 +184,9 @@ def create_parser(): help="Number of channels to dilate as int " "or string with units") p.add_argument("-dc", "--data-column", type=str, default="DATA", - help="Name of visibility data column to flag") + help="Name of visibility data column to flag " + "or an expression composed of DATA columns: " + "e.g \"DATA / (DIR1_DATA + DIR2_DATA + DIR3_DATA)\"") p.add_argument("-fn", "--field-names", type=str, action='append', default=[], help="Name(s) of fields to flag. Defaults to flagging all") @@ -211,8 +214,8 @@ def create_parser(): p.add_argument("-smc", "--subtract-model-column", default=None, type=str, help="Subtracts specified column from data column " "specified. " - "Flagging will proceed on residual " - "data.") + "Flagging will proceed on residual data." + "Deprecated argurment. Use --data-column instead") return p @@ -268,8 +271,6 @@ def _main(args): "Interactive Python Debugger, as per user request") post_mortem_handler.disable_pdb_on_error() - log.info("Flagging on the {0:s} column".format(args.data_column)) - data_column = args.data_column masked_channels = [load_mask(fn, dilate=args.dilate_masks) for fn in collect_masks()] GD = args.config @@ -281,22 +282,24 @@ def _main(args): # Index datasets by these columns index_cols = ['TIME'] - # Reopen the datasets using the aggregated row ordering - columns = [data_column, - "FLAG", - "TIME", - "ANTENNA1", - "ANTENNA2"] - - if args.subtract_model_column is not None: - columns.append(args.subtract_model_column) - xds = list(xds_from_ms(args.ms, - columns=tuple(columns), group_cols=group_cols, index_cols=index_cols, chunks={"row": args.row_chunks})) + try: + data_columns = [getattr(ds, args.data_column).data for ds in xds] + except AttributeError: + try: + data_columns = data_column_expr(args.data_column, xds) + except DataColumnParseError: + raise ValueError(f"{args.data_column} is neither an " + f"expression or a valid column") + + log.info(f"Flagging expression '{args.data_column}'") + else: + log.info(f"Flagging column '{args.data_column}'") + # Get support tables st = support_tables(args.ms) ddid_ds = st["DATA_DESCRIPTION"] @@ -352,7 +355,7 @@ def _main(args): final_stats = [] # Iterate through each dataset - for ds in xds: + for ds, vis in zip(xds, data_columns): if ds.FIELD_ID not in field_dict: continue @@ -367,14 +370,15 @@ def _main(args): spw_info = spw_ds[ddid_ds.SPECTRAL_WINDOW_ID.data[ds.DATA_DESC_ID]] pol_info = pol_ds[ddid_ds.POLARIZATION_ID.data[ds.DATA_DESC_ID]] - nrow, nchan, ncorr = getattr(ds, data_column).data.shape + nrow, nchan, ncorr = vis.shape # Visibilities from the dataset - vis = getattr(ds, data_column).data if args.subtract_model_column is not None: + warnings.warn("-subtract-model-column argument is deprecated." + "Use --data-column instead.") log.info("Forming residual data between '{0:s}' and " "'{1:s}' for flagging.".format( - data_column, args.subtract_model_column)) + args.data_column, args.subtract_model_column)) vismod = getattr(ds, args.subtract_model_column).data vis = vis - vismod diff --git a/tricolour/tests/test_acceptance.py b/tricolour/tests/test_acceptance.py index 41d971b..fa46a62 100644 --- a/tricolour/tests/test_acceptance.py +++ b/tricolour/tests/test_acceptance.py @@ -6,7 +6,7 @@ """ import os -from os.path import join as pjoin +from pathlib import Path import shutil import subprocess import tarfile @@ -17,6 +17,9 @@ import requests import pytest +import dask +from daskms import xds_from_ms, xds_to_table + _GOOGLE_FILE_ID = "1yxDIXUo3Xun9WXxA0x_hvX9Fmxo9Igpr" _MS_FILENAME = '1519747221.subset.ms' @@ -53,28 +56,35 @@ def _download_file_from_google_drive(id, destination): _save_response_content(response, destination) -# Set timeout to 6 minutes -@pytest.fixture(params=[360], scope="module") -def flagged_ms(request, tmp_path_factory): - """ - fixture yielding an MS flagged by tricolour - """ +@pytest.fixture(scope="session") +def ms_tarfile(tmp_path_factory): try: - tarred_ms_filename = os.environ["TRICOLOUR_TEST_MS"] + tarred_ms_filename = Path(os.environ["TRICOLOUR_TEST_MS"]) except KeyError: - tar_dir = tmp_path_factory.mktemp("tar-download") - tarred_ms_filename = os.path.join(tar_dir, "test_data.tar.gz") + tar_dir = tmp_path_factory.mktemp("acceptance-download-") + tarred_ms_filename = tar_dir / "test_data.tar.gz" _download_file_from_google_drive(_GOOGLE_FILE_ID, tarred_ms_filename) - tmp_path = str(tmp_path_factory.mktemp('data')) + yield tarred_ms_filename + + +@pytest.fixture(scope="function") +def ms_filename(ms_tarfile, tmp_path_factory): + tmp_path = tmp_path_factory.mktemp("acceptance-data-") - # Open and extract tarred ms - tarred_ms = tarfile.open(tarred_ms_filename) - tarred_ms.extractall(tmp_path) + with tarfile.open(ms_tarfile) as tarred_ms: + tarred_ms.extractall(tmp_path) - # Set up our paths - ms_filename = pjoin(tmp_path, _MS_FILENAME) + yield str(Path(tmp_path / _MS_FILENAME)) + + +# Set timeout to 6 minutes +@pytest.fixture(params=[360], scope="function") +def flagged_ms(request, ms_filename): + """ + fixture yielding an MS flagged by tricolour + """ test_directory = os.path.dirname(__file__) args = ['tricolour', @@ -103,10 +113,11 @@ def flagged_ms(request, tmp_path_factory): elif ret != 0: raise RuntimeError("Tricolour exited with non-zero return code") - yield ms_filename - - # Remove MS - shutil.rmtree(ms_filename) + try: + yield ms_filename + finally: + # Remove MS + shutil.rmtree(ms_filename) @pytest.mark.parametrize("tol", [1e3]) @@ -239,3 +250,61 @@ def test_bandwidth_flagged(flagged_ms, tol): print("Percent bandwidth flagged for PKS1934-63: %.3f%%" % (100. * flagged_ratio)) assert flagged_ratio < tol + + +@pytest.fixture(params=[360], scope="function") +def multi_model_ms(request, ms_filename): + """ + Multi-model 'DATA' column + """ + test_directory = os.path.dirname(__file__) + + # Open ms + xds = xds_from_ms(ms_filename) + # Create 'MODEL_DATA' column + for i, ds in enumerate(xds): + dims = ds.DATA.dims + xds[i] = ds.assign(MODEL_DATA=(dims, ds.DATA.data / 2)) + + # Write 'MODEL_DATA column - delayed operation + writes = xds_to_table(xds, ms_filename, "MODEL_DATA") + dask.compute(writes) + + # pass the expression to Tricolour + args = ['tricolour', + '-fs', 'total_power', + '-c', os.path.join(test_directory, 'custom.yaml'), + '-dc', 'DATA - MODEL_DATA', + ms_filename] + + p = subprocess.Popen(args, env=os.environ.copy()) + delay = 1.0 + timeout = int(request.param / delay) + + while p.poll() is None and timeout > 0: + time.sleep(delay) + timeout -= delay + + # timeout reached, kill process if it is still rolling + ret = p.poll() + + if ret is None: + p.kill() + ret = 99 + + if ret == 99: + raise RuntimeError("Test timeout reached. Killed flagger") + elif ret != 0: + raise RuntimeError("Tricolour exited with non-zero return code") + + try: + yield ms_filename + finally: + shutil.rmtree(ms_filename) + + +def test_multi_model(multi_model_ms): + """ + Test Multi-model 'DATA' column + """ + pass diff --git a/tricolour/window_statistics.py b/tricolour/window_statistics.py index 6057e70..ad935be 100644 --- a/tricolour/window_statistics.py +++ b/tricolour/window_statistics.py @@ -112,7 +112,6 @@ def window_stats(flag_window, ubls, chan_freqs, Dask array containing a single :class:`WindowStatistics` object. `prev_stats` is merged into this result, if present. """ - # Construct as array of per-baseline stats objects stats = da.blockwise(_window_stats, ("bl",), flag_window, _WINDOW_SCHEMA,