From e44846df351dd2ee999d1955db24ff2a13563672 Mon Sep 17 00:00:00 2001 From: Sakhile Masoka Date: Fri, 30 Oct 2020 11:39:52 +0200 Subject: [PATCH 01/15] support multi-column expression and depency on daskms master --- setup.py | 5 +- tricolour/apps/tricolour/app.py | 26 ++++++++- tricolour/apps/tricolour/somthing.py | 82 ++++++++++++++++++++++++++++ 3 files changed, 110 insertions(+), 3 deletions(-) create mode 100644 tricolour/apps/tricolour/somthing.py diff --git a/setup.py b/setup.py index 166b828..917ae42 100644 --- a/setup.py +++ b/setup.py @@ -9,13 +9,16 @@ readme = readme_file.read() requirements = [ + 'dask-ms' + '@git+https://github.com/ska-sa/dask-ms.git' + '@master', 'dask[array] >= 2.2.0', 'donfig >= 0.4.0', 'numpy >= 1.14.0', 'numba >= 0.43.0', 'scipy >= 1.2.0', 'threadpoolctl >= 1.0.0', - 'dask-ms >= 0.2.3', + # 'dask-ms >= 0.2.3', 'zarr >= 2.3.1' ] diff --git a/tricolour/apps/tricolour/app.py b/tricolour/apps/tricolour/app.py index 836ee49..7515f8c 100644 --- a/tricolour/apps/tricolour/app.py +++ b/tricolour/apps/tricolour/app.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- """ Main tricolour application """ - +import warnings import re import argparse import contextlib @@ -22,6 +22,7 @@ CacheProfiler, visualize) import numpy as np from daskms import xds_from_ms, xds_from_table, xds_to_table +from daskms.expressions import data_column_expr from threadpoolctl import threadpool_limits from tricolour.apps.tricolour.strat_executor import StrategyExecutor @@ -269,7 +270,14 @@ def _main(args): post_mortem_handler.disable_pdb_on_error() log.info("Flagging on the {0:s} column".format(args.data_column)) - data_column = args.data_column + if "/" in args.data_column: + data_string = args.data_column.split("/") + data_column = data_string[0] + string_columns = data_string[1] + else: + data_column = args.data_column + string_columns = None + masked_channels = [load_mask(fn, dilate=args.dilate_masks) for fn in collect_masks()] GD = args.config @@ -288,7 +296,18 @@ def _main(args): "ANTENNA1", "ANTENNA2"] + # Get the columns to read + if string_columns is not None and "(" in string_columns: + match = re.search(r'\(([A-Za-z0-9+-_]+)\)', string_columns) + multi_columns = re.split(r'\+|-', match.group(1)) + multi_columns = list(filter(None, multi_columns)) + columns.extend(multi_columns) + else: + if string_columns is not None: + columns.append(string_columns) + if args.subtract_model_column is not None: + warnings.warn("Use -dc arg") columns.append(args.subtract_model_column) xds = list(xds_from_ms(args.ms, @@ -297,6 +316,9 @@ def _main(args): index_cols=index_cols, chunks={"row": args.row_chunks})) + string = "EXPR = " + args.data_column + vis = data_column_expr(string, xds) + # Get support tables st = support_tables(args.ms) ddid_ds = st["DATA_DESCRIPTION"] diff --git a/tricolour/apps/tricolour/somthing.py b/tricolour/apps/tricolour/somthing.py new file mode 100644 index 0000000..332c06b --- /dev/null +++ b/tricolour/apps/tricolour/somthing.py @@ -0,0 +1,82 @@ +import re +import argparse +import logging +import logging.handlers + + +def create_logger(): + """ Create a console logger """ + log = logging.getLogger("tricolour") + cfmt = logging.Formatter(u'%(name)s - %(asctime)s ' + '%(levelname)s - %(message)s') + log.setLevel(logging.DEBUG) + filehandler = logging.FileHandler("tricolour.log") + filehandler.setFormatter(cfmt) + log.addHandler(filehandler) + log.setLevel(logging.INFO) + + console = logging.StreamHandler() + console.setLevel(logging.INFO) + console.setFormatter(cfmt) + + log.addHandler(console) + + return log + + +log = create_logger() + + +def create_parser(): + # warnings.warn("Use -dc flag") + formatter = argparse.ArgumentDefaultsHelpFormatter + p = argparse.ArgumentParser(formatter_class=formatter) + p.add_argument("ms", help="Measurement Set") + p.add_argument("-dc", "--data-column", type=str, default="DATA", + help="Name of visibility data column to flag. You can " + "specify DATA(+-)MODEL for column arithmetic") + p.add_argument("-smc", "--subtract-model-column", default=None, type=str, + help="Subtracts specified column from data column " + "specified. " + "Flagging will proceed on residual " + "data.") + + return p + + +if __name__ == "__main__": + args = create_parser().parse_args() + if "/" in args.data_column: + data_string = args.data_column.split("/") + data_column = data_string[0] + string_columns = data_string[1] + else: + data_column = args.data_column + string_columns = None + + print(data_column) + print(string_columns) + + columns = [data_column, + "FLAG", + "TIME", + "ANTENNA1", + "ANTENNA2"] + + if string_columns is not None and "(" in string_columns: + match = re.search(r'\(([A-Za-z0-9+-_]+)\)', string_columns) + multi_columns = re.split(r'\+|-', match.group(1)) + multi_columns = list(filter(None, multi_columns)) + print(str(multi_columns)) + columns.extend(multi_columns) + else: + if string_columns is not None: + columns.append(string_columns) + print(columns) +# Handles variety of cases +# -dc DATA +# -dc DATA/DATA_MODEL +# -dc DATA/(DATA_MODEL) +# -dc DATA/(DATA1 + DATA2 + DATA3) +# -dc DATA/(-DATA1 + DATA2 + DATA3) +# -dc DATA/(DATA1 + DATA2 - DATA3) From 5050a4337e1f40a21192e16ec8420ae4c7e84552 Mon Sep 17 00:00:00 2001 From: Sakhile Masoka Date: Fri, 30 Oct 2020 11:47:10 +0200 Subject: [PATCH 02/15] remove sothing.py --- tricolour/apps/tricolour/somthing.py | 82 ---------------------------- 1 file changed, 82 deletions(-) delete mode 100644 tricolour/apps/tricolour/somthing.py diff --git a/tricolour/apps/tricolour/somthing.py b/tricolour/apps/tricolour/somthing.py deleted file mode 100644 index 332c06b..0000000 --- a/tricolour/apps/tricolour/somthing.py +++ /dev/null @@ -1,82 +0,0 @@ -import re -import argparse -import logging -import logging.handlers - - -def create_logger(): - """ Create a console logger """ - log = logging.getLogger("tricolour") - cfmt = logging.Formatter(u'%(name)s - %(asctime)s ' - '%(levelname)s - %(message)s') - log.setLevel(logging.DEBUG) - filehandler = logging.FileHandler("tricolour.log") - filehandler.setFormatter(cfmt) - log.addHandler(filehandler) - log.setLevel(logging.INFO) - - console = logging.StreamHandler() - console.setLevel(logging.INFO) - console.setFormatter(cfmt) - - log.addHandler(console) - - return log - - -log = create_logger() - - -def create_parser(): - # warnings.warn("Use -dc flag") - formatter = argparse.ArgumentDefaultsHelpFormatter - p = argparse.ArgumentParser(formatter_class=formatter) - p.add_argument("ms", help="Measurement Set") - p.add_argument("-dc", "--data-column", type=str, default="DATA", - help="Name of visibility data column to flag. You can " - "specify DATA(+-)MODEL for column arithmetic") - p.add_argument("-smc", "--subtract-model-column", default=None, type=str, - help="Subtracts specified column from data column " - "specified. " - "Flagging will proceed on residual " - "data.") - - return p - - -if __name__ == "__main__": - args = create_parser().parse_args() - if "/" in args.data_column: - data_string = args.data_column.split("/") - data_column = data_string[0] - string_columns = data_string[1] - else: - data_column = args.data_column - string_columns = None - - print(data_column) - print(string_columns) - - columns = [data_column, - "FLAG", - "TIME", - "ANTENNA1", - "ANTENNA2"] - - if string_columns is not None and "(" in string_columns: - match = re.search(r'\(([A-Za-z0-9+-_]+)\)', string_columns) - multi_columns = re.split(r'\+|-', match.group(1)) - multi_columns = list(filter(None, multi_columns)) - print(str(multi_columns)) - columns.extend(multi_columns) - else: - if string_columns is not None: - columns.append(string_columns) - print(columns) -# Handles variety of cases -# -dc DATA -# -dc DATA/DATA_MODEL -# -dc DATA/(DATA_MODEL) -# -dc DATA/(DATA1 + DATA2 + DATA3) -# -dc DATA/(-DATA1 + DATA2 + DATA3) -# -dc DATA/(DATA1 + DATA2 - DATA3) From d18ccbe0d53393b34170bb53b7f77536ec2e4a25 Mon Sep 17 00:00:00 2001 From: Sakhile Masoka Date: Thu, 5 Nov 2020 09:24:03 +0200 Subject: [PATCH 03/15] 1. update argument help messages 2. extract visibility column name to support substract_model_column 3. removed columns arg from xds_from_ms to read entire ms 4. add warning for -smc arg depreciation --- HISTORY.rst | 1 + tricolour/apps/tricolour/app.py | 45 +++++++++------------------------ 2 files changed, 13 insertions(+), 33 deletions(-) diff --git a/HISTORY.rst b/HISTORY.rst index 6215303..629557c 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -6,6 +6,7 @@ History X.Y.Z (YYYY-MM-DD) ------------------ +* Add multi model column expression (:pr:`76`) * Removed .travis and .travis.yml (:pr:`71`) and (:pr:`73`) * Github CI Actions (:pr:`71`) diff --git a/tricolour/apps/tricolour/app.py b/tricolour/apps/tricolour/app.py index 7515f8c..b15505a 100644 --- a/tricolour/apps/tricolour/app.py +++ b/tricolour/apps/tricolour/app.py @@ -184,7 +184,10 @@ def create_parser(): help="Number of channels to dilate as int " "or string with units") p.add_argument("-dc", "--data-column", type=str, default="DATA", - help="Name of visibility data column to flag") + help="Name of visibility data column to flag." + "Now supports multi model columns expressions" + "e.g 'DATA / (DIR1_DATA + DIR2_DATA + DIR3_DATA)'" + "In future will replace subtract-model-column") p.add_argument("-fn", "--field-names", type=str, action='append', default=[], help="Name(s) of fields to flag. Defaults to flagging all") @@ -212,8 +215,8 @@ def create_parser(): p.add_argument("-smc", "--subtract-model-column", default=None, type=str, help="Subtracts specified column from data column " "specified. " - "Flagging will proceed on residual " - "data.") + "Flagging will proceed on residual data." + "Depreciating argurment. See --data-column") return p @@ -269,14 +272,10 @@ def _main(args): "Interactive Python Debugger, as per user request") post_mortem_handler.disable_pdb_on_error() - log.info("Flagging on the {0:s} column".format(args.data_column)) - if "/" in args.data_column: - data_string = args.data_column.split("/") - data_column = data_string[0] - string_columns = data_string[1] - else: - data_column = args.data_column - string_columns = None + # extract the name of the visibility column + # to support subtract_model_column + data_column = re.split(r'\+|-|/|\*|\(|\)', args.data_column)[0] + log.info("Flagging on the {0:s} column".format(data_column)) masked_channels = [load_mask(fn, dilate=args.dilate_masks) for fn in collect_masks()] @@ -289,29 +288,7 @@ def _main(args): # Index datasets by these columns index_cols = ['TIME'] - # Reopen the datasets using the aggregated row ordering - columns = [data_column, - "FLAG", - "TIME", - "ANTENNA1", - "ANTENNA2"] - - # Get the columns to read - if string_columns is not None and "(" in string_columns: - match = re.search(r'\(([A-Za-z0-9+-_]+)\)', string_columns) - multi_columns = re.split(r'\+|-', match.group(1)) - multi_columns = list(filter(None, multi_columns)) - columns.extend(multi_columns) - else: - if string_columns is not None: - columns.append(string_columns) - - if args.subtract_model_column is not None: - warnings.warn("Use -dc arg") - columns.append(args.subtract_model_column) - xds = list(xds_from_ms(args.ms, - columns=tuple(columns), group_cols=group_cols, index_cols=index_cols, chunks={"row": args.row_chunks})) @@ -398,6 +375,8 @@ def _main(args): # Visibilities from the dataset vis = getattr(ds, data_column).data if args.subtract_model_column is not None: + warnings.warn("-smc arg is depreciating." + "See -dc arg for expressions") log.info("Forming residual data between '{0:s}' and " "'{1:s}' for flagging.".format( data_column, args.subtract_model_column)) From 155c24430154e0ac0b0f4f1c6889244d929281f7 Mon Sep 17 00:00:00 2001 From: Sakhile Masoka Date: Wed, 6 Jan 2021 16:00:34 +0200 Subject: [PATCH 04/15] supports assignment (=) in expression --- tricolour/apps/tricolour/app.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/tricolour/apps/tricolour/app.py b/tricolour/apps/tricolour/app.py index b15505a..8ca0c5d 100644 --- a/tricolour/apps/tricolour/app.py +++ b/tricolour/apps/tricolour/app.py @@ -186,8 +186,7 @@ def create_parser(): p.add_argument("-dc", "--data-column", type=str, default="DATA", help="Name of visibility data column to flag." "Now supports multi model columns expressions" - "e.g 'DATA / (DIR1_DATA + DIR2_DATA + DIR3_DATA)'" - "In future will replace subtract-model-column") + "e.g 'EXPR = DATA / (DIR1_DATA + DIR2_DATA + DIR3_DATA)'") p.add_argument("-fn", "--field-names", type=str, action='append', default=[], help="Name(s) of fields to flag. Defaults to flagging all") @@ -216,7 +215,7 @@ def create_parser(): help="Subtracts specified column from data column " "specified. " "Flagging will proceed on residual data." - "Depreciating argurment. See --data-column") + "Deprecated argurment. Use --data-column instead") return p @@ -274,8 +273,13 @@ def _main(args): # extract the name of the visibility column # to support subtract_model_column - data_column = re.split(r'\+|-|/|\*|\(|\)', args.data_column)[0] - log.info("Flagging on the {0:s} column".format(data_column)) + # lhs - left hand side + data_column, *lhs = args.data_column.split("=") + log.info("Flagging on the {0:s} {1:s}".format(data_column, + "column" if not lhs else "expression")) + + if lhs: + data_column = re.split(r'\+|-|/|\*|\(|\)', lhs[0])[0] masked_channels = [load_mask(fn, dilate=args.dilate_masks) for fn in collect_masks()] @@ -293,8 +297,7 @@ def _main(args): index_cols=index_cols, chunks={"row": args.row_chunks})) - string = "EXPR = " + args.data_column - vis = data_column_expr(string, xds) + vis = data_column_expr(args.data_column, xds) # Get support tables st = support_tables(args.ms) @@ -375,8 +378,8 @@ def _main(args): # Visibilities from the dataset vis = getattr(ds, data_column).data if args.subtract_model_column is not None: - warnings.warn("-smc arg is depreciating." - "See -dc arg for expressions") + warnings.warn("-subtract-model-column argument is deprecated." + "Use --data-column instead.") log.info("Forming residual data between '{0:s}' and " "'{1:s}' for flagging.".format( data_column, args.subtract_model_column)) From 918a366032dab5e44cc583a3afb9cd57a410580d Mon Sep 17 00:00:00 2001 From: Sakhile Masoka Date: Thu, 7 Jan 2021 18:33:52 +0200 Subject: [PATCH 05/15] fix logic --- tricolour/apps/tricolour/app.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/tricolour/apps/tricolour/app.py b/tricolour/apps/tricolour/app.py index 8ca0c5d..f0222f3 100644 --- a/tricolour/apps/tricolour/app.py +++ b/tricolour/apps/tricolour/app.py @@ -278,9 +278,6 @@ def _main(args): log.info("Flagging on the {0:s} {1:s}".format(data_column, "column" if not lhs else "expression")) - if lhs: - data_column = re.split(r'\+|-|/|\*|\(|\)', lhs[0])[0] - masked_channels = [load_mask(fn, dilate=args.dilate_masks) for fn in collect_masks()] GD = args.config @@ -296,8 +293,8 @@ def _main(args): group_cols=group_cols, index_cols=index_cols, chunks={"row": args.row_chunks})) - - vis = data_column_expr(args.data_column, xds) + if lhs: + xds = data_column_expr(args.data_column, xds) # Get support tables st = support_tables(args.ms) From 0bce875e42c8277fb0c606c339a3607b2a19e24a Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Fri, 8 Jan 2021 10:49:59 +0200 Subject: [PATCH 06/15] Fix infinite recursion when trying to determine blockwise meta --- tricolour/window_statistics.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tricolour/window_statistics.py b/tricolour/window_statistics.py index 135329a..ad935be 100644 --- a/tricolour/window_statistics.py +++ b/tricolour/window_statistics.py @@ -112,7 +112,6 @@ def window_stats(flag_window, ubls, chan_freqs, Dask array containing a single :class:`WindowStatistics` object. `prev_stats` is merged into this result, if present. """ - # Construct as array of per-baseline stats objects stats = da.blockwise(_window_stats, ("bl",), flag_window, _WINDOW_SCHEMA, @@ -123,7 +122,7 @@ def window_stats(flag_window, ubls, chan_freqs, field_name, None, ddid, None, nchanbins, None, - dtype=np.object) + meta=np.empty((0,), dtype=np.object)) # Create an empty stats object if the user hasn't supplied one if prev_stats is None: @@ -131,13 +130,13 @@ def _window_stat_creator(): return WindowStatistics(nchanbins) prev_stats = da.blockwise(_window_stat_creator, (), - dtype=np.object) + meta=np.empty((), dtype=np.object)) # Combine per-baseline stats into a single stats object return da.blockwise(_combine_baseline_window_stats, (), stats, ("bl",), prev_stats, (), - dtype=np.object) + meta=np.empty((), dtype=np.object)) def _combine_window_stats(*args): From f961d617e9bca96fabb23c2d40bfc31f86919786 Mon Sep 17 00:00:00 2001 From: Sakhile Masoka Date: Mon, 18 Jan 2021 13:18:38 +0200 Subject: [PATCH 07/15] adds test for expression and strip data_column --- tricolour/apps/tricolour/app.py | 1 + tricolour/tests/test_acceptance.py | 103 ++++++++++++++++++++++++++++- 2 files changed, 103 insertions(+), 1 deletion(-) diff --git a/tricolour/apps/tricolour/app.py b/tricolour/apps/tricolour/app.py index f0222f3..a0b236d 100644 --- a/tricolour/apps/tricolour/app.py +++ b/tricolour/apps/tricolour/app.py @@ -275,6 +275,7 @@ def _main(args): # to support subtract_model_column # lhs - left hand side data_column, *lhs = args.data_column.split("=") + data_column = data_column.strip() log.info("Flagging on the {0:s} {1:s}".format(data_column, "column" if not lhs else "expression")) diff --git a/tricolour/tests/test_acceptance.py b/tricolour/tests/test_acceptance.py index 41d971b..fab787d 100644 --- a/tricolour/tests/test_acceptance.py +++ b/tricolour/tests/test_acceptance.py @@ -17,6 +17,12 @@ import requests import pytest +import dask +from daskms import xds_from_ms, xds_to_table +from daskms.expressions import data_column_expr +from numpy.testing import assert_array_equal + + _GOOGLE_FILE_ID = "1yxDIXUo3Xun9WXxA0x_hvX9Fmxo9Igpr" _MS_FILENAME = '1519747221.subset.ms' @@ -108,7 +114,6 @@ def flagged_ms(request, tmp_path_factory): # Remove MS shutil.rmtree(ms_filename) - @pytest.mark.parametrize("tol", [1e3]) def test_mean_chisq(flagged_ms, tol): """ Tests for improvement in mean chisq per correlation """ @@ -239,3 +244,99 @@ def test_bandwidth_flagged(flagged_ms, tol): print("Percent bandwidth flagged for PKS1934-63: %.3f%%" % (100. * flagged_ratio)) assert flagged_ratio < tol + +@pytest.fixture(params=[360], scope="module") +def multi_model_ms(request, tmp_path_factory): + """ + Multi-model 'DATA' column + """ + try: + tarred_ms_filename = os.environ["TRICOLOUR_TEST_MS"] + except KeyError: + tar_dir = tmp_path_factory.mktemp("tar-download") + tarred_ms_filename = os.path.join(tar_dir, "test_data.tar.gz") + + _download_file_from_google_drive(_GOOGLE_FILE_ID, tarred_ms_filename) + + tmp_path = str(tmp_path_factory.mktemp('data')) + + # Open and extract tarred ms + tarred_ms = tarfile.open(tarred_ms_filename) + tarred_ms.extractall(tmp_path) + + # Set up our paths + ms_filename = pjoin(tmp_path, _MS_FILENAME) + test_directory = os.path.dirname(__file__) + + # Open ms + xds = xds_from_ms(ms_filename) + # Create 'MODEL_DATA' column + for i, ds in enumerate(xds): + dims = ds.DATA.dims + xds[i] = ds.assign(MODEL_DATA=(dims, ds.DATA.data / 2)) + + # Write 'MODEL_DATA column - delayed operation + writes = xds_to_table(xds, ms_filename, "MODEL_DATA") + dask.compute(writes) + + # pass the expression to Tricolour + args = ['tricolour', + '-fs', 'total_power', + '-c', os.path.join(test_directory, 'custom.yaml'), + '-dc', 'FLAG_DATA = DATA - MODEL_DATA', + ms_filename] + + p = subprocess.Popen(args, env=os.environ.copy()) + delay = 1.0 + timeout = int(request.param / delay) + + while p.poll() is None and timeout > 0: + time.sleep(delay) + timeout -= delay + + # timeout reached, kill process if it is still rolling + ret = p.poll() + + if ret is None: + p.kill() + ret = 99 + + if ret == 99: + raise RuntimeError("Test timeout reached. Killed flagger") + elif ret != 0: + raise RuntimeError("Tricolour exited with non-zero return code") + + yield ms_filename + + # Remove MS + shutil.rmtree(ms_filename) + +def test_multi_model(multi_model_ms): + """ + Test Multi-model 'DATA' column + """ + # Open ms + xds = xds_from_ms(multi_model_ms) + # Create 'MODEL_DATA' column + for i, ds in enumerate(xds): + dims = ds.DATA.dims + xds[i] = ds.assign(MODEL_DATA=(dims, ds.DATA.data / 2)) + + # Write 'MODEL_DATA column - delayed operation + writes = xds_to_table(xds, multi_model_ms, "MODEL_DATA") + dask.compute(writes) + + # Redundant but test data_column_expr + # expression FLAG_DATA = DATA - MODEL_DATA + expr = "FLAG_DATA = DATA - MODEL_DATA" + xds = data_column_expr(expr, xds) + + with tbl(multi_model_ms) as t: + data = t.getcol("DATA") + model_data = t.getcol("MODEL_DATA") + + assert_array_equal(model_data, data / 2) + + for i, ds in enumerate(xds): + assert_array_equal(ds.DATA.data - ds.MODEL_DATA.data, + ds.FLAG_DATA.data) From 5c38afcacb450ba11adcc390a479ad6ce57bea1c Mon Sep 17 00:00:00 2001 From: Sakhile Masoka Date: Mon, 18 Jan 2021 13:42:57 +0200 Subject: [PATCH 08/15] flake8 fixup --- tricolour/tests/test_acceptance.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tricolour/tests/test_acceptance.py b/tricolour/tests/test_acceptance.py index fab787d..f548d24 100644 --- a/tricolour/tests/test_acceptance.py +++ b/tricolour/tests/test_acceptance.py @@ -114,6 +114,7 @@ def flagged_ms(request, tmp_path_factory): # Remove MS shutil.rmtree(ms_filename) + @pytest.mark.parametrize("tol", [1e3]) def test_mean_chisq(flagged_ms, tol): """ Tests for improvement in mean chisq per correlation """ @@ -245,6 +246,7 @@ def test_bandwidth_flagged(flagged_ms, tol): % (100. * flagged_ratio)) assert flagged_ratio < tol + @pytest.fixture(params=[360], scope="module") def multi_model_ms(request, tmp_path_factory): """ @@ -274,7 +276,7 @@ def multi_model_ms(request, tmp_path_factory): for i, ds in enumerate(xds): dims = ds.DATA.dims xds[i] = ds.assign(MODEL_DATA=(dims, ds.DATA.data / 2)) - + # Write 'MODEL_DATA column - delayed operation writes = xds_to_table(xds, ms_filename, "MODEL_DATA") dask.compute(writes) @@ -311,6 +313,7 @@ def multi_model_ms(request, tmp_path_factory): # Remove MS shutil.rmtree(ms_filename) + def test_multi_model(multi_model_ms): """ Test Multi-model 'DATA' column @@ -321,7 +324,7 @@ def test_multi_model(multi_model_ms): for i, ds in enumerate(xds): dims = ds.DATA.dims xds[i] = ds.assign(MODEL_DATA=(dims, ds.DATA.data / 2)) - + # Write 'MODEL_DATA column - delayed operation writes = xds_to_table(xds, multi_model_ms, "MODEL_DATA") dask.compute(writes) @@ -338,5 +341,5 @@ def test_multi_model(multi_model_ms): assert_array_equal(model_data, data / 2) for i, ds in enumerate(xds): - assert_array_equal(ds.DATA.data - ds.MODEL_DATA.data, + assert_array_equal(ds.DATA.data - ds.MODEL_DATA.data, ds.FLAG_DATA.data) From 6725ad17ebcd1245ad38c7adad390a1bcbe6eeac Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Thu, 11 Feb 2021 13:39:59 +0200 Subject: [PATCH 09/15] Rework with upstream dask changes --- tricolour/apps/tricolour/app.py | 31 ++++++----- tricolour/tests/test_acceptance.py | 87 +++++++++--------------------- 2 files changed, 42 insertions(+), 76 deletions(-) diff --git a/tricolour/apps/tricolour/app.py b/tricolour/apps/tricolour/app.py index a0b236d..9dc3c90 100644 --- a/tricolour/apps/tricolour/app.py +++ b/tricolour/apps/tricolour/app.py @@ -22,7 +22,7 @@ CacheProfiler, visualize) import numpy as np from daskms import xds_from_ms, xds_from_table, xds_to_table -from daskms.expressions import data_column_expr +from daskms.expressions import data_column_expr, DataColumnParseError from threadpoolctl import threadpool_limits from tricolour.apps.tricolour.strat_executor import StrategyExecutor @@ -271,13 +271,6 @@ def _main(args): "Interactive Python Debugger, as per user request") post_mortem_handler.disable_pdb_on_error() - # extract the name of the visibility column - # to support subtract_model_column - # lhs - left hand side - data_column, *lhs = args.data_column.split("=") - data_column = data_column.strip() - log.info("Flagging on the {0:s} {1:s}".format(data_column, - "column" if not lhs else "expression")) masked_channels = [load_mask(fn, dilate=args.dilate_masks) for fn in collect_masks()] @@ -294,8 +287,19 @@ def _main(args): group_cols=group_cols, index_cols=index_cols, chunks={"row": args.row_chunks})) - if lhs: - xds = data_column_expr(args.data_column, xds) + + try: + data_columns = data_column_expr(args.data_column, xds) + except DataColumnParseError: + try: + data_columns = [getattr(ds, args.data_column).data for ds in xds] + except AttributeError: + raise ValueError(f"{args.data_column} is neither an " + f"expression or a valid column") + + log.info(f"Flagging column '{args.data_column}") + else: + log.info(f"Flagging expression '{args.data_column}") # Get support tables st = support_tables(args.ms) @@ -355,7 +359,7 @@ def _main(args): final_stats = [] # Iterate through each dataset - for ds in xds: + for ds, vis in zip(xds, data_columns): if ds.FIELD_ID not in field_dict: continue @@ -371,16 +375,15 @@ def _main(args): spw_info = spw_ds[ddid.SPECTRAL_WINDOW_ID.data[0]] pol_info = pol_ds[ddid.POLARIZATION_ID.data[0]] - nrow, nchan, ncorr = getattr(ds, data_column).data.shape + nrow, nchan, ncorr = vis.shape # Visibilities from the dataset - vis = getattr(ds, data_column).data if args.subtract_model_column is not None: warnings.warn("-subtract-model-column argument is deprecated." "Use --data-column instead.") log.info("Forming residual data between '{0:s}' and " "'{1:s}' for flagging.".format( - data_column, args.subtract_model_column)) + args.data_column, args.subtract_model_column)) vismod = getattr(ds, args.subtract_model_column).data vis = vis - vismod diff --git a/tricolour/tests/test_acceptance.py b/tricolour/tests/test_acceptance.py index f548d24..aedc76e 100644 --- a/tricolour/tests/test_acceptance.py +++ b/tricolour/tests/test_acceptance.py @@ -6,7 +6,7 @@ """ import os -from os.path import join as pjoin +from pathlib import Path import shutil import subprocess import tarfile @@ -19,9 +19,6 @@ import dask from daskms import xds_from_ms, xds_to_table -from daskms.expressions import data_column_expr -from numpy.testing import assert_array_equal - _GOOGLE_FILE_ID = "1yxDIXUo3Xun9WXxA0x_hvX9Fmxo9Igpr" _MS_FILENAME = '1519747221.subset.ms' @@ -59,28 +56,34 @@ def _download_file_from_google_drive(id, destination): _save_response_content(response, destination) -# Set timeout to 6 minutes -@pytest.fixture(params=[360], scope="module") -def flagged_ms(request, tmp_path_factory): - """ - fixture yielding an MS flagged by tricolour - """ +@pytest.fixture(scope="session") +def ms_tarfile(tmp_path_factory): try: - tarred_ms_filename = os.environ["TRICOLOUR_TEST_MS"] + tarred_ms_filename = Path(os.environ["TRICOLOUR_TEST_MS"]) except KeyError: - tar_dir = tmp_path_factory.mktemp("tar-download") - tarred_ms_filename = os.path.join(tar_dir, "test_data.tar.gz") + tar_dir = tmp_path_factory.mktemp("acceptance-download-") + tarred_ms_filename = tar_dir / "test_data.tar.gz" _download_file_from_google_drive(_GOOGLE_FILE_ID, tarred_ms_filename) - tmp_path = str(tmp_path_factory.mktemp('data')) + yield tarred_ms_filename + + +@pytest.fixture(scope="session") +def ms_filename(ms_tarfile, tmp_path_factory): + tmp_path = tmp_path_factory.mktemp("acceptance-data-") + + with tarfile.open(ms_tarfile) as tarred_ms: + tarred_ms.extractall(tmp_path) - # Open and extract tarred ms - tarred_ms = tarfile.open(tarred_ms_filename) - tarred_ms.extractall(tmp_path) + yield str(Path(tmp_path / _MS_FILENAME)) - # Set up our paths - ms_filename = pjoin(tmp_path, _MS_FILENAME) +# Set timeout to 6 minutes +@pytest.fixture(params=[360], scope="module") +def flagged_ms(request, ms_filename): + """ + fixture yielding an MS flagged by tricolour + """ test_directory = os.path.dirname(__file__) args = ['tricolour', @@ -248,26 +251,10 @@ def test_bandwidth_flagged(flagged_ms, tol): @pytest.fixture(params=[360], scope="module") -def multi_model_ms(request, tmp_path_factory): +def multi_model_ms(request, ms_filename): """ Multi-model 'DATA' column """ - try: - tarred_ms_filename = os.environ["TRICOLOUR_TEST_MS"] - except KeyError: - tar_dir = tmp_path_factory.mktemp("tar-download") - tarred_ms_filename = os.path.join(tar_dir, "test_data.tar.gz") - - _download_file_from_google_drive(_GOOGLE_FILE_ID, tarred_ms_filename) - - tmp_path = str(tmp_path_factory.mktemp('data')) - - # Open and extract tarred ms - tarred_ms = tarfile.open(tarred_ms_filename) - tarred_ms.extractall(tmp_path) - - # Set up our paths - ms_filename = pjoin(tmp_path, _MS_FILENAME) test_directory = os.path.dirname(__file__) # Open ms @@ -285,7 +272,7 @@ def multi_model_ms(request, tmp_path_factory): args = ['tricolour', '-fs', 'total_power', '-c', os.path.join(test_directory, 'custom.yaml'), - '-dc', 'FLAG_DATA = DATA - MODEL_DATA', + '-dc', 'DATA - MODEL_DATA', ms_filename] p = subprocess.Popen(args, env=os.environ.copy()) @@ -318,28 +305,4 @@ def test_multi_model(multi_model_ms): """ Test Multi-model 'DATA' column """ - # Open ms - xds = xds_from_ms(multi_model_ms) - # Create 'MODEL_DATA' column - for i, ds in enumerate(xds): - dims = ds.DATA.dims - xds[i] = ds.assign(MODEL_DATA=(dims, ds.DATA.data / 2)) - - # Write 'MODEL_DATA column - delayed operation - writes = xds_to_table(xds, multi_model_ms, "MODEL_DATA") - dask.compute(writes) - - # Redundant but test data_column_expr - # expression FLAG_DATA = DATA - MODEL_DATA - expr = "FLAG_DATA = DATA - MODEL_DATA" - xds = data_column_expr(expr, xds) - - with tbl(multi_model_ms) as t: - data = t.getcol("DATA") - model_data = t.getcol("MODEL_DATA") - - assert_array_equal(model_data, data / 2) - - for i, ds in enumerate(xds): - assert_array_equal(ds.DATA.data - ds.MODEL_DATA.data, - ds.FLAG_DATA.data) + pass \ No newline at end of file From 48f8e73e6a6d625192163e5ab518bf9632077dc9 Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Thu, 11 Feb 2021 13:59:43 +0200 Subject: [PATCH 10/15] rework test fixtures so that download only occurs once --- tricolour/tests/test_acceptance.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/tricolour/tests/test_acceptance.py b/tricolour/tests/test_acceptance.py index aedc76e..3551aaa 100644 --- a/tricolour/tests/test_acceptance.py +++ b/tricolour/tests/test_acceptance.py @@ -69,7 +69,7 @@ def ms_tarfile(tmp_path_factory): yield tarred_ms_filename -@pytest.fixture(scope="session") +@pytest.fixture(scope="function") def ms_filename(ms_tarfile, tmp_path_factory): tmp_path = tmp_path_factory.mktemp("acceptance-data-") @@ -79,7 +79,7 @@ def ms_filename(ms_tarfile, tmp_path_factory): yield str(Path(tmp_path / _MS_FILENAME)) # Set timeout to 6 minutes -@pytest.fixture(params=[360], scope="module") +@pytest.fixture(params=[360], scope="function") def flagged_ms(request, ms_filename): """ fixture yielding an MS flagged by tricolour @@ -112,10 +112,11 @@ def flagged_ms(request, ms_filename): elif ret != 0: raise RuntimeError("Tricolour exited with non-zero return code") - yield ms_filename - - # Remove MS - shutil.rmtree(ms_filename) + try: + yield ms_filename + finally: + # Remove MS + shutil.rmtree(ms_filename) @pytest.mark.parametrize("tol", [1e3]) @@ -250,7 +251,7 @@ def test_bandwidth_flagged(flagged_ms, tol): assert flagged_ratio < tol -@pytest.fixture(params=[360], scope="module") +@pytest.fixture(params=[360], scope="function") def multi_model_ms(request, ms_filename): """ Multi-model 'DATA' column @@ -295,14 +296,14 @@ def multi_model_ms(request, ms_filename): elif ret != 0: raise RuntimeError("Tricolour exited with non-zero return code") - yield ms_filename - - # Remove MS - shutil.rmtree(ms_filename) + try: + yield ms_filename + finally: + shutil.rmtree(ms_filename) def test_multi_model(multi_model_ms): """ Test Multi-model 'DATA' column """ - pass \ No newline at end of file + pass From 01263b3530af1e2800b84db7be1f00a60b991f2e Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Thu, 11 Feb 2021 14:02:41 +0200 Subject: [PATCH 11/15] try the column case before the expression case --- tricolour/apps/tricolour/app.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tricolour/apps/tricolour/app.py b/tricolour/apps/tricolour/app.py index eb2695d..f5ac4b3 100644 --- a/tricolour/apps/tricolour/app.py +++ b/tricolour/apps/tricolour/app.py @@ -289,17 +289,17 @@ def _main(args): chunks={"row": args.row_chunks})) try: - data_columns = data_column_expr(args.data_column, xds) - except DataColumnParseError: + data_columns = [getattr(ds, args.data_column).data for ds in xds] + except AttributeError: try: - data_columns = [getattr(ds, args.data_column).data for ds in xds] - except AttributeError: + data_columns = data_column_expr(args.data_column, xds) + except DataColumnParseError: raise ValueError(f"{args.data_column} is neither an " f"expression or a valid column") - log.info(f"Flagging column '{args.data_column}") + log.info(f"Flagging expression '{args.data_column}'") else: - log.info(f"Flagging expression '{args.data_column}") + log.info(f"Flagging column '{args.data_column}'") # Get support tables st = support_tables(args.ms) From 86eeefb7a3bd8948b2b21ea0f0efa3ce11cac1c5 Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Thu, 11 Feb 2021 14:11:19 +0200 Subject: [PATCH 12/15] provoke a test run --- tricolour/apps/tricolour/app.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tricolour/apps/tricolour/app.py b/tricolour/apps/tricolour/app.py index f5ac4b3..e14cf36 100644 --- a/tricolour/apps/tricolour/app.py +++ b/tricolour/apps/tricolour/app.py @@ -288,6 +288,8 @@ def _main(args): index_cols=index_cols, chunks={"row": args.row_chunks})) + a = None + try: data_columns = [getattr(ds, args.data_column).data for ds in xds] except AttributeError: From de75c38da8a64f1283b3e37ce237f610bb878511 Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Thu, 11 Feb 2021 14:11:21 +0200 Subject: [PATCH 13/15] Revert "provoke a test run" This reverts commit 86eeefb7a3bd8948b2b21ea0f0efa3ce11cac1c5. --- tricolour/apps/tricolour/app.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tricolour/apps/tricolour/app.py b/tricolour/apps/tricolour/app.py index e14cf36..f5ac4b3 100644 --- a/tricolour/apps/tricolour/app.py +++ b/tricolour/apps/tricolour/app.py @@ -288,8 +288,6 @@ def _main(args): index_cols=index_cols, chunks={"row": args.row_chunks})) - a = None - try: data_columns = [getattr(ds, args.data_column).data for ds in xds] except AttributeError: From e3195d6e00eba49e008a3e856516442f81f41754 Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Thu, 11 Feb 2021 14:50:29 +0200 Subject: [PATCH 14/15] flake8 --- tricolour/apps/tricolour/app.py | 1 - tricolour/tests/test_acceptance.py | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/tricolour/apps/tricolour/app.py b/tricolour/apps/tricolour/app.py index f5ac4b3..6d3402d 100644 --- a/tricolour/apps/tricolour/app.py +++ b/tricolour/apps/tricolour/app.py @@ -271,7 +271,6 @@ def _main(args): "Interactive Python Debugger, as per user request") post_mortem_handler.disable_pdb_on_error() - masked_channels = [load_mask(fn, dilate=args.dilate_masks) for fn in collect_masks()] GD = args.config diff --git a/tricolour/tests/test_acceptance.py b/tricolour/tests/test_acceptance.py index 3551aaa..fa46a62 100644 --- a/tricolour/tests/test_acceptance.py +++ b/tricolour/tests/test_acceptance.py @@ -78,6 +78,7 @@ def ms_filename(ms_tarfile, tmp_path_factory): yield str(Path(tmp_path / _MS_FILENAME)) + # Set timeout to 6 minutes @pytest.fixture(params=[360], scope="function") def flagged_ms(request, ms_filename): From 9af1b4e6446c98396b6194b688b42612f28e8f4f Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Thu, 11 Feb 2021 15:33:10 +0200 Subject: [PATCH 15/15] Update help --- tricolour/apps/tricolour/app.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tricolour/apps/tricolour/app.py b/tricolour/apps/tricolour/app.py index 6d3402d..3f5dd1c 100644 --- a/tricolour/apps/tricolour/app.py +++ b/tricolour/apps/tricolour/app.py @@ -184,9 +184,9 @@ def create_parser(): help="Number of channels to dilate as int " "or string with units") p.add_argument("-dc", "--data-column", type=str, default="DATA", - help="Name of visibility data column to flag." - "Now supports multi model columns expressions" - "e.g 'EXPR = DATA / (DIR1_DATA + DIR2_DATA + DIR3_DATA)'") + help="Name of visibility data column to flag " + "or an expression composed of DATA columns: " + "e.g \"DATA / (DIR1_DATA + DIR2_DATA + DIR3_DATA)\"") p.add_argument("-fn", "--field-names", type=str, action='append', default=[], help="Name(s) of fields to flag. Defaults to flagging all")