diff --git a/CHANGELOG.md b/CHANGELOG.md index 86db50a96..fd4ee5bbe 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,7 +6,10 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [PEP 440](https://www.python.org/dev/peps/pep-0440/) and uses [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## [Unreleased] +[Unreleased] +### Changed +* [672](https://github.com/dbekaert/RAiDER/pull/672) - Linted the project with `ruff`. + ### Fixed * [679](https://github.com/dbekaert/RAiDER/pull/679) - Fixed a bug causing test_updateTrue to falsely pass. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index eb2173730..7a33f344b 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -111,6 +111,22 @@ git commit -a -m "Put here the synthetic commit message" git push my_user_name my_new_feature_branch ``` +### Formatting and linting with [Ruff](https://docs.astral.sh/ruff/) ### + +Format your code to follow the style of the project with: +``` +ruff format +``` +and check for linting problems with: +``` +ruff check +``` +Please ensure that any linting problems in your changes are resolved before +submitting a pull request. +> [!TIP] +> vscode users can [install the ruff extension](https://marketplace.visualstudio.com/items?itemName=charliermarsh.ruff) to run the linter automatically in the +editor. + ### Issue a pull request from GitHub UI ### commit locally and push. To get a reasonable history, you may need to diff --git a/pyproject.toml b/pyproject.toml index d5815cdb2..8ac601856 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,3 +66,28 @@ multi_line_output = 5 default_section = "THIRDPARTY" [tool.setuptools_scm] + +[tool.ruff] +line-length = 120 +src = ["tools", "test"] + +[tool.ruff.format] +indent-style = "space" +quote-style = "single" + +[tool.ruff.lint] +extend-select = [ + "I", # isort: https://docs.astral.sh/ruff/rules/#isort-i + "UP", # pyupgrade: https://docs.astral.sh/ruff/rules/#pyupgrade-up + "D", # pydocstyle: https://docs.astral.sh/ruff/rules/#pydocstyle-d + "ANN", # annotations: https://docs.astral.sh/ruff/rules/#flake8-annotations-ann + "PTH", # use-pathlib-pth: https://docs.astral.sh/ruff/rules/#flake8-use-pathlib-pth +] +ignore = ["ANN101", "D200", "D205", "D212"] + +[tool.ruff.lint.pydocstyle] +convention = "google" + +[tool.ruff.lint.isort] +case-sensitive = true +lines-after-imports = 2 diff --git a/setup.py b/setup.py index a8a76bc1e..9d18af2df 100755 --- a/setup.py +++ b/setup.py @@ -5,7 +5,6 @@ # RESERVED. United States Government Sponsorship acknowledged. # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -import re from pathlib import Path import numpy as np diff --git a/test/__init__.py b/test/__init__.py index a948eaff4..71ef256d8 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -1,7 +1,4 @@ import os -import pytest -import subprocess -import shutil import string import random from contextlib import contextmanager @@ -31,37 +28,6 @@ def pushd(dir): os.chdir(prevdir) -def update_yaml(dct_cfg:dict, dst:str='temp.yaml'): - """ Write a new yaml file from a dictionary. - - Updates parameters in the default 'template.yaml' file. - Each key:value pair will in 'dct_cfg' will overwrite that in the default - """ - import RAiDER, yaml - - run_config_path = os.path.join( - os.path.dirname(RAiDER.__file__), - 'cli', - 'examples', - 'template', - 'template.yaml' - ) - - with open(run_config_path, 'r') as f: - try: - params = yaml.safe_load(f) - except yaml.YAMLError as exc: - print(exc) - raise ValueError(f'Something is wrong with the yaml file {run_config_path}') - - params = {**params, **dct_cfg} - - with open(dst, 'w') as fh: - yaml.safe_dump(params, fh, default_flow_style=False) - - return dst - - def makeLatLonGrid(bbox, reg, out_dir, spacing=0.1): """ Make lat lons at a specified spacing """ S, N, W, E = bbox diff --git a/test/_scenario_1.py b/test/_scenario_1.py index faaa61063..4f296f1e3 100755 --- a/test/_scenario_1.py +++ b/test/_scenario_1.py @@ -10,7 +10,7 @@ from RAiDER.delay import main from RAiDER.utilFcns import rio_open from RAiDER.checkArgs import makeDelayFileNames -from RAiDER.cli.validators import modelName2Module +from RAiDER.cli.validators import get_wm_by_name SCENARIO_DIR = os.path.join(TEST_DIR, "scenario_1") _RTOL = 1e-2 @@ -93,7 +93,7 @@ def core_test_tropo_delay(tmp_path, modelName): if not os.path.exists(wmLoc): os.mkdir(wmLoc) - _, model_obj = modelName2Module(modelName) + _, model_obj = get_wm_by_name(modelName) wet_file, hydro_file = makeDelayFileNames( time, Zenith, "envi", modelName, tmp_path ) diff --git a/test/_scenario_2.py b/test/_scenario_2.py index c02a77146..4f3cc9fef 100644 --- a/test/_scenario_2.py +++ b/test/_scenario_2.py @@ -9,7 +9,7 @@ from RAiDER.delay import main from RAiDER.losreader import Zenith -from RAiDER.cli.validators import modelName2Module +from RAiDER.cli.validators import get_wm_by_name SCENARIO_DIR = os.path.join(TEST_DIR, "scenario_2") _RTOL = 1e-2 @@ -34,7 +34,7 @@ def test_computeDelay(tmp_path): lats = stats['Lat'].values lons = stats['Lon'].values - _, model_obj = modelName2Module('ERA5') + _, model_obj = get_wm_by_name('ERA5') with pushd(tmp_path): diff --git a/test/scenario_1/raider_example_1.yaml b/test/scenario_1/raider_example_1.yaml index 0d2febc09..7305a400d 100644 --- a/test/scenario_1/raider_example_1.yaml +++ b/test/scenario_1/raider_example_1.yaml @@ -11,7 +11,3 @@ height_group: height_levels: 0 50 100 500 1000 # Return only these specific height levels los_group: # absent other options ZTD is calculated - runtime_group: - output_directory: test/scenario_1 - - diff --git a/test/synthetic_test/weather_files_synth/ERA-5_2019_11_17_T20_51_58_5S_2S_41W_37W.nc b/test/synthetic_test/weather_files_synth/ERA-5_2019_11_17_T20_51_58_5S_2S_41W_37W.nc new file mode 100644 index 000000000..1771cb7ba Binary files /dev/null and b/test/synthetic_test/weather_files_synth/ERA-5_2019_11_17_T20_51_58_5S_2S_41W_37W.nc differ diff --git a/test/test_GUNW.py b/test/test_GUNW.py index d35ae981b..bc6c52748 100644 --- a/test/test_GUNW.py +++ b/test/test_GUNW.py @@ -18,6 +18,7 @@ import RAiDER.cli.raider as raider import RAiDER.s1_azimuth_timing from RAiDER import aws +import RAiDER.aria.prepFromGUNW from RAiDER.aria.prepFromGUNW import ( check_hrrr_dataset_availablity_for_s1_azimuth_time_interpolation, check_weather_model_availability,_get_acq_time_from_gunw_id, @@ -369,6 +370,18 @@ def test_check_weather_model_availability_over_alaska(test_gunw_path_factory, we assert cond +@pytest.mark.parametrize('weather_model_name', ['ERA5', 'GMAO', 'MERRA2', 'HRRR']) +def test_check_weather_model_availability_2(weather_model_name): + gunw_id = Path("test/gunw_test_data/S1-GUNW-D-R-059-tops-20230320_20220418-180300-00179W_00051N-PP-c92e-v2_0_6.nc") + assert check_weather_model_availability(gunw_id, weather_model_name) + + +def test_check_weather_model_availability_3(): + gunw_id = Path("test/gunw_test_data/S1-GUNW-D-R-059-tops-20230320_20220418-180300-00179W_00051N-PP-c92e-v2_0_6.nc") + with pytest.raises(ValueError): + check_weather_model_availability(gunw_id, 'NotAModel') + + @pytest.mark.parametrize('weather_model_name', ['HRRR']) @pytest.mark.parametrize('location', ['california-t71', 'alaska']) def test_weather_model_availability_integration_using_valid_range(location, @@ -509,7 +522,7 @@ def test_hyp3_exits_succesfully_when_hrrr_not_available(mocker): side_effect=[False]) # The gunw id should not have a hyp3 file associated with it # This call will still hit the HRRR s3 API as done in the previous test - mocker.patch("RAiDER.aws.get_s3_file", side_effect=['hyp3-job-uuid-3ad24/S1-GUNW-A-R-106-tops-20160809_20140101-160001-00078W_00041N-PP-4be8-v3_0_0.nc']) + mocker.patch("RAiDER.aws.get_s3_file", side_effect=[Path('hyp3-job-uuid-3ad24/S1-GUNW-A-R-106-tops-20160809_20140101-160001-00078W_00041N-PP-4be8-v3_0_0.nc')]) mocker.patch('RAiDER.aria.prepFromGUNW.check_weather_model_availability') iargs = [ '--bucket', 's3://foo', @@ -622,22 +635,21 @@ def test_check_hrrr_availability_all_true(): gunw_id = "S1-GUNW-A-R-106-tops-20220115_20211222-225947-00078W_00041N-PP-4be8-v3_0_0" # Mock _get_acq_time_from_gunw_id to return expected times - result = check_hrrr_dataset_availablity_for_s1_azimuth_time_interpolation(gunw_id) - assert result == True + assert check_hrrr_dataset_availablity_for_s1_azimuth_time_interpolation(gunw_id) def test_get_slc_ids_from_gunw(): - test_path = 'test/gunw_test_data/S1-GUNW-D-R-059-tops-20230320_20220418-180300-00179W_00051N-PP-c92e-v2_0_6.nc' + test_path = Path('test/gunw_test_data/S1-GUNW-D-R-059-tops-20230320_20220418-180300-00179W_00051N-PP-c92e-v2_0_6.nc') assert get_slc_ids_from_gunw(test_path, 'reference') == 'S1A_IW_SLC__1SDV_20230320T180251_20230320T180309_047731_05BBDB_DCA0.zip' assert get_slc_ids_from_gunw(test_path, 'secondary') == 'S1A_IW_SLC__1SDV_20220418T180246_20220418T180305_042831_051CC3_3C47.zip' with pytest.raises(FileNotFoundError): - get_slc_ids_from_gunw('dummy.nc') + get_slc_ids_from_gunw(Path('dummy.nc')) with pytest.raises(ValueError): get_slc_ids_from_gunw(test_path, 'tertiary') with pytest.raises(OSError): - get_slc_ids_from_gunw('test/weather_files/ERA-5_2020_01_30_T13_52_45_32N_35N_120W_115W.nc') + get_slc_ids_from_gunw(Path('test/weather_files/ERA-5_2020_01_30_T13_52_45_32N_35N_120W_115W.nc')) def test_get_acq_time_valid_slc_id(): @@ -653,20 +665,3 @@ def test_get_acq_time_invalid_slc_id(): invalid_slc_id = "test/gunw_azimuth_test_data/S1B_OPER_AUX_POEORB_OPOD_20210731T111940_V20210710T225942_20210712T005942.EOF" with pytest.raises(ValueError): get_acq_time_from_slc_id(invalid_slc_id) - - -def test_check_weather_model_availability(): - gunw_id = "test/gunw_test_data/S1-GUNW-D-R-059-tops-20230320_20220418-180300-00179W_00051N-PP-c92e-v2_0_6.nc" - weather_models = ['ERA5', 'GMAO', 'MERRA2', 'HRRR'] - for wm in weather_models: - assert check_weather_model_availability(gunw_id, wm) - - with pytest.raises(ValueError): - check_weather_model_availability(gunw_id, 'NotAModel') - -def test_check_weather_model_availability_2(): - gunw_id = "test/gunw_test_data/S1-GUNW-D-R-059-tops-20230320_20220418-180300-00179W_00051N-PP-c92e-v2_0_6.nc" - weather_models = ['ERA5', 'GMAO', 'MERRA2', 'HRRR'] - fail_check = [True, True, True, True] - for wm, check in zip(weather_models, fail_check): - assert check_weather_model_availability(gunw_id, wm)==check diff --git a/test/test_HRRR_ztd.py b/test/test_HRRR_ztd.py index e1ea38cfe..4777de4ad 100644 --- a/test/test_HRRR_ztd.py +++ b/test/test_HRRR_ztd.py @@ -1,42 +1,21 @@ -import os -import subprocess -import shutil -import glob - -from test import TEST_DIR, WM, update_yaml, pushd +from test import TEST_DIR, pushd import numpy as np import xarray as xr from RAiDER.cli.raider import calcDelays - def test_scenario_1(tmp_path, data_for_hrrr_ztd, mocker): + SCENARIO_DIR = TEST_DIR / "scenario_1" + test_path = SCENARIO_DIR / 'raider_example_1.yaml' + mocker.patch('RAiDER.processWM.prepareWeatherModel', + side_effect=[str(data_for_hrrr_ztd)]) + with pushd(tmp_path): - dct_group = { - "aoi_group": {"bounding_box": [36, 37, -92, -91]}, - "date_group": {"date_start": "20200101"}, - "time_group": {"time": "12:00:00", "interpolate_time": "none"}, - "weather_model": "HRRR", - "height_group": {"height_levels": [0, 50, 100, 500, 1000]}, - "look_dir": "right", - "runtime_group": {"output_directory": "test/scenario_1"}, - } - - cfg = update_yaml(dct_group, os.path.join(tmp_path, "temp.yaml")) - - SCENARIO_DIR = os.path.join(tmp_path, TEST_DIR, "scenario_1") - mocker.patch( - "RAiDER.processWM.prepareWeatherModel", side_effect=[str(data_for_hrrr_ztd)] - ) - calcDelays([os.path.join(tmp_path, "temp.yaml")]) + calcDelays([str(test_path)]) + new_data = xr.load_dataset('HRRR_tropo_20200101T120000_ztd.nc') - new_data = xr.load_dataset( - os.path.join( - tmp_path, "test", "scenario_1", "HRRR_tropo_20200101T120000_ztd.nc" - ) - ) - new_data1 = new_data.sel(x=-91.84, y=36.84, z=0, method="nearest") - golden_data = 2.2622863, 0.0361021 # hydro|wet + new_data1 = new_data.sel(x=-91.84, y=36.84, z=0, method='nearest') + golden_data = 2.2622863, 0.0361021 # hydro|wet - np.testing.assert_almost_equal(golden_data[0], new_data1["hydro"].data) - np.testing.assert_almost_equal(golden_data[1], new_data1["wet"].data) + np.testing.assert_almost_equal(golden_data[0], new_data1["hydro"].data) + np.testing.assert_almost_equal(golden_data[1], new_data1["wet"].data) diff --git a/test/test_checkArgs.py b/test/test_checkArgs.py index 8e0de7533..cf2f031c4 100644 --- a/test/test_checkArgs.py +++ b/test/test_checkArgs.py @@ -1,194 +1,173 @@ import datetime import os import shutil -import pytest +from pathlib import Path -import multiprocessing as mp -import numpy as np import pandas as pd +import pytest -from test import TEST_DIR, pushd - -from RAiDER.cli import DEFAULT_DICT -from RAiDER.checkArgs import checkArgs, makeDelayFileNames, get_raster_ext -from RAiDER.llreader import BoundingBox, StationFile, RasterRDR -from RAiDER.losreader import Zenith, Conventional, Raytracing +from RAiDER.checkArgs import checkArgs, get_raster_ext, makeDelayFileNames +from RAiDER.cli.types import AOIGroup, DateGroup, HeightGroupUnparsed, LOSGroup, RunConfig, RuntimeGroup, TimeGroup +from RAiDER.llreader import BoundingBox, RasterRDR, StationFile +from RAiDER.losreader import Zenith from RAiDER.models.gmao import GMAO +from test import TEST_DIR, pushd -SCENARIO_1 = os.path.join(TEST_DIR, "scenario_1") -SCENARIO_2 = os.path.join(TEST_DIR, "scenario_2") +SCENARIO_1 = os.path.join(TEST_DIR, 'scenario_1') +SCENARIO_2 = os.path.join(TEST_DIR, 'scenario_2') @pytest.fixture(autouse=True) def args(): - d = DEFAULT_DICT - d["date_list"] = [datetime.datetime(2018, 1, 1)] - d["time"] = datetime.time(12, 0, 0) - d["aoi"] = BoundingBox([38, 39, -92, -91]) - d["los"] = Zenith() - d["weather_model"] = GMAO() - - for f in "weather_files weather_dir".split(): - shutil.rmtree(f) if os.path.exists(f) else "" + d = RunConfig( + weather_model=GMAO(), + date_group=DateGroup(date_list=[datetime.datetime(2018, 1, 1)]), + time_group=TimeGroup(time=datetime.time(12, 0, 0)), + aoi_group=AOIGroup(aoi=BoundingBox([38, 39, -92, -91])), + los_group=LOSGroup(los=Zenith()), + height_group=HeightGroupUnparsed(), + runtime_group=RuntimeGroup(), + ) + + for f in 'weather_files weather_dir'.split(): + shutil.rmtree(f) if os.path.exists(f) else '' return d -def isWriteable(dirpath): - """Test whether a directory is writeable""" +def isWriteable(dirpath: Path) -> bool: + """Test whether a directory is writeable.""" try: - filehandle = open(os.path.join(dirpath, "tmp.txt"), "w") - filehandle.close() + with (dirpath / 'tmp.txt').open('w'): + pass return True except IOError: return False def test_checkArgs_outfmt_1(args): - """Test that passing height levels with hdf5 outformat works""" - args = args - args.file_format = "h5" - args.heightlvls = [10, 100, 1000] - checkArgs(args) - assert os.path.splitext(args.wetFilenames[0])[-1] == ".h5" + args.runtime_group.file_format = 'h5' + args.height_group.height_levels = [10, 100, 1000] + args = checkArgs(args) + assert os.path.splitext(args.wetFilenames[0])[-1] == '.h5' def test_checkArgs_outfmt_2(args): - """Test that passing a raster format with height levels throws an error""" - args = args - args.heightlvs = [10, 100, 1000] - args.file_format = "GTiff" + args.runtime_group.file_format = 'GTiff' + args.height_group.height_levels = [10, 100, 1000] args = checkArgs(args) - assert os.path.splitext(args.wetFilenames[0])[-1] == ".nc" + assert os.path.splitext(args.wetFilenames[0])[-1] == '.nc' def test_checkArgs_outfmt_3(args): - """Test that passing a raster format with height levels throws an error""" - args = args with pytest.raises(FileNotFoundError): - args.aoi = StationFile(os.path.join("fake_dir", "stations.csv")) + args.aoi_group.aoi = StationFile(os.path.join('fake_dir', 'stations.csv')) def test_checkArgs_outfmt_4(args): - """Test that passing a raster format with height levels throws an error""" - args = args - args.aoi = RasterRDR( - lat_file=os.path.join(SCENARIO_1, "geom", "lat.dat"), - lon_file=os.path.join(SCENARIO_1, "geom", "lon.dat"), + args.aoi_group.aoi = RasterRDR( + lat_file=os.path.join(SCENARIO_1, 'geom', 'lat.dat'), + lon_file=os.path.join(SCENARIO_1, 'geom', 'lon.dat'), ) - argDict = checkArgs(args) - assert argDict.aoi.type() == "radar_rasters" + args = checkArgs(args) + assert args.aoi_group.aoi.type() == 'radar_rasters' def test_checkArgs_outfmt_5(args, tmp_path): with pushd(tmp_path): - args = args - args.aoi = StationFile(os.path.join(SCENARIO_2, "stations.csv")) - argDict = checkArgs(args) - assert pd.read_csv(argDict["wetFilenames"][0]).shape == (8, 4) + args.aoi_group.aoi = StationFile(os.path.join(SCENARIO_2, 'stations.csv')) + args = checkArgs(args) + assert pd.read_csv(args.wetFilenames[0]).shape == (8, 4) def test_checkArgs_outloc_1(args): - """Test that the default output and weather model directories are correct""" + """Test that the default output and weather model directories are correct.""" args = args argDict = checkArgs(args) - out = argDict["output_directory"] - wmLoc = argDict["weather_model_directory"] + out = argDict.runtime_group.output_directory + wmLoc = argDict.runtime_group.weather_model_directory assert os.path.abspath(out) == os.getcwd() - assert os.path.abspath(wmLoc) == os.path.join(os.getcwd(), "weather_files") + assert os.path.abspath(wmLoc) == os.path.join(os.getcwd(), 'weather_files') def test_checkArgs_outloc_2(args, tmp_path): - """Tests that the correct output location gets assigned when provided""" + """Tests that the correct output location gets assigned when provided.""" with pushd(tmp_path): - args = args - args.output_directory = tmp_path + args.runtime_group.output_directory = tmp_path argDict = checkArgs(args) - out = argDict["output_directory"] + out = argDict.runtime_group.output_directory assert out == tmp_path def test_checkArgs_outloc_2b(args, tmp_path): - """Tests that the weather model directory gets passed through by itself""" + """Tests that the weather model directory gets passed through by itself.""" with pushd(tmp_path): - args = args - args.output_directory = tmp_path - args.weather_model_directory = "weather_dir" + args.runtime_group.output_directory = tmp_path + wm_dir = Path('weather_dir') + args.runtime_group.weather_model_directory = wm_dir argDict = checkArgs(args) - assert argDict["weather_model_directory"] == "weather_dir" + assert argDict.runtime_group.weather_model_directory == wm_dir def test_checkArgs_outloc_3(args, tmp_path): - """Tests that the weather model directory gets created when needed""" + """Tests that the weather model directory gets created when needed.""" with pushd(tmp_path): - args = args - args.output_directory = tmp_path + args.runtime_group.output_directory = tmp_path argDict = checkArgs(args) - assert os.path.isdir(argDict["weather_model_directory"]) + assert argDict.runtime_group.weather_model_directory.is_dir() def test_checkArgs_outloc_4(args): - """Tests for creating writeable weather model directory""" + """Tests for creating writeable weather model directory.""" args = args argDict = checkArgs(args) - assert isWriteable(argDict["weather_model_directory"]) + assert isWriteable(argDict.runtime_group.weather_model_directory) def test_filenames_1(args): - """tests that the correct filenames are generated""" + """tests that the correct filenames are generated.""" args = args argDict = checkArgs(args) - assert "Delay" not in argDict["wetFilenames"][0] - assert "wet" in argDict["wetFilenames"][0] - assert "hydro" in argDict["hydroFilenames"][0] - assert "20180101" in argDict["wetFilenames"][0] - assert "20180101" in argDict["hydroFilenames"][0] - assert len(argDict["hydroFilenames"]) == 1 + assert 'Delay' not in argDict.wetFilenames[0] + assert 'wet' in argDict.wetFilenames[0] + assert 'hydro' in argDict.hydroFilenames[0] + assert '20180101' in argDict.wetFilenames[0] + assert '20180101' in argDict.hydroFilenames[0] + assert len(argDict.hydroFilenames) == 1 def test_filenames_2(args): - """tests that the correct filenames are generated""" - args = args - args["output_directory"] = SCENARIO_2 - args.aoi = StationFile(os.path.join(SCENARIO_2, "stations.csv")) + """Tests that the correct filenames are generated.""" + args.runtime_group.output_directory = Path(SCENARIO_2) + args.aoi_group.aoi = StationFile(os.path.join(SCENARIO_2, 'stations.csv')) argDict = checkArgs(args) - assert "20180101" in argDict["wetFilenames"][0] - assert len(argDict["wetFilenames"]) == 1 + assert '20180101' in argDict.wetFilenames[0] + assert len(argDict.wetFilenames) == 1 def test_makeDelayFileNames_1(): - assert makeDelayFileNames(None, None, "h5", "name", "dir") == ( - "dir/name_wet_ztd.h5", - "dir/name_hydro_ztd.h5", - ) + assert makeDelayFileNames(None, None, 'h5', 'name', Path('dir')) == ('dir/name_wet_ztd.h5', 'dir/name_hydro_ztd.h5') def test_makeDelayFileNames_2(): - assert makeDelayFileNames(None, (), "h5", "name", "dir") == ( - "dir/name_wet_std.h5", - "dir/name_hydro_std.h5", - ) + assert makeDelayFileNames(None, (), 'h5', 'name', Path('dir')) == ('dir/name_wet_std.h5', 'dir/name_hydro_std.h5') def test_makeDelayFileNames_3(): - assert makeDelayFileNames( - datetime.datetime(2020, 1, 1, 1, 2, 3), None, "h5", "model_name", "dir" - ) == ( - "dir/model_name_wet_20200101T010203_ztd.h5", - "dir/model_name_hydro_20200101T010203_ztd.h5", + assert makeDelayFileNames(datetime.datetime(2020, 1, 1, 1, 2, 3), None, 'h5', 'model_name', Path('dir')) == ( + 'dir/model_name_wet_20200101T010203_ztd.h5', + 'dir/model_name_hydro_20200101T010203_ztd.h5', ) def test_makeDelayFileNames_4(): - assert makeDelayFileNames( - datetime.datetime(1900, 12, 31, 1, 2, 3), "los", "h5", "model_name", "dir" - ) == ( - "dir/model_name_wet_19001231T010203_std.h5", - "dir/model_name_hydro_19001231T010203_std.h5", + assert makeDelayFileNames(datetime.datetime(1900, 12, 31, 1, 2, 3), 'los', 'h5', 'model_name', Path('dir')) == ( + 'dir/model_name_wet_19001231T010203_std.h5', + 'dir/model_name_hydro_19001231T010203_std.h5', ) def test_get_raster_ext(): with pytest.raises(ValueError): - get_raster_ext("dummy_format") + get_raster_ext('dummy_format') diff --git a/test/test_datelist.py b/test/test_datelist.py index 37177ef53..ebf7a3096 100644 --- a/test/test_datelist.py +++ b/test/test_datelist.py @@ -1,63 +1,62 @@ import datetime import os import shutil -from test import TEST_DIR, WM, update_yaml + +from RAiDER.utilFcns import write_yaml +from test import TEST_DIR, WM, pushd from RAiDER.cli.raider import read_run_config_file -def test_datelist(): +def test_datelist(tmp_path): SCENARIO_DIR = os.path.join(TEST_DIR, "datelist") if os.path.exists(SCENARIO_DIR): shutil.rmtree(SCENARIO_DIR) os.makedirs(SCENARIO_DIR, exist_ok=False) - dates = ["20200124", "20200130"] - true_dates = [datetime.datetime(2020, 1, 24), datetime.datetime(2020, 1, 30)] + dates = ['20200124', '20200130'] + true_dates = [ + datetime.date(2020, 1, 24), + datetime.date(2020, 1, 30), + ] dct_group = { - "aoi_group": {"bounding_box": [28, 28.3, -116.3, -116]}, - "date_group": {"date_list": dates}, - "time_group": {"time": "00:00:00", "interpolate_time": "none"}, - "weather_model": WM, - "runtime_group": { - "output_directory": SCENARIO_DIR, - "weather_model_directory": os.path.join(SCENARIO_DIR, "weather_files"), - }, - } - - cfg = update_yaml(dct_group, "temp.yaml") - file_to_del = "temp.yaml" - param_dict = read_run_config_file(cfg) - assert param_dict["date_list"] == true_dates - - if os.path.exists(file_to_del): - os.remove(file_to_del) - - -def test_datestep(): + 'aoi_group': {'bounding_box': [28, 28.3, -116.3, -116]}, + 'date_group': {'date_list': dates}, + 'time_group': {'time': '00:00:00', 'interpolate_time': 'none'}, + 'weather_model': WM, + 'runtime_group': { + 'output_directory': SCENARIO_DIR, + 'weather_model_directory': os.path.join(SCENARIO_DIR, 'weather_files') + } + } + + with pushd(tmp_path): + cfg = write_yaml(dct_group, 'temp.yaml') + param_dict = read_run_config_file(cfg) + assert param_dict.date_group.date_list == true_dates + + +def test_datestep(tmp_path): SCENARIO_DIR = os.path.join(TEST_DIR, "scenario_5") st, en, step = "20200124", "20200130", 3 true_dates = [ - datetime.datetime(2020, 1, 24), - datetime.datetime(2020, 1, 27), - datetime.datetime(2020, 1, 30), + datetime.date(2020, 1, 24), + datetime.date(2020, 1, 27), + datetime.date(2020, 1, 30), ] dct_group = { - "aoi_group": {"bounding_box": [28, 39, -123, -112]}, - "date_group": {"date_start": st, "date_end": en, "date_step": step}, - "time_group": {"time": "00:00:00", "interpolate_time": "none"}, - "weather_model": WM, - "runtime_group": { - "output_directory": SCENARIO_DIR, - "weather_model_directory": os.path.join(SCENARIO_DIR, "weather_files"), - }, - } - - cfg = update_yaml(dct_group, "temp.yaml") - file_to_del = "temp.yaml" - param_dict = read_run_config_file(cfg) - assert param_dict["date_list"] == true_dates - - if os.path.exists(file_to_del): - os.remove(file_to_del) + 'aoi_group': {'bounding_box': [28, 39, -123, -112]}, + 'date_group': {'date_start': st, 'date_end': en, 'date_step': step}, + 'time_group': {'time': '00:00:00', 'interpolate_time': 'none'}, + 'weather_model': WM, + 'runtime_group': { + 'output_directory': SCENARIO_DIR, + 'weather_model_directory': os.path.join(SCENARIO_DIR, 'weather_files') + } + } + + with pushd(tmp_path): + cfg = write_yaml(dct_group, 'temp.yaml') + param_dict = read_run_config_file(cfg) + assert param_dict.date_group.date_list == true_dates diff --git a/test/test_dem.py b/test/test_dem.py index a008c7e56..97fc06e44 100644 --- a/test/test_dem.py +++ b/test/test_dem.py @@ -1,4 +1,3 @@ -import os import pytest from test import TEST_DIR, pushd @@ -6,12 +5,13 @@ def test_download_dem_1(): - SCENARIO_1 = os.path.join(TEST_DIR, "scenario_4") + SCENARIO_1 = TEST_DIR / "scenario_4" hts, meta = download_dem( - demName=os.path.join(SCENARIO_1,'warpedDEM.dem'), + dem_path=SCENARIO_1 / 'warpedDEM.dem', overwrite=False ) assert hts.shape == (45,226) + assert meta is not None assert meta['crs'] is None @@ -22,17 +22,18 @@ def test_download_dem_2(): def test_download_dem_3(tmp_path): with pushd(tmp_path): - fname = os.path.join(tmp_path, 'tmp_file.nc') + path = tmp_path / 'tmp_file.nc' with pytest.raises(ValueError): - download_dem(demName=fname) + download_dem(dem_path=path) @pytest.mark.long def test_download_dem_4(tmp_path): with pushd(tmp_path): - fname = os.path.join(tmp_path, 'tmp_file.nc') - z,m = download_dem(demName=fname, overwrite=True, ll_bounds=[37.9,38.,-91.8,-91.7], writeDEM=True) + path = tmp_path / 'tmp_file.nc' + z, m = download_dem(dem_path=path, overwrite=True, ll_bounds=[37.9,38.,-91.8,-91.7], writeDEM=True) assert len(z.shape) == 2 + assert m is not None assert 'crs' in m.keys() diff --git a/test/test_downloadGNSS.py b/test/test_downloadGNSS.py index a4865b237..8665f7b1b 100644 --- a/test/test_downloadGNSS.py +++ b/test/test_downloadGNSS.py @@ -1,10 +1,7 @@ -import os import pytest import requests from unittest import mock -from test import TEST_DIR, pushd -from RAiDER.dem import download_dem from RAiDER.gnss.downloadGNSSDelays import ( check_url, in_box, @@ -14,6 +11,8 @@ main, ) +from test import pushd + # Test check_url with a valid and invalid URL def test_check_url_valid(): @@ -35,15 +34,13 @@ def test_in_box_inside(): lat = 38.0 lon = -97.0 llbox = [30, 40, -100, -90] # Sample bounding box - assert in_box(lat, lon, llbox) == True - + assert in_box(lat, lon, llbox) def test_in_box_outside(): lat = 50.0 lon = -80.0 llbox = [30, 40, -100, -90] # Sample bounding box - assert in_box(lat, lon, llbox) == False - + assert not in_box(lat, lon, llbox) # Test fix_lons with various longitudes def test_fix_lons_positive(): @@ -84,14 +81,12 @@ def test_get_ID_invalid(): def test_download_UNR(tmp_path): + expected_path = "http://geodesy.unr.edu/gps_timeseries/trop/MORZ/MORZ.2020.trop.zip" + statID = "MORZ" + year = 2020 with pushd(tmp_path): - statID = "MORZ" - year = 2020 outDict = download_UNR(statID, year) - assert ( - outDict["path"] - == "http://geodesy.unr.edu/gps_timeseries/trop/MORZ/MORZ.2020.trop.zip" - ) + assert outDict["path"] == expected_path def test_download_UNR_2(): @@ -115,7 +110,7 @@ def test_download_UNR_4(): download_UNR(statID, year, baseURL="www.google.com") +@pytest.mark.skip def test_main(): - # iargs = None - # main(inps=iargs) - assert True + iargs = None + main(inps=iargs) diff --git a/test/test_gnss.py b/test/test_gnss.py index bd7598abe..caa7dddb8 100644 --- a/test/test_gnss.py +++ b/test/test_gnss.py @@ -1,3 +1,4 @@ +from pathlib import Path from RAiDER.models.customExceptions import NoStationDataFoundError from RAiDER.gnss.downloadGNSSDelays import ( get_stats_by_llh, get_station_list, download_tropo_delays, @@ -15,15 +16,12 @@ import pandas as pd from test import pushd, TEST_DIR -from unittest import mock SCENARIO2_DIR = os.path.join(TEST_DIR, "scenario_2") -def file_len(fname): - with open(fname) as f: - for i, l in enumerate(f): - pass - return i + 1 +def file_len(path: Path) -> int: + with path.open('rb') as f: + return sum(1 for _ in f) @pytest.fixture @@ -40,11 +38,11 @@ def temp_file(): def test_getDateTime(): - f1 = '20080101T060000' - f2 = '20080101T560000' - f3 = '20080101T0600000' - f4 = '20080101_060000' - f5 = '2008-01-01T06:00:00' + f1 = Path('20080101T060000') + f2 = Path('20080101T560000') + f3 = Path('20080101T0600000') + f4 = Path('20080101_060000') + f5 = Path('2008-01-01T06:00:00') assert getDateTime(f1) == datetime.datetime(2008, 1, 1, 6, 0, 0) with pytest.raises(ValueError): getDateTime(f2) @@ -59,10 +57,10 @@ def test_addDateTimeToFiles1(tmp_path, temp_file): df = temp_file with pushd(tmp_path): - new_name = os.path.join(tmp_path, 'tmp.csv') - df.to_csv(new_name, index=False) - addDateTimeToFiles([new_name]) - df = pd.read_csv(new_name) + new_path = tmp_path / 'tmp.csv' + df.to_csv(new_path, index=False) + addDateTimeToFiles([new_path]) + df = pd.read_csv(new_path) assert 'Datetime' not in df.columns @@ -71,13 +69,10 @@ def test_addDateTimeToFiles2(tmp_path, temp_file): df = temp_file with pushd(tmp_path): - new_name = os.path.join( - tmp_path, - 'tmp' + f1 + '.csv' - ) - df.to_csv(new_name, index=False) - addDateTimeToFiles([new_name]) - df = pd.read_csv(new_name) + new_path = tmp_path / f'tmp{f1}.csv' + df.to_csv(new_path, index=False) + addDateTimeToFiles([new_path]) + df = pd.read_csv(new_path) assert 'Datetime' in df.columns @@ -86,25 +81,19 @@ def test_concatDelayFiles(tmp_path, temp_file): df = temp_file with pushd(tmp_path): - new_name = os.path.join( - tmp_path, - 'tmp' + f1 + '.csv' - ) - new_name2 = os.path.join( - tmp_path, - 'tmp' + f1 + '_2.csv' - ) - df.to_csv(new_name, index=False) - df.to_csv(new_name2, index=False) - file_length = file_len(new_name) - addDateTimeToFiles([new_name, new_name2]) - - out_name = os.path.join(tmp_path, 'out.csv') + new_path1 = tmp_path / f'tmp{f1}_1.csv' + new_path2 = tmp_path / f'tmp{f1}_2.csv' + df.to_csv(new_path1, index=False) + df.to_csv(new_path2, index=False) + file_length = file_len(new_path1) + addDateTimeToFiles([new_path1, new_path2]) + + out_path = tmp_path / 'out.csv' concatDelayFiles( - [new_name, new_name2], - outName=out_name + [new_path1, new_path2], + outName=out_path ) - assert file_len(out_name) == file_length + assert file_len(out_path) == file_length def test_get_stats_by_llh2(): diff --git a/test/test_interpolator.py b/test/test_interpolator.py index 7799967a0..a13c573b1 100644 --- a/test/test_interpolator.py +++ b/test/test_interpolator.py @@ -1,4 +1,5 @@ import os +from pathlib import Path import numpy as np import rasterio as rio import pytest @@ -936,19 +937,19 @@ def test_interpolateDEM(): metadata = {'driver': 'GTiff', 'dtype': 'float32', 'width': s, 'height': s, 'count': 1} - demFile = './dem_tmp.tif' + dem_file = Path('./dem_tmp.tif') - with rio.open(demFile, 'w', **metadata) as ds: + with rio.open(dem_file, 'w', **metadata) as ds: ds.write(dem, 1) ds.update_tags(AREA_OR_POINT='Point') ## random points to interpolate to lons = np.array([4.5, 9.5]) lats = np.array([2.5, 9.5]) - out = interpolateDEM(demFile, (lats, lons)) + out = interpolateDEM(dem_file, (lats, lons)) gold = np.array([[36, 81], [8, 18]], dtype=float) assert np.allclose(out, gold) - os.remove(demFile) + dem_file.unlink() # TODO: implement an interpolator test that is similar to test_scenario_1. # Currently the scipy and C++ interpolators differ on that case. diff --git a/test/test_intersect.py b/test/test_intersect.py index 8d567ebc0..f7c217c0a 100644 --- a/test/test_intersect.py +++ b/test/test_intersect.py @@ -1,10 +1,15 @@ +from RAiDER.cli.raider import calcDelays +from RAiDER.utilFcns import write_yaml +import pytest +import os import pandas as pd -# import rasterio +import subprocess +import numpy as np from scipy.interpolate import griddata import rasterio -from test import * +from test import TEST_DIR, WM_DIR, pushd SCENARIO_DIR = os.path.join(TEST_DIR, "scenario_6") @@ -38,15 +43,10 @@ def test_cube_intersect(tmp_path, wm): } ## generate the default run config file and overwrite it with new parms - cfg = update_yaml(grp, "temp.yaml") + cfg = write_yaml(grp, 'temp.yaml') - # breakpoint() ## run raider and intersect - cmd = f"raider.py {cfg}" - proc = subprocess.run( - cmd.split(), stdout=subprocess.PIPE, universal_newlines=True - ) - assert proc.returncode == 0, "RAiDER Failed." + calcDelays([str(cfg)]) ## hard code what it should be and check it matches gold = {"ERA5": 2.2787, "GMAO": np.nan, "HRRR": np.nan} @@ -95,14 +95,10 @@ def test_gnss_intersect(tmp_path, wm): } ## generate the default run config file and overwrite it with new parms - cfg = update_yaml(grp) + cfg = write_yaml(grp, 'temp.yaml') ## run raider and intersect - cmd = f"raider.py {cfg}" - proc = subprocess.run( - cmd.split(), stdout=subprocess.PIPE, universal_newlines=True - ) - assert proc.returncode == 0, "RAiDER Failed." + calcDelays([str(cfg)]) gold = {"ERA5": 2.34514, "GMAO": np.nan, "HRRR": np.nan} df = pd.read_csv( @@ -112,5 +108,3 @@ def test_gnss_intersect(tmp_path, wm): # test for equality with golden data np.testing.assert_almost_equal(td.item(), gold[wm], decimal=4) - - return diff --git a/test/test_llreader.py b/test/test_llreader.py index c8b584014..b22ad3a73 100644 --- a/test/test_llreader.py +++ b/test/test_llreader.py @@ -1,4 +1,5 @@ import os +from pathlib import Path import pytest import numpy as np @@ -11,13 +12,12 @@ from RAiDER.utilFcns import rio_open from RAiDER.llreader import ( - StationFile, RasterRDR, BoundingBox, GeocodedFile, Geocube, - bounds_from_latlon_rasters, bounds_from_csv + StationFile, RasterRDR, BoundingBox, GeocodedFile, bounds_from_latlon_rasters, bounds_from_csv ) -SCENARIO0_DIR = os.path.join(TEST_DIR, "scenario_0") -SCENARIO1_DIR = os.path.join(TEST_DIR, "scenario_1", "geom") -SCENARIO2_DIR = os.path.join(TEST_DIR, "scenario_2") +SCENARIO0_DIR = TEST_DIR / "scenario_0" +SCENARIO1_DIR = TEST_DIR / "scenario_1/geom" +SCENARIO2_DIR = TEST_DIR / "scenario_2" @pytest.fixture @@ -27,12 +27,12 @@ def parser(): @pytest.fixture def station_file(): - return os.path.join(SCENARIO2_DIR, 'stations.csv') + return SCENARIO2_DIR / 'stations.csv' @pytest.fixture def llfiles(): - return os.path.join(SCENARIO1_DIR, 'lat.dat'), os.path.join(SCENARIO1_DIR, 'lon.dat') + return SCENARIO1_DIR / 'lat.dat', SCENARIO1_DIR / 'lon.dat' def test_latlon_reader_2(): @@ -69,12 +69,12 @@ def test_set_xygrid(): def test_latlon_reader(): - latfile = os.path.join(GEOM_DIR, 'lat.rdr') - lonfile = os.path.join(GEOM_DIR, 'lon.rdr') - lat_true = rio_open(latfile) - lon_true = rio_open(lonfile) + latfile = Path(GEOM_DIR) / 'lat.rdr' + lonfile = Path(GEOM_DIR) / 'lon.rdr' + lat_true, _ = rio_open(latfile) + lon_true, _ = rio_open(lonfile) - query = RasterRDR(lat_file=latfile, lon_file=lonfile) + query = RasterRDR(lat_file=str(latfile), lon_file=str(lonfile)) lats, lons = query.readLL() assert lats.shape == (45, 226) assert lons.shape == (45, 226) @@ -122,9 +122,9 @@ def test_read_station_file(station_file): def test_bounds_from_latlon_rasters(): - latfile = os.path.join(GEOM_DIR, 'lat.rdr') - lonfile = os.path.join(GEOM_DIR, 'lon.rdr') - snwe, _, _ = bounds_from_latlon_rasters(latfile, lonfile) + lat_path = Path(GEOM_DIR) / 'lat.rdr' + lon_path = Path(GEOM_DIR) / 'lon.rdr' + snwe, _, _ = bounds_from_latlon_rasters(str(lat_path), str(lon_path)) bounds_true =[15.7637, 21.4936, -101.6384, -98.2418] assert all([np.allclose(b, t, rtol=1e-4) for b, t in zip(snwe, bounds_true)]) @@ -142,10 +142,8 @@ def test_readZ_sf(station_file): def test_GeocodedFile(): - aoi = GeocodedFile(os.path.join(SCENARIO0_DIR, 'small_dem.tif'), is_dem=True) + aoi = GeocodedFile(SCENARIO0_DIR / 'small_dem.tif', is_dem=True) z = aoi.readZ() x,y = aoi.readLL() assert z.shape == (569,558) assert x.shape == z.shape - assert True - diff --git a/test/test_losreader.py b/test/test_losreader.py index 2a7125354..16fb0a0c3 100644 --- a/test/test_losreader.py +++ b/test/test_losreader.py @@ -1,3 +1,5 @@ +import pytest +import os import datetime import numpy as np import RAiDER @@ -12,7 +14,7 @@ Zenith, ) -from test import * +from test import ORB_DIR @pytest.fixture diff --git a/test/test_processWM.py b/test/test_processWM.py index e0dbbb4e1..7749fbbb7 100644 --- a/test/test_processWM.py +++ b/test/test_processWM.py @@ -1,7 +1,6 @@ import os import pytest -import numpy as np from test import TEST_DIR diff --git a/test/test_raiderDelay.py b/test/test_raiderDelay.py index 446ee0b91..f3a86947e 100644 --- a/test/test_raiderDelay.py +++ b/test/test_raiderDelay.py @@ -1,5 +1,3 @@ -from argparse import ArgumentParser, ArgumentTypeError -from datetime import datetime, time from RAiDER.cli.raider import drop_nans diff --git a/test/test_scenario_2.py b/test/test_scenario_2.py index 9d366c919..21517943f 100644 --- a/test/test_scenario_2.py +++ b/test/test_scenario_2.py @@ -1,11 +1,5 @@ -import os -import pytest -import subprocess -from test import TEST_DIR -import numpy as np -import xarray as xr #TODO: include GNSS station test # @pytest.mark.long diff --git a/test/test_scenario_4.py b/test/test_scenario_4.py index c5d1bd7f3..f1ed25efd 100644 --- a/test/test_scenario_4.py +++ b/test/test_scenario_4.py @@ -6,7 +6,6 @@ import numpy as np from pyproj import CRS -import RAiDER from RAiDER.delay import tropo_delay, _get_delays_on_cube from RAiDER.llreader import RasterRDR from RAiDER.losreader import Zenith diff --git a/test/test_slant.py b/test/test_slant.py index cbe405b12..e13f0b530 100644 --- a/test/test_slant.py +++ b/test/test_slant.py @@ -1,9 +1,17 @@ +from RAiDER.cli.raider import calcDelays +import pytest import glob +import os +import subprocess +import shutil import numpy as np import xarray as xr -from test import * +from test import ( + TEST_DIR, WM_DIR, ORB_DIR, make_delay_name +) +from RAiDER.utilFcns import write_yaml @pytest.mark.parametrize('weather_model_name', ['ERA5']) def test_slant_proj(weather_model_name): @@ -33,12 +41,10 @@ def test_slant_proj(weather_model_name): } ## generate the default run config file and overwrite it with new parms - cfg = update_yaml(grp, 'temp.yaml') + cfg = write_yaml(grp, 'temp.yaml') ## run raider and intersect - cmd = f'raider.py {cfg}' - proc = subprocess.run(cmd.split(), stdout=subprocess.PIPE, universal_newlines=True) - assert proc.returncode == 0, 'RAiDER Failed.' + calcDelays([str(cfg)]) gold = {'ERA5': [33.4, -117.8, 0, 2.333865144]} lat, lon, hgt, val = gold[weather_model_name] @@ -84,12 +90,10 @@ def test_ray_tracing(weather_model_name): } ## generate the default run config file and overwrite it with new parms - cfg = update_yaml(grp, 'temp.yaml') + cfg = write_yaml(grp, 'temp.yaml') ## run raider and intersect - cmd = f'raider.py {cfg}' - proc = subprocess.run(cmd.split(), stdout=subprocess.PIPE, universal_newlines=True) - assert proc.returncode == 0, 'RAiDER Failed.' + calcDelays([str(cfg)]) # model to lat/lon/correct value gold = {'ERA5': [33.4, -117.8, 0, 2.97711681]} diff --git a/test/test_synthetic.py b/test/test_synthetic.py index d37ead691..7249bda40 100644 --- a/test/test_synthetic.py +++ b/test/test_synthetic.py @@ -1,14 +1,22 @@ +import os import os.path as op +import shutil +import subprocess from dataclasses import dataclass from datetime import datetime +from pathlib import Path +import numpy as np +import pytest +import xarray as xr + +from RAiDER.cli.raider import calcDelays +from RAiDER.cli.validators import get_wm_by_name from RAiDER.llreader import BoundingBox -from RAiDER.models.weatherModel import make_weather_model_filename from RAiDER.losreader import Raytracing, build_ray -from RAiDER.utilFcns import lla2ecef -from RAiDER.cli.validators import modelName2Module - -from test import * +from RAiDER.models.weatherModel import make_weather_model_filename +from RAiDER.utilFcns import lla2ecef, write_yaml +from test import ORB_DIR, WM_DIR, pushd def update_model(wm_file: str, wm_eq_type: str, wm_dir: str = "weather_files_synth"): @@ -25,8 +33,8 @@ def update_model(wm_file: str, wm_eq_type: str, wm_dir: str = "weather_files_syn ), "Set wm_eq_type to hydro, wet_linear, or wet_nonlinear" # initialize dummy wm to calculate constant delays # any model will do as 1) all constants same 2) all equations same - model = op.basename(wm_file).split("_")[0].upper().replace("-", "") - Obj = modelName2Module(model)[1]() + model = op.basename(wm_file).split('_')[0].upper().replace("-", "") + Obj = get_wm_by_name(model)[1]() ds = xr.open_dataset(wm_file) t = ds["t"] p = ds["p"] @@ -110,7 +118,7 @@ def __init__(self, region: str, wmName: str, path: str): self.dts = self.dt.strftime("%Y_%m_%d_T%H_%M_%S") self.ttime = self.dt.strftime("%H:%M:%S") - self.wmObj = modelName2Module(self.wmName.upper().replace("-", ""))[1]() + self.wmObj = get_wm_by_name(self.wmName.upper().replace("-", ""))[1]() self.hgt_lvls = np.arange(-500, 9500, 500) self._cube_spacing_m = 10000.0 @@ -168,14 +176,16 @@ def setup_region(self): def make_config_dict(self): dct = { - "aoi_group": {"bounding_box": list(self.SNWE)}, - "height_group": {"height_levels": self.hgt_lvls.tolist()}, - "time_group": {"time": self.ttime, "interpolate_time": "none"}, - "date_group": {"date_list": datetime.strftime(self.dt, "%Y%m%d")}, - "cube_spacing_in_m": str(self._cube_spacing_m), - "los_group": {"ray_trace": True, "orbit_file": self.orbit}, - "weather_model": self.wmName, - "runtime_group": {"output_directory": self.wd}, + 'aoi_group': {'bounding_box': list(self.SNWE)}, + 'height_group': {'height_levels': self.hgt_lvls.tolist()}, + 'time_group': {'time': self.ttime, 'interpolate_time': 'none'}, + 'date_group': {'date_list': datetime.strftime(self.dt, '%Y%m%d')}, + 'los_group': {'ray_trace': True, 'orbit_file': self.orbit}, + 'weather_model': self.wmName, + 'runtime_group': { + 'output_directory': self.wd, + 'cube_spacing_in_m': self._cube_spacing_m, + }, } return dct @@ -196,14 +206,12 @@ def test_dl_real(tmp_path, region, mod="ERA5"): ) dct_cfg["download_only"] = True - cfg = update_yaml(dct_cfg) - ## run raider to download the real weather model - cmd = f"raider.py {cfg}" + cfg = write_yaml(dct_cfg, 'temp.yaml') - proc = subprocess.run( - cmd.split(), stdout=subprocess.PIPE, universal_newlines=True - ) - assert proc.returncode == 0, "RAiDER did not complete successfully" + ## run raider to download the real weather model + cmd = f'raider.py {cfg}' + proc = subprocess.run(cmd.split(), stdout=subprocess.PIPE, universal_newlines=True) + assert proc.returncode == 0, 'RAiDER did not complete successfully' @pytest.mark.parametrize("region", "AK LA Fort".split()) @@ -231,16 +239,11 @@ def test_hydrostatic_eq(tmp_path, region, mod="ERA-5"): dct_cfg["download_only"] = False ## update the weather model; t = p for hydrostatic - path_synth = update_model(SAobj.path_wm_real, "hydro", SAobj.wm_dir_synth) - - cfg = update_yaml(dct_cfg) + update_model(SAobj.path_wm_real, "hydro", SAobj.wm_dir_synth) ## run raider with the synthetic model - cmd = f"raider.py {cfg}" - proc = subprocess.run( - cmd.split(), stdout=subprocess.PIPE, universal_newlines=True - ) - assert proc.returncode == 0, "RAiDER did not complete successfully" + cfg = write_yaml(dct_cfg, 'temp.yaml') + calcDelays([str(cfg)]) # get the just created synthetic delays wm_name = SAobj.wmName.replace("-", "") # incase of ERA-5 @@ -310,17 +313,11 @@ def test_wet_eq_linear(tmp_path, region, mod="ERA-5"): dct_cfg["download_only"] = False ## update the weather model; t = e for wet1 - path_synth = update_model(SAobj.path_wm_real, "wet_linear", SAobj.wm_dir_synth) - - cfg = update_yaml(dct_cfg) + update_model(SAobj.path_wm_real, "wet_linear", SAobj.wm_dir_synth) ## run raider with the synthetic model - cmd = f"raider.py {cfg}" - proc = subprocess.run( - cmd.split(), stdout=subprocess.PIPE, universal_newlines=True - ) - - assert proc.returncode == 0, "RAiDER did not complete successfully" + cfg = write_yaml(dct_cfg, 'temp.yaml') + calcDelays([str(cfg)]) # get the just created synthetic delays wm_name = SAobj.wmName.replace("-", "") # incase of ERA-5 @@ -393,18 +390,11 @@ def test_wet_eq_nonlinear(tmp_path, region, mod="ERA-5"): dct_cfg["download_only"] = False ## update the weather model; t = e for wet1 - path_synth = update_model( - SAobj.path_wm_real, "wet_nonlinear", SAobj.wm_dir_synth - ) - - cfg = update_yaml(dct_cfg) + update_model(SAobj.path_wm_real, "wet_nonlinear", SAobj.wm_dir_synth) ## run raider with the synthetic model - cmd = f"raider.py {cfg}" - proc = subprocess.run( - cmd.split(), stdout=subprocess.PIPE, universal_newlines=True - ) - assert proc.returncode == 0, "RAiDER did not complete successfully" + cfg = write_yaml(dct_cfg, 'temp.yaml') + calcDelays([str(cfg)]) # get the just created synthetic delays wm_name = SAobj.wmName.replace("-", "") # incase of ERA-5 @@ -436,9 +426,9 @@ def test_wet_eq_nonlinear(tmp_path, region, mod="ERA-5"): np.testing.assert_almost_equal(0, resid, decimal=6) da.close() - os.remove("./temp.yaml") - os.remove("./error.log") - os.remove("./debug.log") + Path('./temp.yaml').unlink(missing_ok=True) + Path('./error.log').unlink(missing_ok=True) + Path('./debug.log').unlink(missing_ok=True) del da # delete temp directory diff --git a/test/test_temporal_interpolate.py b/test/test_temporal_interpolate.py index f78de9c90..d738ee383 100644 --- a/test/test_temporal_interpolate.py +++ b/test/test_temporal_interpolate.py @@ -1,11 +1,18 @@ +import pytest import glob import shutil +import os +import subprocess +import numpy as np +import xarray as xr -import pandas as pd -from test import * +from test import ( + WM, TEST_DIR +) from RAiDER.logger import logger +from RAiDER.utilFcns import write_yaml wm = 'ERA5' if WM == 'ERA-5' else WM @@ -34,7 +41,7 @@ def test_cube_timemean(): for hr in [hr1, hr2]: grp['time_group'].update({'time': f'{hr}:00:00'}) ## generate the default run config file and overwrite it with new parms - cfg = update_yaml(grp) + cfg = write_yaml(grp, 'temp.yaml') ## run raider for the default date cmd = f'raider.py {cfg}' @@ -43,7 +50,7 @@ def test_cube_timemean(): ## run interpolation in the middle of the two grp['time_group'] = {'time': ti, 'interpolate_time': 'center_time'} - cfg = update_yaml(grp) + cfg = write_yaml(grp, 'temp.yaml') cmd = f'raider.py {cfg}' proc = subprocess.run(cmd.split(), stdout=subprocess.PIPE, universal_newlines=True) @@ -96,7 +103,7 @@ def test_cube_weighting(): for hr in [hr1, hr2]: grp['time_group'].update({'time': f'{hr}:00:00'}) ## generate the default run config file and overwrite it with new parms - cfg = update_yaml(grp) + cfg = write_yaml(grp, 'temp.yaml') ## run raider for the default date cmd = f'raider.py {cfg}' @@ -105,7 +112,7 @@ def test_cube_weighting(): ## run interpolation very near the first grp['time_group'] = {'time': ti, 'interpolate_time': 'center_time'} - cfg = update_yaml(grp) + cfg = write_yaml(grp, 'temp.yaml') cmd = f'raider.py {cfg}' proc = subprocess.run(cmd.split(), stdout=subprocess.PIPE, universal_newlines=True) diff --git a/test/test_util.py b/test/test_util.py index eae32a734..ded21ed64 100644 --- a/test/test_util.py +++ b/test/test_util.py @@ -1,6 +1,6 @@ import datetime -import h5py import os +from pathlib import Path import pytest import numpy as np @@ -22,6 +22,7 @@ _R_EARTH = 6378138 SCENARIO_DIR = os.path.join(TEST_DIR, "scenario_1") +SCENARIO0_DIR = TEST_DIR / "scenario_0" @pytest.fixture @@ -121,7 +122,7 @@ def test_cosd(): def test_rio_open(): - out = rio_open(os.path.join(TEST_DIR, "test_geom", "lat.rdr"), False) + out, _ = rio_open(TEST_DIR / "test_geom/lat.rdr", False) assert np.allclose(out.shape, (45, 226)) @@ -130,10 +131,10 @@ def test_writeArrayToRaster(tmp_path): array = np.transpose( np.array([np.arange(0, 10)]) ) * np.arange(0, 10) - filename = str(tmp_path / 'dummy.out') + path = tmp_path / 'dummy.out' - writeArrayToRaster(array, filename) - with rasterio.open(filename) as src: + writeArrayToRaster(array, path) + with rasterio.open(path) as src: band = src.read(1) noval = src.nodatavals[0] @@ -144,35 +145,35 @@ def test_writeArrayToRaster(tmp_path): def test_writeArrayToRaster_2(): test = np.random.randn(10,10,10) with pytest.raises(RuntimeError): - writeArrayToRaster(test, 'dummy_file') + writeArrayToRaster(test, Path('dummy_file')) def test_writeArrayToRaster_3(tmp_path): test = np.random.randn(10,10) test = test + test * 1j with pushd(tmp_path): - fname = os.path.join(tmp_path, 'tmp_file.tif') - writeArrayToRaster(test, fname) - tmp = rio_profile(fname) + path = tmp_path / 'tmp_file.tif' + writeArrayToRaster(test, path) + tmp = rio_profile(path) assert tmp['dtype'] == 'complex64' def test_writeArrayToRaster_4(tmp_path): - SCENARIO0_DIR = os.path.join(TEST_DIR, "scenario_0") - geotif = os.path.join(SCENARIO0_DIR, 'small_dem.tif') + SCENARIO0_DIR = TEST_DIR / "scenario_0" + geotif = SCENARIO0_DIR / 'small_dem.tif' profile = rio_profile(geotif) - data = rio_open(geotif) + data, _ = rio_open(geotif) with pushd(tmp_path): - fname = os.path.join(tmp_path, 'tmp_file.nc') + path = tmp_path / 'tmp_file.nc' writeArrayToRaster( data, - fname, + path, proj=profile['crs'], gt=profile['transform'], fmt='nc', ) - new_fname = os.path.join(tmp_path, 'tmp_file.tif') - prof = rio_profile(new_fname) + new_path = tmp_path / 'tmp_file.tif' + prof = rio_profile(new_path) assert prof['driver'] == 'GTiff' @@ -253,16 +254,17 @@ def test_least_nonzero_2(): def test_rio_extent(): # Create a simple georeferenced test file - with rasterio.open("test.tif", mode="w", + test_file = Path("test.tif") + with rasterio.open(test_file, mode="w", width=11, height=11, count=1, dtype=np.float64, crs=pyproj.CRS.from_epsg(4326), transform=rasterio.Affine.from_gdal( 17.0, 0.1, 0, 18.0, 0, -0.1 )) as dst: dst.write(np.random.randn(11, 11), 1) - profile = rio_profile("test.tif") + profile = rio_profile(test_file) assert rio_extents(profile) == (17.0, 18.0, 17.0, 18.0) - os.remove("test.tif") + test_file.unlink() def test_getTimeFromFile(): @@ -520,15 +522,13 @@ def test_get_nearest_wmtimes_4(): def test_rio(): - SCENARIO0_DIR = os.path.join(TEST_DIR, "scenario_0") - geotif = os.path.join(SCENARIO0_DIR, 'small_dem.tif') + geotif = SCENARIO0_DIR / 'small_dem.tif' profile = rio_profile(geotif) assert profile['crs'] is not None def test_rio_2(): - SCENARIO0_DIR = os.path.join(TEST_DIR, "scenario_0") - geotif = os.path.join(SCENARIO0_DIR, 'small_dem.tif') + geotif = SCENARIO0_DIR / 'small_dem.tif' prof = rio_profile(geotif) del prof['transform'] with pytest.raises(KeyError): @@ -536,16 +536,16 @@ def test_rio_2(): def test_rio_3(): - SCENARIO0_DIR = os.path.join(TEST_DIR, "scenario_0") - geotif = os.path.join(SCENARIO0_DIR, 'small_dem.tif') - data = rio_open(geotif, returnProj=False, userNDV=None, band=1) + geotif = SCENARIO0_DIR / 'small_dem.tif' + data, _ = rio_open(geotif, userNDV=None, band=1) assert data.shape == (569,558) def test_rio_4(): - SCENARIO_DIR = os.path.join(TEST_DIR, "scenario_4") - los = os.path.join(SCENARIO_DIR, 'los.rdr') - inc, hd = rio_open(los, returnProj=False) + SCENARIO_DIR = TEST_DIR / "scenario_4" + los_path = SCENARIO_DIR / 'los.rdr' + los, _ = rio_open(los_path) + inc, hd = los assert len(inc.shape) == 2 assert len(hd.shape) == 2 diff --git a/test/test_validators.py b/test/test_validators.py index 4e348695d..d8a41aeab 100644 --- a/test/test_validators.py +++ b/test/test_validators.py @@ -1,23 +1,21 @@ from argparse import ArgumentParser -from datetime import datetime, time +from datetime import datetime, time, date import os import pytest import numpy as np -from test import TEST_DIR, pushd -SCENARIO = os.path.join(TEST_DIR, "scenario_4") - -from RAiDER.cli import AttributeDict - +from test import TEST_DIR +from RAiDER.cli.types import DateGroupUnparsed, LOSGroupUnparsed, TimeGroup from RAiDER.cli.validators import ( - modelName2Module, getBufferedExtent, isOutside, isInside, - enforce_valid_dates as date_type, convert_time as time_type, - enforce_bbox, parse_dates, enforce_wm, get_los + getBufferedExtent, isOutside, isInside, + coerce_into_date, + parse_bbox, parse_dates, parse_weather_model, get_los ) +SCENARIO = os.path.join(TEST_DIR, "scenario_4") @pytest.fixture def parser(): @@ -55,29 +53,34 @@ def llarray(): @pytest.fixture def args1(): test_file = os.path.join(SCENARIO, 'los.rdr') - args = AttributeDict({'los_file': test_file, 'los_convention': 'isce','ray_trace': False}) + args = { + 'los_file': test_file, + 'los_convention': 'isce', + 'ray_trace': False, + } return args def test_enforce_wm(): with pytest.raises(NotImplementedError): - enforce_wm('notamodel', 'fakeaoi') + parse_weather_model('notamodel', 'fakeaoi') def test_get_los_ray(args1): args = args1 - los = get_los(args) + los_group_unparsed = LOSGroupUnparsed(**args) + los = get_los(los_group_unparsed) assert not los.ray_trace() assert los.is_Projected() def test_date_type(): - assert date_type("2020-10-1") == datetime(2020, 10, 1) - assert date_type("2020101") == datetime(2020, 10, 1) + assert coerce_into_date("2020-10-1") == date(2020, 10, 1) + assert coerce_into_date("2020101") == date(2020, 10, 1) with pytest.raises(ValueError): - date_type("foobar") + coerce_into_date("foobar") @pytest.mark.parametrize("input,expected", ( @@ -96,44 +99,42 @@ def test_date_type(): )) @pytest.mark.parametrize("timezone", ("", "z", "+0000")) def test_time_type(input, timezone, expected): - assert time_type(input + timezone) == expected + assert TimeGroup.coerce_into_time(input + timezone) == expected def test_time_type_error(): with pytest.raises(ValueError): - time_type("foobar") + TimeGroup.coerce_into_time("foobar") def test_date_list_action(): - date_list = { - 'date_start':'20200101', - } - assert date_type(date_list['date_start']) == datetime(2020,1,1) - - - assert parse_dates(date_list) == [datetime(2020,1,1)] + date_group_unparsed = DateGroupUnparsed( + date_start='20200101', + ) + assert coerce_into_date(date_group_unparsed.date_start) == date(2020,1,1) + assert parse_dates(date_group_unparsed).date_list == [date(2020,1,1)] - date_list['date_end'] = '20200103' - assert date_type(date_list['date_end']) == datetime(2020,1,3) - assert parse_dates(date_list) == [datetime(2020,1,1), datetime(2020,1,2), datetime(2020,1,3)] + date_group_unparsed.date_end = '20200103' + assert coerce_into_date(date_group_unparsed.date_end) == date(2020,1,3) + assert parse_dates(date_group_unparsed).date_list == [date(2020,1,1), date(2020,1,2), date(2020,1,3)] - date_list['date_end'] = '20200112' - date_list['date_step'] = '5' - assert parse_dates(date_list) == [datetime(2020,1,1), datetime(2020,1,6), datetime(2020,1,11)] + date_group_unparsed.date_end = '20200112' + date_group_unparsed.date_step = '5' + assert parse_dates(date_group_unparsed).date_list == [date(2020,1,1), date(2020,1,6), date(2020,1,11)] def test_bbox_action(): bbox_str = "45 46 -72 -70" - assert len(enforce_bbox(bbox_str)) == 4 + assert len(parse_bbox(bbox_str)) == 4 - assert enforce_bbox(bbox_str) == [45, 46, -72, -70] + assert parse_bbox(bbox_str) == (45, 46, -72, -70) with pytest.raises(ValueError): - enforce_bbox("20 20 30 30") + parse_bbox("20 20 30 30") with pytest.raises(ValueError): - enforce_bbox("30 100 20 40") + parse_bbox("30 100 20 40") with pytest.raises(ValueError): - enforce_bbox("10 30 40 190") + parse_bbox("10 30 40 190") def test_ll1(llsimple): @@ -157,13 +158,18 @@ def test_ll4(llarray): def test_isOutside1(llsimple): - assert isOutside(getBufferedExtent(*llsimple), getBufferedExtent(*llsimple) + 1) + extent1 = getBufferedExtent(*llsimple) + extent2 = extent1[0] + 1, extent1[1] + 1, extent1[2] + 1, extent1[3] + 1 + assert isOutside(extent1, extent2) def test_isOutside2(llsimple): - assert not isOutside(getBufferedExtent(*llsimple), getBufferedExtent(*llsimple)) + extent = getBufferedExtent(*llsimple) + assert not isOutside(extent, extent) def test_isInside(llsimple): - assert isInside(getBufferedExtent(*llsimple), getBufferedExtent(*llsimple)) - assert not isInside(getBufferedExtent(*llsimple), getBufferedExtent(*llsimple) + 1) + extent1 = getBufferedExtent(*llsimple) + extent2 = extent1[0] + 1, extent1[1] + 1, extent1[2] + 1, extent1[3] + 1 + assert isInside(extent1, extent1) + assert not isInside(extent1, extent2) diff --git a/test/test_weather_model.py b/test/test_weather_model.py index 441c1aebc..bd6c0db81 100644 --- a/test/test_weather_model.py +++ b/test/test_weather_model.py @@ -23,7 +23,7 @@ from RAiDER.models.gmao import GMAO from RAiDER.models.merra2 import MERRA2 from RAiDER.models.ncmr import NCMR -from RAiDER.models.customExceptions import * +from RAiDER.models.customExceptions import DatetimeOutsideRange _LON0 = 0 @@ -315,10 +315,11 @@ def test_hrrr(hrrr: HRRR): wm.checkTime(datetime.datetime(2010, 7, 15).replace(tzinfo=datetime.timezone(offset=datetime.timedelta()))) wm.checkTime(datetime.datetime(2018, 7, 12).replace(tzinfo=datetime.timezone(offset=datetime.timedelta()))) - assert isinstance(wm.checkValidBounds([35, 40, -95, -90]), HRRR) + assert isinstance(wm, HRRR) + wm.checkValidBounds(np.array([35, 40, -95, -90])) with pytest.raises(ValueError): - wm.checkValidBounds([45, 47, 300, 310]) + wm.checkValidBounds(np.array([45, 47, 300, 310])) def test_hrrrak(hrrrak: HRRRAK): @@ -326,10 +327,11 @@ def test_hrrrak(hrrrak: HRRRAK): assert wm._Name == 'HRRR-AK' assert wm._valid_range[0] == datetime.datetime(2018, 7, 13).replace(tzinfo=datetime.timezone(offset=datetime.timedelta())) - assert isinstance(wm.checkValidBounds([45, 47, 200, 210]), HRRRAK) + assert isinstance(wm, HRRRAK) + wm.checkValidBounds(np.array([45, 47, 200, 210])) with pytest.raises(ValueError): - wm.checkValidBounds([15, 20, 265, 270]) + wm.checkValidBounds(np.array([15, 20, 265, 270])) with pytest.raises(DatetimeOutsideRange): wm.checkTime(datetime.datetime(2018, 7, 12).replace(tzinfo=datetime.timezone(offset=datetime.timedelta()))) diff --git a/test/weather_files/ERA-5_2020_01_30_T13_52_45_14N_18N_102W_99W.nc b/test/weather_files/ERA-5_2020_01_30_T13_52_45_14N_18N_102W_99W.nc deleted file mode 100644 index 6e31b5f70..000000000 Binary files a/test/weather_files/ERA-5_2020_01_30_T13_52_45_14N_18N_102W_99W.nc and /dev/null differ diff --git a/tools/RAiDER/__init__.py b/tools/RAiDER/__init__.py index 9c933fb7a..f736f7b8e 100644 --- a/tools/RAiDER/__init__.py +++ b/tools/RAiDER/__init__.py @@ -3,8 +3,10 @@ Copyright (c) 2019-2022, California Institute of Technology ("Caltech"). All rights reserved. """ + from importlib.metadata import version + __version__ = version(__name__) __copyright__ = 'Copyright (c) 2019-2022, California Institute of Technology ("Caltech"). All rights reserved.' diff --git a/tools/RAiDER/aria/calcGUNW.py b/tools/RAiDER/aria/calcGUNW.py index 12c18cf91..2f6af058d 100644 --- a/tools/RAiDER/aria/calcGUNW.py +++ b/tools/RAiDER/aria/calcGUNW.py @@ -1,27 +1,31 @@ """ -Calculate the interferometric phase from the 4 delays files of a GUNW -Write it to disk +Calculate the interferometric phase from the 4 delays files of a GUNW and write it to disk. """ + +import datetime as dt import os -import xarray as xr +from pathlib import Path + +import h5py +import netCDF4 import numpy as np +import xarray as xr + import RAiDER -from RAiDER.utilFcns import rio_open from RAiDER.logger import logger -from datetime import datetime -import h5py -from netCDF4 import Dataset -## ToDo: - # Check difference direction + +# ToDo: +# Check difference direction TROPO_GROUP = 'science/grids/corrections/external/troposphere' TROPO_NAMES = ['troposphereWet', 'troposphereHydrostatic'] -DIM_NAMES = ['heightsMeta', 'latitudeMeta', 'longitudeMeta'] +DIM_NAMES = ['heightsMeta', 'latitudeMeta', 'longitudeMeta'] -def compute_delays_slc(cube_filenames: list, wavelength: float) -> xr.Dataset: - """Get delays from standard RAiDER output formatting ouput including radian +def compute_delays_slc(cube_paths: list[Path], wavelength: float) -> xr.Dataset: + """ + Get delays from standard RAiDER output formatting ouput including radian conversion and metadata. Parameters @@ -31,25 +35,25 @@ def compute_delays_slc(cube_filenames: list, wavelength: float) -> xr.Dataset: wavelength : float Depends on sensor, e.g. for Sentinel-1 it is ~.05 - Returns + Returns: ------- xr.Dataset Formatted dataset for GUNW """ # parse date from filename - dct_delays = {} - for f in cube_filenames: - date = datetime.strptime(os.path.basename(f).split('_')[2], '%Y%m%dT%H%M%S') - dct_delays[date] = f + dct_delays: dict[dt.datetime, Path] = {} + for path in cube_paths: + date = dt.datetime.strptime(path.name.split('_')[2], '%Y%m%dT%H%M%S') + dct_delays[date] = path sec, ref = sorted(dct_delays.keys()) - wet_delays = [] - hyd_delays = [] - attrs_lst = [] + wet_delays: list[xr.DataArray] = [] + hyd_delays: list[xr.DataArray] = [] + attrs_lst: list[dict] = [] phase2range = (-4 * np.pi) / float(wavelength) - for dt in [ref, sec]: - path = dct_delays[dt] + for datetime in [ref, sec]: + path = dct_delays[datetime] with xr.open_dataset(path) as ds: da_wet = ds['wet'] * phase2range da_hydro = ds['hydro'] * phase2range @@ -58,13 +62,13 @@ def compute_delays_slc(cube_filenames: list, wavelength: float) -> xr.Dataset: hyd_delays.append(da_hydro) attrs_lst.append(ds.attrs) - chunk_sizes = da_wet.shape[0], da_wet.shape[1]/3, da_wet.shape[2]/3 + chunk_sizes = da_wet.shape[0], da_wet.shape[1] / 3, da_wet.shape[2] / 3 # open one to copy and store new data ds_slc = xr.open_dataset(path).copy() encoding = ds_slc['wet'].encoding # chunksizes and fill value encoding['contiguous'] = False - encoding['_FillValue'] = 0. + encoding['_FillValue'] = 0.0 encoding['chunksizes'] = tuple([np.floor(cs) for cs in chunk_sizes]) del ds_slc['wet'], ds_slc['hydro'] @@ -72,31 +76,30 @@ def compute_delays_slc(cube_filenames: list, wavelength: float) -> xr.Dataset: ds_slc[f'{key}_{TROPO_NAMES[0]}'] = wet_delays[i] ds_slc[f'{key}_{TROPO_NAMES[1]}'] = hyd_delays[i] - model = os.path.basename(path).split('_')[0] + model = path.name.split('_')[0] attrs = { - 'units': 'radians', - 'grid_mapping': 'crs', - } + 'units': 'radians', + 'grid_mapping': 'crs', + } - ## no data (fill value?) chunk size? + # no data (fill value?) chunk size? for name in TROPO_NAMES: for k, key in enumerate(['reference', 'secondary']): descrip = f"Delay due to {name.lstrip('troposphere')} component of troposphere" - da_attrs = {**attrs, - 'description': descrip, - 'long_name': name, - 'standard_name': name, - 'RAiDER version': RAiDER.__version__, - 'model_times_used': attrs_lst[k]['model_times_used'], - 'scene_center_time': attrs_lst[k]['reference_time'], - 'time_interpolation_method': attrs_lst[k]['interpolation_method'] - } + da_attrs = { + **attrs, + 'description': descrip, + 'long_name': name, + 'standard_name': name, + 'RAiDER version': RAiDER.__version__, + 'model_times_used': attrs_lst[k]['model_times_used'], + 'scene_center_time': attrs_lst[k]['reference_time'], + 'time_interpolation_method': attrs_lst[k]['interpolation_method'], + } ds_slc[f'{key}_{name}'] = ds_slc[f'{key}_{name}'].assign_attrs(da_attrs) ds_slc[f'{key}_{name}'].encoding = encoding - ds_slc = ds_slc.assign_attrs(model=model, - method='ray tracing' - ) + ds_slc = ds_slc.assign_attrs(model=model, method='ray tracing') # force these to float32 to prevent stitching errors coords = {coord: ds_slc[coord].astype(np.float32) for coord in ds_slc.coords} @@ -104,10 +107,12 @@ def compute_delays_slc(cube_filenames: list, wavelength: float) -> xr.Dataset: return ds_slc.rename(z=DIM_NAMES[0], y=DIM_NAMES[1], x=DIM_NAMES[2]) + # first need to delete the variable; only can seem to with h5 + -def update_gunw_slc(path_gunw:str, ds_slc): - """ Update the path_gunw file using the slc delays in ds_slc """ - ## first need to delete the variable; only can seem to with h5 + +def update_gunw_slc(path_gunw: Path, ds_slc: xr.Dataset) -> None: + """Update the path_gunw file using the slc delays in ds_slc.""" with h5py.File(path_gunw, 'a') as h5: for k in TROPO_GROUP.split(): h5 = h5[k] @@ -122,87 +127,79 @@ def update_gunw_slc(path_gunw:str, ds_slc): if k in h5.keys(): del h5[k] - - with Dataset(path_gunw, mode='a') as ds: - ds_grp = ds[TROPO_GROUP] + with netCDF4.Dataset(path_gunw, mode='a') as ds: + ds_grp = ds[TROPO_GROUP] ds_grp.createGroup(ds_slc.attrs['model'].upper()) ds_grp_wm = ds_grp[ds_slc.attrs['model'].upper()] - - ## create and store new data e.g., corrections/troposphere/GMAO/reference/troposphereWet + # create and store new data e.g., corrections/troposphere/GMAO/reference/troposphereWet for rs in 'reference secondary'.split(): ds_grp_wm.createGroup(rs) ds_grp_rs = ds_grp_wm[rs] - ## create the new dimensions e.g., corrections/troposphere/GMAO/reference/latitudeMeta + # create the new dimensions e.g., corrections/troposphere/GMAO/reference/latitudeMeta for dim in DIM_NAMES: - ## dimension may already exist if updating + # dimension may already exist if updating try: ds_grp_rs.createDimension(dim, len(ds_slc.coords[dim])) - ## necessary for transform - v = ds_grp_rs.createVariable(dim, np.float32, dim) + # necessary for transform + v = ds_grp_rs.createVariable(dim, np.float32, dim) v[:] = ds_slc[dim] v.setncatts(ds_slc[dim].attrs) except RuntimeError: pass - ## add the projection if it doesnt exist + # add the projection if it doesnt exist try: v_proj = ds_grp_rs.createVariable('crs', 'i') except RuntimeError: v_proj = ds_grp_rs['crs'] - v_proj.setncatts(ds_slc["crs"].attrs) + v_proj.setncatts(ds_slc['crs'].attrs) - ## update the actual tropo data + # update the actual tropo data for name in TROPO_NAMES: - da = ds_slc[f'{rs}_{name}'] - nodata = da.encoding['_FillValue'] + da = ds_slc[f'{rs}_{name}'] + nodata = da.encoding['_FillValue'] chunksize = da.encoding['chunksizes'] - ## in case updating + # in case updating try: - v = ds_grp_rs.createVariable(name, np.float32, DIM_NAMES, - chunksizes=chunksize, fill_value=nodata) + v = ds_grp_rs.createVariable(name, np.float32, DIM_NAMES, chunksizes=chunksize, fill_value=nodata) except RuntimeError: - v = ds_grp_rs[name] + v = ds_grp_rs[name] v[:] = da.data v.setncatts(da.attrs) - logger.info('Updated %s group in: %s', os.path.basename(TROPO_GROUP), path_gunw) - return -def update_gunw_version(path_gunw): - """ temporary hack for updating version to test aria-tools """ - with Dataset(path_gunw, mode='a') as ds: +def update_gunw_version(path_gunw: Path) -> None: + """Temporary hack for updating version to test aria-tools.""" + with netCDF4.Dataset(path_gunw, mode='a') as ds: ds.version = '1c' - return -def tropo_gunw_slc(cube_filenames: list, - path_gunw: str, - wavelength: float) -> xr.Dataset: +def tropo_gunw_slc(cube_paths: list[Path], path_gunw: Path, wavelength: float) -> xr.Dataset: """ - Computes and formats the troposphere phase delay for GUNW from RAiDER outputs. + Compute and format the troposphere phase delay for GUNW from RAiDER outputs. Parameters ---------- - cube_filenames : list + cube_filenames : list[Path] list with filename of delay cube for ref and sec date (netcdf) - path_gunw : str + path_gunw : Path GUNW netcdf path wavelength : float Wavelength of SAR - Returns + Returns: ------- xr.Dataset Output cube that will be included in GUNW """ - ds_slc = compute_delays_slc(cube_filenames, wavelength) + ds_slc = compute_delays_slc(cube_paths, wavelength) # write the interferometric delay to disk update_gunw_slc(path_gunw, ds_slc) diff --git a/tools/RAiDER/aria/prepFromGUNW.py b/tools/RAiDER/aria/prepFromGUNW.py index 2d76bfaa7..7b2cb3720 100644 --- a/tools/RAiDER/aria/prepFromGUNW.py +++ b/tools/RAiDER/aria/prepFromGUNW.py @@ -5,30 +5,32 @@ # RESERVED. United States Government Sponsorship acknowledged. # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -import os -from datetime import datetime, timezone, timedelta +import datetime as dt +import sys +from pathlib import Path + +from RAiDER.aria.types import CalcDelaysArgs import numpy as np -import xarray as xr -import rasterio import pandas as pd -import yaml +import rasterio import shapely.wkt -from dataclasses import dataclass -import sys +import xarray as xr from shapely.geometry import box -import RAiDER from RAiDER.logger import logger from RAiDER.models import credentials -from RAiDER.models.hrrr import HRRR_CONUS_COVERAGE_POLYGON, AK_GEO, check_hrrr_dataset_availability +from RAiDER.models.hrrr import AK_GEO, HRRR_CONUS_COVERAGE_POLYGON, check_hrrr_dataset_availability from RAiDER.s1_azimuth_timing import get_times_for_azimuth_interpolation from RAiDER.s1_orbits import get_orbits_from_slc_ids_hyp3lib +from RAiDER.types import BB, LookDir +from RAiDER.utilFcns import write_yaml + -## cube spacing in degrees for each model +# cube spacing in degrees for each model DCT_POSTING = {'HRRR': 0.05, 'HRES': 0.10, 'GMAO': 0.10, 'ERA5': 0.10, 'ERA5T': 0.10, 'MERRA2': 0.1} -def _get_acq_time_from_gunw_id(gunw_id: str, reference_or_secondary: str) -> datetime: +def _get_acq_time_from_gunw_id(gunw_id: str, reference_or_secondary: str) -> dt.datetime: # Ex: S1-GUNW-A-R-106-tops-20220115_20211222-225947-00078W_00041N-PP-4be8-v3_0_0 if reference_or_secondary not in ['reference', 'secondary']: raise ValueError('Reference_or_secondary must "reference" or "secondary"') @@ -36,29 +38,32 @@ def _get_acq_time_from_gunw_id(gunw_id: str, reference_or_secondary: str) -> dat date_tokens = tokens[6].split('_') date_token = date_tokens[0] if reference_or_secondary == 'reference' else date_tokens[1] center_time_token = tokens[7] - cen_acq_time = datetime(int(date_token[:4]), - int(date_token[4:6]), - int(date_token[6:]), - int(center_time_token[:2]), - int(center_time_token[2:4]), - int(center_time_token[4:])) + cen_acq_time = dt.datetime( + int(date_token[:4]), + int(date_token[4:6]), + int(date_token[6:]), + int(center_time_token[:2]), + int(center_time_token[2:4]), + int(center_time_token[4:]), + ) return cen_acq_time - def check_hrrr_dataset_availablity_for_s1_azimuth_time_interpolation(gunw_id: str) -> bool: - """Determines if all the times for azimuth interpolation are available using Herbie; note that not all 1 hour times - are available within the said date range of HRRR. + """ + Determine if all the times for azimuth interpolation are available using + Herbie. Note that not all 1 hour times are available within the said date + range of HRRR. Parameters ---------- gunw_id : str - Returns + Returns: ------- bool - Example: + Example: check_hrrr_dataset_availablity_for_s1_azimuth_time_interpolation(S1-GUNW-A-R-106-tops-20220115_20211222-225947-00078W_00041N-PP-4be8-v3_0_0) should return True """ @@ -74,9 +79,8 @@ def check_hrrr_dataset_availablity_for_s1_azimuth_time_interpolation(gunw_id: st return all(ref_dataset_availability) and all(sec_dataset_availability) -def get_slc_ids_from_gunw(gunw_path: str, - reference_or_secondary: str = 'reference') -> list[str]: - #Example input: test/gunw_test_data/S1-GUNW-D-R-059-tops-20230320_20220418-180300-00179W_00051N-PP-c92e-v2_0_6.nc +def get_slc_ids_from_gunw(gunw_path: Path, reference_or_secondary: str = 'reference') -> list[str]: + # Example input: test/gunw_test_data/S1-GUNW-D-R-059-tops-20230320_20220418-180300-00179W_00051N-PP-c92e-v2_0_6.nc if reference_or_secondary not in ['reference', 'secondary']: raise ValueError('"reference_or_secondary" must be either "reference" or "secondary"') group = f'science/radarMetaData/inputSLC/{reference_or_secondary}' @@ -91,24 +95,24 @@ def get_acq_time_from_slc_id(slc_id: str) -> pd.Timestamp: return pd.Timestamp(ts_str) - -def check_weather_model_availability(gunw_path: str, - weather_model_name: str) -> bool: - """Checks weather reference and secondary dates of GUNW occur within - weather model valid range +def check_weather_model_availability(gunw_path: Path, weather_model_name: str) -> bool: + """ + Check weather reference and secondary dates of GUNW occur within + weather model valid range. Parameters ---------- - gunw_path : str + gunw_path : Path weather_model_name : str Should be one of 'HRRR', 'HRES', 'ERA5', 'ERA5T', 'GMAO', 'MERRA2'. - Returns + + Returns: ------- bool: True if both reference and secondary acquisitions are within the valid range. We assume that reference_date > secondary_date (i.e. reference scenes are most recent) - Raises + Raises: ------ ValueError - If weather model is not correctly referencing the Class from RAiDER.models @@ -117,8 +121,8 @@ def check_weather_model_availability(gunw_path: str, ref_slc_ids = get_slc_ids_from_gunw(gunw_path, reference_or_secondary='reference') sec_slc_ids = get_slc_ids_from_gunw(gunw_path, reference_or_secondary='secondary') - ref_ts = get_acq_time_from_slc_id(ref_slc_ids[0]).replace(tzinfo=timezone(offset=timedelta())) - sec_ts = get_acq_time_from_slc_id(sec_slc_ids[0]).replace(tzinfo=timezone(offset=timedelta())) + ref_ts = get_acq_time_from_slc_id(ref_slc_ids[0]).replace(tzinfo=dt.timezone(offset=dt.timedelta())) + sec_ts = get_acq_time_from_slc_id(sec_slc_ids[0]).replace(tzinfo=dt.timezone(offset=dt.timedelta())) if weather_model_name == 'HRRR': group = '/science/grids/data/' @@ -144,189 +148,186 @@ def check_weather_model_availability(gunw_path: str, weather_model = weather_model_cls() wm_start_date, wm_end_date = weather_model._valid_range - if not isinstance(wm_end_date, datetime): - raise ValueError(f'the weather model\'s end date is not valid: {wm_end_date}') + if not isinstance(wm_end_date, dt.datetime): + raise ValueError(f"the weather model's end date is not valid: {wm_end_date}") ref_cond = ref_ts <= wm_end_date sec_cond = sec_ts >= wm_start_date return ref_cond and sec_cond -@dataclass class GUNW: - path_gunw: str - wm: str - out_dir: str - - def __post_init__(self): - self.SNWE = self.get_bbox() - self.heights = np.arange(-500, 9500, 500).tolist() + path_gunw: Path + wm: str # TODO(garlic-os): probably a known weather model name + out_dir: Path + SNWE: BB.SNWE + heights: list[int] + dates: list[int] # ints in YYYYMMDD form + mid_time: str # str in HH:MM:SS form + look_dir: LookDir + wavelength: float + name: str + orbit_file: ... + spacing_m: int + + def __init__(self, path_gunw: str, wm: str, out_dir: str) -> None: + self.path_gunw = Path(path_gunw) + self.wm = wm + self.out_dir = Path(out_dir) + + self.SNWE = self.get_bbox() + self.heights = np.arange(-500, 9500, 500).tolist() # self.heights = [-500, 0] self.dates, self.mid_time = self.get_datetimes() - - self.look_dir = self.get_look_dir() + self.look_dir = self.get_look_dir() self.wavelength = self.get_wavelength() - self.name = self.make_fname() - self.OrbitFile = self.get_orbit_file() - self.spacing_m = int(DCT_POSTING[self.wm] * 1e5) - - ## not implemented + self.name = self.make_fname() + self.orbit_file = self.get_orbit_file() + self.spacing_m = int(DCT_POSTING[self.wm] * 1e5) + # not implemented # self.spacing_m = self.calc_spacing_UTM() # probably wrong/unnecessary # self.lat_file, self.lon_file = self.makeLatLonGrid_native() # self.path_cube = self.make_cube() # not needed - - def get_bbox(self): - """ Get the bounding box (SNWE) from an ARIA GUNW product """ + def get_bbox(self) -> BB.SNWE: + """Get the bounding box (SNWE) from an ARIA GUNW product.""" with xr.open_dataset(self.path_gunw) as ds: poly_str = ds['productBoundingBox'].data[0].decode('utf-8') - - poly = shapely.wkt.loads(poly_str) + poly = shapely.wkt.loads(poly_str) W, S, E, N = poly.bounds + return S, N, W, E - return [S, N, W, E] - - - def make_fname(self): - """ Match the ref/sec filename (SLC dates may be different around edge cases) """ - ref, sec = os.path.basename(self.path_gunw).split('-')[6].split('_') - mid_time = os.path.basename(self.path_gunw).split('-')[7] + def make_fname(self) -> str: + """Match the ref/sec filename (SLC dates may be different around edge cases).""" + ref, sec = self.path_gunw.name.split('-')[6].split('_') + mid_time = self.path_gunw.name.split('-')[7] return f'{ref}-{sec}_{mid_time}' - - def get_datetimes(self): - """ Get the datetimes and set the satellite for orbit """ - ref_sec = self.get_slc_dt() - middates = [] - for aq in ref_sec: - st, en = aq - midpt = st + (en-st)/2 - middates.append(int(midpt.date().strftime('%Y%m%d'))) - midtime = midpt.time().strftime('%H:%M:%S') - return middates, midtime - - - def get_slc_dt(self): - """ Grab the SLC start date and time from the GUNW """ - group = 'science/radarMetaData/inputSLC' - lst_sten = [] - for i, key in enumerate('reference secondary'.split()): - ds = xr.open_dataset(self.path_gunw, group=f'{group}/{key}') - slcs = ds['L1InputGranules'] + def get_datetimes(self) -> tuple[list[int], str]: + """Get the datetimes and set the satellite for orbit.""" + ref_sec = self.get_slc_dt() + mid_dates: list[int] = [] # dates in YYYYMMDD format + for st, en in ref_sec: + midpoint = st + (en - st) / 2 + mid_dates.append(int(midpoint.date().strftime('%Y%m%d'))) + mid_time = midpoint.time().strftime('%H:%M:%S') + return mid_dates, mid_time + + def get_slc_dt(self) -> list[tuple[dt.datetime, dt.datetime]]: + """Grab the SLC start date and time from the GUNW.""" + group = 'science/radarMetaData/inputSLC' + lst_sten: list[tuple[dt.datetime, dt.datetime]] = [] + for key in 'reference secondary'.split(): + with xr.open_dataset(self.path_gunw, group=f'{group}/{key}') as ds: + slcs = ds['L1InputGranules'] nslcs = slcs.count().item() # single slc if nslcs == 1: - slc = slcs.item() + slc = slcs.item() assert slc, f'Missing {key} SLC metadata in GUNW: {self.f}' - st = datetime.strptime(slc.split('_')[5], '%Y%m%dT%H%M%S') - en = datetime.strptime(slc.split('_')[6], '%Y%m%dT%H%M%S') + st = dt.datetime.strptime(slc.split('_')[5], '%Y%m%dT%H%M%S') + en = dt.datetime.strptime(slc.split('_')[6], '%Y%m%dT%H%M%S') else: - st, en = datetime(1989, 3, 1), datetime(1989, 3, 1) + st, en = dt.datetime(1989, 3, 1), dt.datetime(1989, 3, 1) for j in range(nslcs): slc = slcs.data[j] if slc: - ## get the maximum range - st_tmp = datetime.strptime(slc.split('_')[5], '%Y%m%dT%H%M%S') - en_tmp = datetime.strptime(slc.split('_')[6], '%Y%m%dT%H%M%S') + # get the maximum range + st_tmp = dt.datetime.strptime(slc.split('_')[5], '%Y%m%dT%H%M%S') + en_tmp = dt.datetime.strptime(slc.split('_')[6], '%Y%m%dT%H%M%S') - ## check the second SLC is within one day of the previous - if st > datetime(1989, 3, 1): + # check the second SLC is within one day of the previous + if st > dt.datetime(1989, 3, 1): stdiff = np.abs((st_tmp - st).days) endiff = np.abs((en_tmp - en).days) assert stdiff < 2 and endiff < 2, 'SLCs granules are too far apart in time. Incorrect metadata' - st = st_tmp if st_tmp > st else st en = en_tmp if en_tmp > en else en - assert st>datetime(1989, 3, 1), f'Missing {key} SLC metadata in GUNW: {self.f}' + assert st > dt.datetime(1989, 3, 1), \ + f'Missing {key} SLC metadata in GUNW: {self.f}' - lst_sten.append([st, en]) + lst_sten.append((st, en)) return lst_sten - - def get_look_dir(self): - look_dir = os.path.basename(self.path_gunw).split('-')[3].lower() + def get_look_dir(self) -> LookDir: + look_dir = self.path_gunw.name.split('-')[3].lower() return 'right' if look_dir == 'r' else 'left' - def get_wavelength(self): - group ='science/radarMetaData' + group = 'science/radarMetaData' with xr.open_dataset(self.path_gunw, group=group) as ds: wavelength = ds['wavelength'].item() return wavelength - - def get_orbit_file(self): - """ Get orbit file for reference (GUNW: first & later date)""" - orbit_dir = os.path.join(self.out_dir, 'orbits') - os.makedirs(orbit_dir, exist_ok=True) + # TODO(garlic-os): sounds like this returns one thing but it returns a list? + def get_orbit_file(self) -> list[str]: + """Get orbit file for reference (GUNW: first & later date).""" + orbit_dir = self.out_dir / 'orbits' + orbit_dir.mkdir(parents=True, exist_ok=True) # just to get the correct satellite - group = 'science/radarMetaData/inputSLC/reference' + group = 'science/radarMetaData/inputSLC/reference' - ds = xr.open_dataset(self.path_gunw, group=f'{group}') - slcs = ds['L1InputGranules'] + with xr.open_dataset(self.path_gunw, group=f'{group}') as ds: + slcs = ds['L1InputGranules'] # Convert to list of strings slcs_lst = [slc for slc in slcs.data.tolist() if slc] - # Remove .zip from the granule ids included in this field + # Remove ".zip" from the granule ids included in this field slcs_lst = list(map(lambda slc: slc.replace('.zip', ''), slcs_lst)) path_orb = get_orbits_from_slc_ids_hyp3lib(slcs_lst) return [str(o) for o in path_orb] - - ## ------ methods below are not used + # ------ methods below are not used def get_version(self): with xr.open_dataset(self.path_gunw) as ds: version = ds.attrs['version'] return version - def getHeights(self): - """ Get the 4 height levels within a GUNW """ - group ='science/grids/imagingGeometry' + """Get the 4 height levels within a GUNW.""" + group = 'science/grids/imagingGeometry' with xr.open_dataset(self.path_gunw, group=group) as ds: hgts = ds.heightsMeta.data.tolist() return hgts - - def calc_spacing_UTM(self, posting:float=0.01): - """ Convert desired horizontal posting in degrees to meters + def calc_spacing_UTM(self, posting: float = 0.01): + """Convert desired horizontal posting in degrees to meters. Want to calculate delays close to native model resolution (3 km for HRR) """ from RAiDER.utilFcns import WGS84_to_UTM + group = 'science/grids/data' with xr.open_dataset(self.path_gunw, group=group) as ds0: lats = ds0.latitude.data lons = ds0.longitude.data - lat0, lon0 = lats[0], lons[0] lat1, lon1 = lat0 + posting, lon0 + posting - res = WGS84_to_UTM(np.array([lon0, lon1]), np.array([lat0, lat1])) + res = WGS84_to_UTM(np.array([lon0, lon1]), np.array([lat0, lat1])) lon_spacing_m = np.subtract(*res[2][::-1]) lat_spacing_m = np.subtract(*res[3][::-1]) return np.mean([lon_spacing_m, lat_spacing_m]) - - def makeLatLonGrid_native(self): - """ Make LatLonGrid at GUNW spacing (90m = 0.00083333º) """ + def makeLatLonGrid_native(self) -> tuple[Path, Path]: + """Make LatLonGrid at GUNW spacing (90m = 0.00083333º).""" group = 'science/grids/data' with xr.open_dataset(self.path_gunw, group=group) as ds0: lats = ds0.latitude.data lons = ds0.longitude.data - Lat, Lon = np.meshgrid(lats, lons) + Lat, Lon = np.meshgrid(lats, lons) - dims = 'longitude latitude'.split() + dims = 'longitude latitude'.split() da_lon = xr.DataArray(Lon.T, coords=[Lon[0, :], Lat[:, 0]], dims=dims) da_lat = xr.DataArray(Lat.T, coords=[Lon[0, :], Lat[:, 0]], dims=dims) - dst_lat = os.path.join(self.out_dir, 'latitude.geo') - dst_lon = os.path.join(self.out_dir, 'longitude.geo') + dst_lat = self.out_dir / 'latitude.geo' + dst_lon = self.out_dir / 'longitude.geo' da_lat.to_netcdf(dst_lat) da_lon.to_netcdf(dst_lon) @@ -335,9 +336,8 @@ def makeLatLonGrid_native(self): logger.debug('Wrote: %s', dst_lon) return dst_lat, dst_lon - - def make_cube(self): - """ Make LatLonGrid at GUNW spacing (90m = 0.00083333º) """ + def make_cube(self) -> Path: + """Make LatLonGrid at GUNW spacing (90m = 0.00083333º).""" group = 'science/grids/data' with xr.open_dataset(self.path_gunw, group=group) as ds0: lats0 = ds0.latitude.data @@ -346,76 +346,45 @@ def make_cube(self): lat_st, lat_en = np.floor(lats0.min()), np.ceil(lats0.max()) lon_st, lon_en = np.floor(lons0.min()), np.ceil(lons0.max()) - lats = np.arange(lat_st, lat_en, DCT_POSTING[self.wmodel]) - lons = np.arange(lon_st, lon_en, DCT_POSTING[self.wmodel]) + lats = np.arange(lat_st, lat_en, DCT_POSTING[self.wm]) + lons = np.arange(lon_st, lon_en, DCT_POSTING[self.wm]) - ds = xr.Dataset(coords={'latitude': lats, 'longitude': lons, 'heights': self.heights}) - dst_cube = os.path.join(self.out_dir, f'GeoCube_{self.name}.nc') - ds.to_netcdf(dst_cube) + dst_cube = self.out_dir / f'GeoCube_{self.name}.nc' + with xr.Dataset(coords={'latitude': lats, 'longitude': lons, 'heights': self.heights}) as ds: + ds.to_netcdf(dst_cube) - logger.info('Wrote cube to: %s', dst_cube) + logger.info('Wrote cube to: %s', str(dst_cube)) return dst_cube - -def update_yaml(dct_cfg:dict, dst:str='GUNW.yaml'): - """ Write a new yaml file from a dictionary. - - Updates parameters in the default 'template.yaml' file. - Each key:value pair will in 'dct_cfg' will overwrite that in the default - """ - - run_config_path = os.path.join( - os.path.dirname(RAiDER.__file__), - 'cli', - 'examples', - 'template', - 'template.yaml' - ) - - with open(run_config_path, 'r') as f: - try: - params = yaml.safe_load(f) - except yaml.YAMLError as exc: - print(exc) - raise ValueError(f'Something is wrong with the yaml file {run_config_path}') - - params = {**params, **dct_cfg} - - with open(dst, 'w') as fh: - yaml.safe_dump(params, fh, default_flow_style=False) - - logger.info (f'Wrote new cfg file: %s', dst) - return dst - - -def main(args): - """ Read parameters needed for RAiDER from ARIA Standard Products (GUNW) """ - +def main(args: CalcDelaysArgs) -> tuple[Path, float]: + """Read parameters needed for RAiDER from ARIA Standard Products (GUNW).""" # Check if WEATHER MODEL API credentials hidden file exists, if not create it or raise ERROR credentials.check_api(args.weather_model, args.api_uid, args.api_key) GUNWObj = GUNW(args.file, args.weather_model, args.output_directory) - raider_cfg = { - 'weather_model': args.weather_model, - 'look_dir': GUNWObj.look_dir, - 'cube_spacing_in_m': GUNWObj.spacing_m, - 'aoi_group' : {'bounding_box': GUNWObj.SNWE}, - 'height_group' : {'height_levels': GUNWObj.heights}, - 'date_group': {'date_list': GUNWObj.dates}, - 'time_group': {'time': GUNWObj.mid_time, - # Options are 'none', 'center_time', and 'azimuth_time_grid' - 'interpolate_time': args.interpolate_time}, - 'los_group' : {'ray_trace': True, - 'orbit_file': GUNWObj.OrbitFile, - 'wavelength': GUNWObj.wavelength, - }, - - 'runtime_group': {'raster_format': 'nc', - 'output_directory': args.output_directory, - } + raider_cfg = { + 'weather_model': args.weather_model, + 'look_dir': GUNWObj.look_dir, + 'aoi_group': {'bounding_box': GUNWObj.SNWE}, + 'height_group': {'height_levels': GUNWObj.heights}, + 'date_group': {'date_list': GUNWObj.dates}, + 'time_group': { + 'time': GUNWObj.mid_time, + # Options are 'none', 'center_time', and 'azimuth_time_grid' + 'interpolate_time': args.interpolate_time, + }, + 'los_group': { + 'ray_trace': True, + 'orbit_file': GUNWObj.orbit_file, + }, + 'runtime_group': { + 'raster_format': 'nc', + 'output_directory': args.output_directory, + 'cube_spacing_in_m': GUNWObj.spacing_m, + }, } - path_cfg = f'GUNW_{GUNWObj.name}.yaml' - update_yaml(raider_cfg, path_cfg) + path_cfg = Path(f'GUNW_{GUNWObj.name}.yaml') + write_yaml(raider_cfg, path_cfg) return path_cfg, GUNWObj.wavelength diff --git a/tools/RAiDER/aria/types.py b/tools/RAiDER/aria/types.py new file mode 100644 index 000000000..56034ddd0 --- /dev/null +++ b/tools/RAiDER/aria/types.py @@ -0,0 +1,28 @@ +import argparse +from pathlib import Path +from typing import Optional + +from RAiDER.types import TimeInterpolationMethod + + +class CalcDelaysArgsUnparsed(argparse.Namespace): + bucket: Optional[str] + bucket_prefix: Optional[str] + input_bucket_prefix: Optional[str] + file: Optional[Path] + weather_model: str + api_uid: Optional[str] + api_key: Optional[str] + interpolate_time: TimeInterpolationMethod + output_directory: Path + +class CalcDelaysArgs(argparse.Namespace): + bucket: Optional[str] + bucket_prefix: Optional[str] + input_bucket_prefix: Optional[str] + file: Path + weather_model: str + api_uid: Optional[str] + api_key: Optional[str] + interpolate_time: TimeInterpolationMethod + output_directory: Path diff --git a/tools/RAiDER/aws.py b/tools/RAiDER/aws.py index 6afb44e5a..c8a308fef 100644 --- a/tools/RAiDER/aws.py +++ b/tools/RAiDER/aws.py @@ -1,29 +1,18 @@ -from typing import Optional, Union from mimetypes import guess_type from pathlib import Path +from typing import Optional, Union import boto3 from RAiDER.logger import logger -S3_CLIENT = boto3.client('s3') - -def get_tag_set(): - tag_set = { - 'TagSet': [ - { - 'Key': 'file_type', - 'Value': 'product' - } - ] - } - return tag_set +S3_CLIENT = boto3.client('s3') def get_content_type(file_location: Union[Path, str]) -> str: content_type = guess_type(file_location)[0] - if not content_type: + if content_type is None: content_type = 'application/octet-stream' return content_type @@ -36,20 +25,24 @@ def upload_file_to_s3(path_to_file: Union[str, Path], bucket: str, prefix: str = logger.info(f'Uploading s3://{bucket}/{key}') S3_CLIENT.upload_file(str(path_to_file), bucket, key, extra_args) - tag_set = get_tag_set() + tag_set = { + 'TagSet': [ + { + 'Key': 'file_type', + 'Value': 'product' + } + ] + } S3_CLIENT.put_object_tagging(Bucket=bucket, Key=key, Tagging=tag_set) -def get_s3_file(bucket_name: str, bucket_prefix: str, file_type: str) -> Optional[str]: - result = S3_CLIENT.list_objects_v2( - Bucket=bucket_name, - Prefix=bucket_prefix - ) +def get_s3_file(bucket_name: str, bucket_prefix: str, file_type: str) -> Optional[Path]: + result = S3_CLIENT.list_objects_v2(Bucket=bucket_name, Prefix=bucket_prefix) for s3_object in result['Contents']: key = s3_object['Key'] if key.endswith(file_type): file_name = Path(key).name logger.info(f'Downloading s3://{bucket_name}/{key} to {file_name}') S3_CLIENT.download_file(bucket_name, key, file_name) - return file_name + return Path(file_name) diff --git a/tools/RAiDER/checkArgs.py b/tools/RAiDER/checkArgs.py index 69c170b61..a9844e89e 100644 --- a/tools/RAiDER/checkArgs.py +++ b/tools/RAiDER/checkArgs.py @@ -5,105 +5,100 @@ # RESERVED. United States Government Sponsorship acknowledged. # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -import os +import datetime as dt +from pathlib import Path +from typing import Optional import pandas as pd import rasterio.drivers as rd -from datetime import datetime - - -from RAiDER.losreader import Zenith -from RAiDER.llreader import BoundingBox +from RAiDER.cli.types import RunConfig +from RAiDER.llreader import BoundingBox, StationFile from RAiDER.logger import logger +from RAiDER.losreader import LOS, Zenith -def checkArgs(args): - ''' - Helper fcn for checking argument compatibility and returns the - correct variables - ''' - - ######################################################################################################################### +def checkArgs(run_config: RunConfig) -> RunConfig: + """Check argument compatibility and return the correct variables.""" + ############################################################################ # Directories - if args.weather_model_directory is None: - args.weather_model_directory = os.path.join(args.output_directory, 'weather_files') + run_config.runtime_group.output_directory.mkdir(exist_ok=True) + run_config.runtime_group.weather_model_directory.mkdir(exist_ok=True) + run_config.weather_model.set_wmLoc(str(run_config.runtime_group.weather_model_directory)) - os.makedirs(args.output_directory, exist_ok=True) - os.makedirs(args.weather_model_directory, exist_ok=True) - args['weather_model'].set_wmLoc(args.weather_model_directory) - - ######################################################################################################################### + ############################################################################ # Date and Time parsing - args.date_list = [datetime.combine(d, args.time) for d in args.date_list] - if (len(args.date_list) > 1) & (args.orbit_file is not None): - logger.warning('Only one orbit file is being used to get the ' - 'look vectors for all requested times, if you ' - 'want to use separate orbit files you will ' - 'need to run raider separately for each time.') - - args.los.setTime(args.date_list[0]) - - ######################################################################################################################### + run_config.date_group.date_list = [ + dt.datetime.combine(d, run_config.time_group.time) + for d in run_config.date_group.date_list + ] + if len(run_config.date_group.date_list) > 1 and run_config.los_group.orbit_file is not None: + logger.warning( + 'Only one orbit file is being used to get the look vectors for all requested times. If you want to use ' + 'separate orbit files you will need to run RAiDER separately for each time.' + ) + + run_config.los_group.los.setTime(run_config.date_group.date_list[0]) + + ############################################################################ # filenames - wetNames, hydroNames = [], [] - for d in args.date_list: - if (args.aoi.type() != 'bounding_box'): - + wetNames: list[str] = [] + hydroNames: list[str] = [] + for d in run_config.date_group.date_list: + if not isinstance(run_config.aoi_group.aoi, BoundingBox): # Handle the GNSS station file - if (args.aoi.type()=='station_file'): - wetFilename = os.path.join( - args.output_directory, - f'{args.weather_model._dataset.upper()}_Delay'\ - f'_{d.strftime("%Y%m%dT%H%M%S")}_ztd.csv' + if isinstance(run_config.aoi_group.aoi, StationFile): + wetFilename = str( + run_config.runtime_group.output_directory / + f'{run_config.weather_model._dataset.upper()}_Delay_{d.strftime("%Y%m%dT%H%M%S")}_ztd.csv' ) - hydroFilename = '' # only the 'wetFilename' is used for the station_file + hydroFilename = '' # only the 'wetFilename' is used for the station_file # copy the input station file to the output location for editing - indf = pd.read_csv(args.aoi._filename).drop_duplicates(subset=["Lat", "Lon"]) + indf = pd.read_csv(run_config.aoi_group.aoi._filename) \ + .drop_duplicates(subset=['Lat', 'Lon']) indf.to_csv(wetFilename, index=False) else: # This implies rasters - fmt = get_raster_ext(args.file_format) + fmt = get_raster_ext(run_config.runtime_group.file_format) wetFilename, hydroFilename = makeDelayFileNames( d, - args.los, + run_config.los_group.los, fmt, - args.weather_model._dataset.upper(), - args.output_directory, + run_config.weather_model._dataset.upper(), + run_config.runtime_group.output_directory, ) - else: # In this case a cube file format is needed - if args.file_format not in '.nc .h5 h5 hdf5 .hdf5 nc'.split(): + if run_config.runtime_group.file_format not in '.nc .h5 h5 hdf5 .hdf5 nc'.split(): fmt = 'nc' - logger.debug('Invalid extension %s for cube. Defaulting to .nc', args.file_format) + logger.debug('Invalid extension %s for cube. Defaulting to .nc', run_config.runtime_group.file_format) else: - fmt = args.file_format.strip('.').replace('df', '') + fmt = run_config.runtime_group.file_format.strip('.').replace('df', '') wetFilename, hydroFilename = makeDelayFileNames( d, - args.los, + run_config.los_group.los, fmt, - args.weather_model._dataset.upper(), - args.output_directory, + run_config.weather_model._dataset.upper(), + run_config.runtime_group.output_directory, ) wetNames.append(wetFilename) hydroNames.append(hydroFilename) - args.wetFilenames = wetNames - args.hydroFilenames = hydroNames + run_config.wetFilenames = wetNames + run_config.hydroFilenames = hydroNames - return args + return run_config def get_raster_ext(fmt): drivers = rd.raster_driver_extensions() - extensions = {value.upper():key for key, value in drivers.items()} + extensions = {value.upper(): key for key, value in drivers.items()} # add in ENVI/ISCE formats with generic extension extensions['ENVI'] = '.dat' @@ -112,29 +107,27 @@ def get_raster_ext(fmt): try: return extensions[fmt.upper()] except KeyError: - raise ValueError('{} is not a valid gdal/rasterio file format for rasters'.format(fmt)) + raise ValueError(f'{fmt} is not a valid gdal/rasterio file format for rasters') -def makeDelayFileNames(time, los, outformat, weather_model_name, out): - ''' +def makeDelayFileNames(date: Optional[dt.date], los: Optional[LOS], outformat: str, weather_model_name: str, out: Path) -> tuple[str, str]: + """ return names for the wet and hydrostatic delays. # Examples: - >>> makeDelayFileNames(datetime(2020, 1, 1, 0, 0, 0), None, "h5", "model_name", "some_dir") + >>> makeDelayFileNames(dt.datetime(2020, 1, 1, 0, 0, 0), None, "h5", "model_name", "some_dir") ('some_dir/model_name_wet_00_00_00_ztd.h5', 'some_dir/model_name_hydro_00_00_00_ztd.h5') >>> makeDelayFileNames(None, None, "h5", "model_name", "some_dir") ('some_dir/model_name_wet_ztd.h5', 'some_dir/model_name_hydro_ztd.h5') - ''' - format_string = "{model_name}_{{}}_{time}{los}.{ext}".format( + """ + format_string = '{model_name}_{{}}_{time}{los}.{ext}'.format( model_name=weather_model_name, - time=time.strftime("%Y%m%dT%H%M%S_") if time is not None else "", - los="ztd" if (isinstance(los, Zenith) or los is None) else "std", - ext=outformat - ) - hydroname, wetname = ( - format_string.format(dtyp) for dtyp in ('hydro', 'wet') + time=date.strftime('%Y%m%dT%H%M%S_') if date is not None else '', + los='ztd' if (isinstance(los, Zenith) or los is None) else 'std', + ext=outformat, ) + hydroname, wetname = (format_string.format(dtyp) for dtyp in ('hydro', 'wet')) - hydro_file_name = os.path.join(out, hydroname) - wet_file_name = os.path.join(out, wetname) + hydro_file_name = str(out / hydroname) + wet_file_name = str(out / wetname) return wet_file_name, hydro_file_name diff --git a/tools/RAiDER/cli/__init__.py b/tools/RAiDER/cli/__init__.py index 1f9447773..e69de29bb 100644 --- a/tools/RAiDER/cli/__init__.py +++ b/tools/RAiDER/cli/__init__.py @@ -1,44 +0,0 @@ -import os -from RAiDER.constants import _ZREF, _CUBE_SPACING_IN_M - -class AttributeDict(dict): - __getattr__ = dict.__getitem__ - __setattr__ = dict.__setitem__ - __delattr__ = dict.__delitem__ - -DEFAULT_DICT = AttributeDict( - dict( - look_dir='right', - date_start=None, - date_end=None, - date_step=None, - date_list=None, - time=None, - end_time=None, - weather_model=None, - lat_file=None, - lon_file=None, - station_file=None, - bounding_box=None, - geocoded_file=None, - dem=None, - use_dem_latlon=False, - height_levels=None, - height_file_rdr=None, - ray_trace=False, - zref=_ZREF, - cube_spacing_in_m=_CUBE_SPACING_IN_M, - los_file=None, - los_convention='isce', - los_cube=None, - orbit_file=None, - verbose=True, - raster_format='GTiff', - file_format='GTiff', - download_only=False, - output_directory='.', - weather_model_directory=None, - output_projection='EPSG:4326', - interpolate_time='center_time', - ) -) diff --git a/tools/RAiDER/cli/__main__.py b/tools/RAiDER/cli/__main__.py index 6e3401ebe..8786cb0fa 100644 --- a/tools/RAiDER/cli/__main__.py +++ b/tools/RAiDER/cli/__main__.py @@ -1,21 +1,18 @@ import argparse import sys - from importlib.metadata import entry_points +from pathlib import Path import RAiDER.cli.conf as conf -def main(): - parser = argparse.ArgumentParser( - prefix_chars='+', - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) +def main() -> None: + parser = argparse.ArgumentParser(prefix_chars='+', formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument( '++process', choices=['calcDelays', 'downloadGNSS', 'calcDelaysGUNW'], default='calcDelays', - help='Select the entrypoint to use' + help='Select the entrypoint to use', ) parser.add_argument( '++logger_path', @@ -27,20 +24,18 @@ def main(): args, unknowns = parser.parse_known_args() # Needed for a global logging path - conf.setLoggerPath(args.logger_path) + logger_path = Path(args.logger_path) if args.logger_path else None + conf.setLoggerPath(logger_path) sys.argv = [args.process, *unknowns] try: # python >=3.10 interface - (process_entry_point,) = entry_points( - group='console_scripts', name=f'{args.process}.py') + (process_entry_point,) = entry_points(group='console_scripts', name=f'{args.process}.py') except TypeError: # python 3.8 and 3.9 interface scripts = entry_points()['console_scripts'] - process_entry_point = [ - ep for ep in scripts if ep.name == f'{args.process}.py' - ][0] + process_entry_point = [ep for ep in scripts if ep.name == f'{args.process}.py'][0] process_entry_point.load()() diff --git a/tools/RAiDER/cli/conf.py b/tools/RAiDER/cli/conf.py index 8d1faaa03..eeb50a2c5 100644 --- a/tools/RAiDER/cli/conf.py +++ b/tools/RAiDER/cli/conf.py @@ -1,7 +1,10 @@ +from pathlib import Path +from typing import Optional -LOGGER_PATH = None -def setLoggerPath(path): +LOGGER_PATH: Optional[Path] = None + + +def setLoggerPath(path: Optional[Path]) -> None: global LOGGER_PATH LOGGER_PATH = path - diff --git a/tools/RAiDER/cli/parser.py b/tools/RAiDER/cli/parser.py index e340f3cce..d7aa9f2e7 100644 --- a/tools/RAiDER/cli/parser.py +++ b/tools/RAiDER/cli/parser.py @@ -3,17 +3,17 @@ from RAiDER.cli.validators import BBoxAction, IntegerMappingType -def add_cpus(parser: argparse.ArgumentParser): + +def add_cpus(parser: argparse.ArgumentParser) -> None: parser.add_argument( '--cpus', - help='The number of cpus to be used for multiprocessing or "all" for ' - 'all available cpus.', + help='The number of cpus to be used for multiprocessing or "all" for all available cpus.', type=IntegerMappingType(0, all=os.cpu_count()), default='all', ) -def add_verbose(parser: argparse.ArgumentParser): +def add_verbose(parser: argparse.ArgumentParser) -> None: parser.add_argument( '--verbose', '-v', help='Run in verbose mode', @@ -22,21 +22,18 @@ def add_verbose(parser: argparse.ArgumentParser): ) -def add_out(parser: argparse.ArgumentParser): - parser.add_argument( - '--out', - help='Output directory', - default='.' - ) +def add_out(parser: argparse.ArgumentParser) -> None: + parser.add_argument('--out', help='Output directory', default='.') -def add_bbox(parser: argparse.ArgumentParser): +def add_bbox(parser: argparse.ArgumentParser) -> None: parser.add_argument( - '--bbox', '-b', - help="Bounding box", + '--bbox', + '-b', + help='Bounding box', nargs=4, type=float, dest='query_area', action=BBoxAction, - metavar=('S', 'N', 'W', 'E') + metavar=('S', 'N', 'W', 'E'), ) diff --git a/tools/RAiDER/cli/raider.py b/tools/RAiDER/cli/raider.py index 07546cb41..ced83c261 100644 --- a/tools/RAiDER/cli/raider.py +++ b/tools/RAiDER/cli/raider.py @@ -1,29 +1,48 @@ import argparse -import datetime -import glob -import os +import datetime as dt import json +import os import shutil import sys -import yaml +from collections.abc import Sequence +from pathlib import Path +from textwrap import dedent +from typing import Any, Optional, cast import numpy as np import xarray as xr +import yaml -from textwrap import dedent -from pathlib import Path - -import RAiDER.aria.prepFromGUNW import RAiDER.aria.calcGUNW +import RAiDER.aria.prepFromGUNW from RAiDER import aws -from RAiDER.logger import logger, logging -from RAiDER.cli import DEFAULT_DICT, AttributeDict -from RAiDER.cli.parser import add_out, add_cpus, add_verbose +from RAiDER.aria.types import CalcDelaysArgs, CalcDelaysArgsUnparsed +from RAiDER.cli.parser import add_cpus, add_out, add_verbose +from RAiDER.cli.types import ( + AOIGroup, + AOIGroupUnparsed, + DateGroupUnparsed, + HeightGroupUnparsed, + LOSGroup, + LOSGroupUnparsed, + RAiDERArgs, + RunConfig, + RuntimeGroup, + TimeGroup, +) from RAiDER.cli.validators import DateListAction, date_type +from RAiDER.gnss.types import RAiDERCombineArgs +from RAiDER.logger import logger, logging +from RAiDER.losreader import Raytracing from RAiDER.models.allowed import ALLOWED_MODELS -from RAiDER.models.customExceptions import * +from RAiDER.models.customExceptions import DatetimeFailed, NoWeatherModelData, TryToKeepGoingError, WrongNumberOfFiles +from RAiDER.s1_azimuth_timing import ( + get_inverse_weights_for_dates, + get_s1_azimuth_time_grid, + get_times_for_azimuth_interpolation, +) +from RAiDER.types import TimeInterpolationMethod from RAiDER.utilFcns import get_dt -from RAiDER.s1_azimuth_timing import get_s1_azimuth_time_grid, get_inverse_weights_for_dates, get_times_for_azimuth_interpolation TIME_INTERPOLATION_METHODS = ['none', 'center_time', 'azimuth_time_grid'] @@ -42,91 +61,90 @@ raider.py run_config_file.yaml """ -DEFAULT_RUN_CONFIG_PATH = os.path.abspath("./raider.yaml") +DEFAULT_RUN_CONFIG_PATH = Path('./examples/template/template.yaml') -def read_run_config_file(fname): +def read_run_config_file(path: Path) -> RunConfig: """ Read the run config file into a dictionary structure. + Args: - fname (str): full path to the run config file + path (Path): path to the run config file Returns: - dict: arguments to pass to RAiDER functions + RAiDERArgs: arguments to pass to RAiDER functions Examples: >>> run_config = read_run_config_file('raider.yaml') """ from RAiDER.cli.validators import ( - enforce_time, parse_dates, get_query_region, get_heights, get_los, enforce_wm + get_heights, + get_los, + get_query_region, + parse_dates, + parse_weather_model, ) - with open(fname, 'r') as f: + + with path.open() as f: try: - params = yaml.safe_load(f) + yaml_data: dict[str, Any] = yaml.safe_load(f) except yaml.YAMLError as exc: print(exc) - raise ValueError( - 'Something is wrong with the yaml file {}'.format(fname)) + raise ValueError(f'Something is wrong with the yaml file {path}') # Drop any values not specified - params = drop_nans(params) + yaml_data = drop_nans(yaml_data) - # Need to ensure that all the groups exist, even if they are not specified by the user - group_keys = ['date_group', 'time_group', 'aoi_group', - 'height_group', 'los_group', 'runtime_group'] - for key in group_keys: - if not key in params.keys(): - params[key] = {} + # Ensure that all the groups exist, even if they are not specified by the user + for key in ('date_group', 'time_group', 'aoi_group', 'height_group', 'los_group', 'runtime_group'): + if key not in yaml_data or yaml_data[key] is None: + yaml_data[key] = {} - # Parse the user-provided arguments - run_config = DEFAULT_DICT.copy() - for key, value in params.items(): - if key == 'runtime_group': - for k, v in value.items(): - if v is not None: - run_config[k] = v - if key == 'time_group': - run_config.update(enforce_time(AttributeDict(value))) - if key == 'date_group': - run_config['date_list'] = parse_dates(AttributeDict(value)) - if key == 'aoi_group': - # in case a DEM is passed and should be used - dct_temp = {**AttributeDict(value), - **AttributeDict(params['height_group'])} - run_config['aoi'] = get_query_region(AttributeDict(dct_temp)) - - if key == 'los_group': - run_config['los'] = get_los(AttributeDict(value)) - run_config['zref'] = AttributeDict(value).get('zref') - if key == 'look_dir': - if value.lower() not in ['right', 'left']: - raise ValueError(f"Unknown look direction {value}") - run_config['look_dir'] = value.lower() - if key == 'cube_spacing_in_m': - run_config[key] = float(value) if isinstance(value, str) else value - if key == 'download_only': - run_config[key] = bool(value) - - # Have to guarantee that certain variables exist prior to looking at heights - for key, value in params.items(): - if key == 'height_group': - run_config.update( - get_heights( - AttributeDict(value), - run_config['output_directory'], - run_config['station_file'], - run_config['bounding_box'], - ) - ) + # Validate look direction + if not isinstance(yaml_data['look_dir'], str) or yaml_data['look_dir'].lower() not in ('right', 'left'): + raise ValueError(f'Unknown look direction {yaml_data["look_dir"]}') + + # Support for deprecated location for cube_spacing_in_m + if 'cube_spacing_in_m' in yaml_data: + logger.warning( + 'Run config option cube_spacing_in_m is deprecated. Please use runtime_group.cube_spacing_in_m instead.' + ) + yaml_data['runtime_group']['cube_spacing_in_m'] = yaml_data['cube_spacing_in_m'] - if key == 'weather_model': - run_config[key] = enforce_wm(value, run_config['aoi']) + # Parse the user-provided arguments + height_group_unparsed = HeightGroupUnparsed(**yaml_data['height_group']) + aoi_group_unparsed = AOIGroupUnparsed(**yaml_data['aoi_group']) + runtime_group = RuntimeGroup(**yaml_data['runtime_group']) + aoi_group = AOIGroup( + aoi=get_query_region( + aoi_group_unparsed, + height_group_unparsed, + cube_spacing_in_m=runtime_group.cube_spacing_in_m, + ) + ) - run_config['aoi']._cube_spacing_m = run_config['cube_spacing_in_m'] - return AttributeDict(run_config) + return RunConfig( + look_dir=yaml_data['look_dir'].lower(), + weather_model=parse_weather_model(yaml_data['weather_model'], aoi_group.aoi), + date_group=parse_dates(DateGroupUnparsed(**yaml_data['date_group'])), + time_group=TimeGroup(**yaml_data['time_group']), + aoi_group=aoi_group, + height_group=get_heights( + height_group=height_group_unparsed, + aoi_group=aoi_group_unparsed, + runtime_group=runtime_group, + ), + los_group=LOSGroup( + los=get_los(LOSGroupUnparsed(**yaml_data['los_group'])), + **yaml_data['los_group'] + ), + runtime_group=runtime_group, + ) -def drop_nans(d): +def drop_nans(d: dict[str, Any]) -> dict[str, Any]: + # Must iterate over a copy of the dict's keys because dict.keys() cannot + # be used directly when the dict's size is going to change. for key in list(d.keys()): if d[key] is None: del d[key] @@ -137,24 +155,28 @@ def drop_nans(d): return d -def calcDelays(iargs=None): - """ Parse command line arguments using argparse. """ +def calcDelays(iargs: Optional[Sequence[str]]=None) -> list[Path]: + """Parse command line arguments using argparse.""" import RAiDER import RAiDER.processWM - from RAiDER.delay import tropo_delay from RAiDER.checkArgs import checkArgs - from RAiDER.utilFcns import writeDelays, get_nearest_wmtimes + from RAiDER.delay import tropo_delay + from RAiDER.utilFcns import get_nearest_wmtimes, writeDelays + examples = 'Examples of use:' \ '\n\t raider.py run_config_file.yaml' \ '\n\t raider.py --generate_config template' p = argparse.ArgumentParser( description=HELP_MESSAGE, - epilog=examples, formatter_class=argparse.RawDescriptionHelpFormatter) + epilog=examples, + formatter_class=argparse.RawDescriptionHelpFormatter + ) p.add_argument( '--download_only', action='store_true', + default=False, help='only download a weather model.' ) @@ -162,7 +184,8 @@ def calcDelays(iargs=None): # run with a configuration file. group = p.add_mutually_exclusive_group(required=True) group.add_argument( - '--generate_config', '-g', + '--generate_config', + '-g', nargs='?', choices=[ 'template', @@ -170,44 +193,41 @@ def calcDelays(iargs=None): 'example_LA_GNSS', 'example_UK_isce', ], - help='Generate an example run configuration and exit' + help='Generate an example run configuration and exit', ) group.add_argument( 'run_config_file', nargs='?', + type=lambda s: Path(s).absolute(), help='a YAML file with arguments to RAiDER' ) # if not None, will replace first argument (run_config_file) - args = p.parse_args(args=iargs) + args: RAiDERArgs = p.parse_args(args=iargs, namespace=RAiDERArgs()) # Default example run configuration file ex_run_config_name = args.generate_config or 'template' - ex_run_config_dir = ( - Path(RAiDER.__file__).parent / - 'cli/examples' / ex_run_config_name - ) + ex_run_config_dir = Path(RAiDER.__file__).parent / 'cli/examples' / ex_run_config_name if args.generate_config is not None: for filename in ex_run_config_dir.glob('*'): - dest_path = Path(os.getcwd()) / filename.name + dest_path = Path.cwd() / filename.name if dest_path.exists(): print(f'File {dest_path} already exists. Overwrite? [y/n]') if input().lower() != 'y': continue - shutil.copy(filename, os.getcwd()) + shutil.copy(filename, str(Path.cwd())) logger.info('Wrote: %s', filename) sys.exit() - # args.generate_config now guaranteed to be None # If no run configuration file is provided, look for a ./raider.yaml if args.run_config_file is not None: - if not os.path.isfile(args.run_config_file): - raise FileNotFoundError(args.run_config_file) + if not args.run_config_file.exists(): + raise FileNotFoundError(str(args.run_config_file)) else: - if not os.path.isfile(DEFAULT_RUN_CONFIG_PATH): + if not DEFAULT_RUN_CONFIG_PATH.is_file(): msg = ( - "No run configuration file provided! Specify a run configuration " + 'No run configuration file provided! Specify a run configuration ' "file or have a 'raider.yaml' file in the current directory." ) p.print_usage() @@ -216,42 +236,39 @@ def calcDelays(iargs=None): args.run_config_file = DEFAULT_RUN_CONFIG_PATH # Read the run config file - params = read_run_config_file(args.run_config_file) + run_config = read_run_config_file(args.run_config_file) # Verify the run config file's parameters - params = checkArgs(params) - dl_only = params['download_only'] or args.download_only + run_config = checkArgs(run_config) + dl_only = run_config.runtime_group.download_only or args.download_only - if not params.verbose: + if not run_config.runtime_group.verbose: logger.setLevel(logging.INFO) # Extract and buffer the AOI - los = params['los'] - aoi = params['aoi'] - model = params['weather_model'] + los = run_config.los_group.los + aoi = run_config.aoi_group.aoi + model = run_config.weather_model # adjust user requested AOI by grid size and buffer slightly aoi.add_buffer(model.getLLRes()) # define the xy grid within the buffered bounding box - aoi.set_output_xygrid(params['output_projection']) + aoi.set_output_xygrid(run_config.runtime_group.output_projection) # add a buffer determined by latitude for ray tracing - if los.ray_trace(): - wm_bounds = aoi.calc_buffer_ray(los.getSensorDirection(), - lookDir=los.getLookDirection(), incAngle=30) + if isinstance(los, Raytracing): + wm_bounds = aoi.calc_buffer_ray(los.getSensorDirection(), lookDir=los.getLookDirection(), incAngle=30) else: wm_bounds = aoi.bounds() model.set_latlon_bounds(wm_bounds, output_spacing=aoi.get_output_spacing()) - wet_filenames = [] - for t, w, f in zip( - params['date_list'], - params['wetFilenames'], - params['hydroFilenames'] - ): - + wet_paths: list[Path] = [] + t: dt.datetime + w: str + f: str + for t, w, f in zip(run_config.date_group.date_list, run_config.wetFilenames, run_config.hydroFilenames): ########################################################### # Weather model calculation ########################################################### @@ -259,16 +276,19 @@ def calcDelays(iargs=None): logger.debug(f'Requested date,time: {t.strftime("%Y%m%d, %H:%M")}') logger.debug('Beginning weather model pre-processing') - interp_method = params.get('interpolate_time') + interp_method = run_config.time_group.interpolate_time if interp_method is None: interp_method = 'none' - logger.warning('interp_method is not specified, defaulting to \'none\', i.e. nearest datetime for delay ' - 'calculation') + logger.warning( + "interp_method is not specified, defaulting to 'none', i.e. nearest datetime for delay calculation" + ) - if (interp_method != 'azimuth_time_grid'): - times = get_nearest_wmtimes( - t, [model.dtime() if model.dtime() is not None else 6][0] - ) if interp_method == 'center_time' else [t] + if interp_method != 'azimuth_time_grid': + times = ( + get_nearest_wmtimes(t, [model.dtime() if model.dtime() is not None else 6][0]) + if interp_method == 'center_time' + else [t] + ) elif interp_method == 'azimuth_time_grid': step = model.dtime() @@ -276,21 +296,22 @@ def calcDelays(iargs=None): # Will yield 2 or 3 dates depending if t is within 5 minutes of time step times = get_times_for_azimuth_interpolation(t, time_step_hours) else: - raise NotImplementedError('Only none, center_time, and azimuth_time_grid are accepted values for ' - 'interp_method.') - wfiles = [] + raise NotImplementedError( + 'Only none, center_time, and azimuth_time_grid are accepted values for interp_method.' + ) + wfiles: list[Path] = [] for tt in times: try: wfile = RAiDER.processWM.prepareWeatherModel( model, tt, aoi.bounds(), - makePlots=params['verbose'] + makePlots=run_config.runtime_group.verbose ) - wfiles.append(wfile) + wfiles.append(Path(wfile)) except TryToKeepGoingError: - if interp_method in ['azimuth_time_grid', 'none']: + if interp_method in ('azimuth_time_grid', 'none'): raise DatetimeFailed(model.Model(), tt) else: continue @@ -298,16 +319,11 @@ def calcDelays(iargs=None): # log when something else happens and then continue with the next time except Exception as e: S, N, W, E = wm_bounds - logger.info( - 'Weather model point bounds are ' - f'{S:.2f}/{N:.2f}/{W:.2f}/{E:.2f}' - ) + logger.info('Weather model point bounds are ' f'{S:.2f}/{N:.2f}/{W:.2f}/{E:.2f}') logger.info(f'Query datetime: {tt}') logger.error(e) - logger.error('Weather model files are: {}'.format(wfiles)) - logger.error( - f'Downloading and/or preparation of {model._Name} failed.' - ) + logger.error(f'Weather model files are: {wfiles}') + logger.error(f'Downloading and/or preparation of {model._Name} failed.') continue # dont process the delays for download only @@ -315,72 +331,73 @@ def calcDelays(iargs=None): continue # Get the weather model file - weather_model_file = getWeatherFile( - wfiles, times, t, model._Name, interp_method) + weather_model_file = getWeatherFile(wfiles, times, t, model._Name, interp_method) # Now process the delays try: wet_delay, hydro_delay = tropo_delay( - t, weather_model_file, aoi, los, - height_levels=params['height_levels'], - out_proj=params['output_projection'], - zref=params['zref'] + t, + str(weather_model_file), + aoi, + los, + height_levels=run_config.height_group.height_levels, + out_proj=run_config.runtime_group.output_projection, + zref=run_config.los_group.zref, ) except RuntimeError: - logger.exception("Datetime %s failed", t) + logger.exception('Datetime %s failed', t) continue # Different options depending on the inputs if los.is_Projected(): - out_filename = w.replace("_ztd", "_std") - f = f.replace("_ztd", "_std") + out_filename = w.replace('_ztd', '_std') + hydro_filename = f.replace('_ztd', '_std') elif los.ray_trace(): - out_filename = w.replace("_std", "_ray") - f = f.replace("_std", "_ray") + out_filename = w.replace('_std', '_ray') + hydro_filename = f.replace('_std', '_ray') else: out_filename = w + hydro_filename = f # A dataset was returned by the above # Dataset returned: Cube e.g. GUNW workflow if hydro_delay is None: + out_path = Path(out_filename.replace('wet', 'tropo')) ds = wet_delay - ext = os.path.splitext(out_filename)[1] - out_filename = out_filename.replace('wet', 'tropo') + ext = out_path.suffix # data provenance: include metadata for model and times used - times_str = [t.strftime("%Y%m%dT%H:%M:%S") for t in sorted(times)] - ds = ds.assign_attrs(model_name=model._Name, - model_times_used=times_str, - interpolation_method=interp_method) - if ext not in ['.nc', '.h5']: - out_filename = f'{os.path.splitext(out_filename)[0]}.nc' - - if out_filename.endswith(".nc"): - ds.to_netcdf(out_filename, mode="w") - elif out_filename.endswith(".h5"): - ds.to_netcdf(out_filename, engine="h5netcdf", - invalid_netcdf=True) - - logger.info( - '\nSuccessfully wrote delay cube to: %s\n', out_filename) + times_str = [t.strftime('%Y%m%dT%H:%M:%S') for t in sorted(times)] + ds = ds.assign_attrs(model_name=model._Name, model_times_used=times_str, interpolation_method=interp_method) + if ext not in ('.nc', '.h5'): + out_path = Path(out_path.stem + '.nc') + + if out_path.suffix == '.nc': + ds.to_netcdf(out_path, mode='w') + elif out_path.suffix == '.h5': + ds.to_netcdf(out_path, engine='h5netcdf', invalid_netcdf=True) + + logger.info('\nSuccessfully wrote delay cube to: %s\n', out_path) # Dataset returned: station files, radar_raster, geocoded_file else: + out_path = Path(out_filename) + hydro_path = Path(hydro_filename) if aoi.type() == 'station_file': - out_filename = f'{os.path.splitext(out_filename)[0]}.csv' + out_path = out_path.with_suffix('.csv') - if aoi.type() in ['station_file', 'radar_rasters', 'geocoded_file']: - writeDelays(aoi, wet_delay, hydro_delay, out_filename, - f, outformat=params['raster_format']) + if aoi.type() in ('station_file', 'radar_rasters', 'geocoded_file'): + writeDelays(aoi, wet_delay, hydro_delay, out_path, hydro_path, outformat=run_config.runtime_group.raster_format) - wet_filenames.append(out_filename) + wet_paths.append(out_path) - return wet_filenames + return wet_paths # ------------------------------------------------------ downloadGNSSDelays.py -def downloadGNSS(): +def downloadGNSS() -> None: """Parse command line arguments using argparse.""" from RAiDER.gnss.downloadGNSSDelays import main as dlGNSS + p = argparse.ArgumentParser( formatter_class=argparse.RawDescriptionHelpFormatter, description=""" \ @@ -411,26 +428,42 @@ def downloadGNSS(): directory, across specified range of time (in YYMMDD YYMMDD) and specified time of day, and confined to specified geographic bounding box : downloadGNSSdelay.py --download --out products -y 20100101 20141231 --returntime '00:00:00' -b '39 40 -79 -78' - """) + """, + ) # Stations to check/download area = p.add_argument_group( - 'Stations to check/download. Can be a lat/lon bounding box or file, or will run the whole world if not specified') + 'Stations to check/download. Can be a lat/lon bounding box or file, or will run the whole world if not specified' + ) area.add_argument( - '--station_file', '-f', default=None, dest='station_file', - help=('Text file containing a list of 4-char station IDs separated by newlines')) + '--station_file', + '-f', + default=None, + dest='station_file', + help=('Text file containing a list of 4-char station IDs separated by newlines'), + ) area.add_argument( - '-b', '--bounding_box', dest='bounding_box', type=str, default=None, - help="Provide either valid shapefile or Lat/Lon Bounding SNWE. -- Example : '19 20 -99.5 -98.5'") + '-b', + '--bounding_box', + dest='bounding_box', + type=str, + default=None, + help="Provide either valid shapefile or Lat/Lon Bounding SNWE. -- Example : '19 20 -99.5 -98.5'", + ) area.add_argument( - '--gpsrepo', '-gr', default='UNR', dest='gps_repo', - help=('Specify GPS repository you wish to query. Currently supported archives: UNR.')) + '--gpsrepo', + '-gr', + default='UNR', + dest='gps_repo', + help=('Specify GPS repository you wish to query. Currently supported archives: UNR.'), + ) - misc = p.add_argument_group("Run parameters") + misc = p.add_argument_group('Run parameters') add_out(misc) misc.add_argument( - '--date', dest='dateList', + '--date', + dest='dateList', help=dedent("""\ Date to calculate delay. Can be a single date, a list of two dates (earlier, later) with 1-day interval, or a list of two dates and interval in days (earlier, later, interval). @@ -438,22 +471,27 @@ def downloadGNSS(): YYYYMMDD or YYYYMMDD YYYYMMDD YYYYMMDD YYYYMMDD N - """), + """), nargs="+", action=DateListAction, type=date_type, - required=True + required=True, ) misc.add_argument( - '--returntime', dest='returnTime', + '--returntime', + dest='returnTime', help="Return delays closest to this specified time. If not specified, the GPS delays for all times will be returned. Input in 'HH:MM:SS', e.g. '16:00:00'", - default=None) + default=None, + ) misc.add_argument( '--download', help='Physically download data. Note this option is not necessary to proceed with statistical analyses, as data can be handled virtually in the program.', - action='store_true', dest='download', default=False) + action='store_true', + dest='download', + default=False, + ) add_cpus(misc) add_verbose(misc) @@ -461,79 +499,90 @@ def downloadGNSS(): args = p.parse_args() dlGNSS(args) - return # ------------------------------------------------------------ prepFromGUNW.py -def calcDelaysGUNW(iargs: list[str] = None) -> xr.Dataset: - +def calcDelaysGUNW(iargs: Optional[list[str]] = None) -> Optional[xr.Dataset]: p = argparse.ArgumentParser( description='Calculate a cube of interferometic delays for GUNW files', - formatter_class=argparse.ArgumentDefaultsHelpFormatter) + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) p.add_argument( '--bucket', - help='S3 bucket containing ARIA GUNW NetCDF file. Will be ignored if the --file argument is provided.' + help='S3 bucket containing ARIA GUNW NetCDF file. Will be ignored if the --file argument is provided.', ) p.add_argument( - '--bucket-prefix', default='', + '--bucket-prefix', + default='', help='S3 bucket prefix which may contain an ARIA GUNW NetCDF file to calculate delays for and which the final ' - 'ARIA GUNW NetCDF file will be upload to. Will be ignored if the --file argument is provided.' + 'ARIA GUNW NetCDF file will be upload to. Will be ignored if the --file argument is provided.', ) p.add_argument( '--input-bucket-prefix', help='S3 bucket prefix that contains an ARIA GUNW NetCDF file to calculate delays for. ' - 'If not provided, will look in --bucket-prefix for an ARIA GUNW NetCDF file. ' - 'Will be ignored if the --file argument is provided.' + 'If not provided, will look in --bucket-prefix for an ARIA GUNW NetCDF file. ' + 'Will be ignored if the --file argument is provided.', ) p.add_argument( - '-f', '--file', type=str, - help='1 ARIA GUNW netcdf file' + '-f', + '--file', + type=lambda s: Path(s).absolute(), + help='1 ARIA GUNW netcdf file', ) p.add_argument( - '-m', '--weather-model', default='HRRR', type=str, - choices=['None'] + ALLOWED_MODELS, help='Weather model.' + '-m', + '--weather-model', + default='HRRR', + choices=['None'] + ALLOWED_MODELS, + help='Weather model.' ) p.add_argument( - '-uid', '--api_uid', default=None, type=str, - help='Weather model API UID [uid, email, username], depending on model.' + '-uid', + '--api_uid', + default=None, + help='Weather model API UID [uid, email, username], depending on model.', ) p.add_argument( - '-key', '--api_key', default=None, type=str, + '-key', + '--api_key', + default=None, help='Weather model API KEY [key, password], depending on model.' ) p.add_argument( - '-interp', '--interpolate-time', default='azimuth_time_grid', type=str, + '-interp', + '--interpolate-time', + default='azimuth_time_grid', choices=TIME_INTERPOLATION_METHODS, - help=('How to interpolate across model time steps. Possible options are: ' - '[\'none\', \'center_time\', \'azimuth_time_grid\'] ' - 'None: means nearest model time; center_time: linearly across center time; ' - 'Azimuth_time_grid: means every pixel is weighted with respect to azimuth time of S1;' - ) + help=( + 'How to interpolate across model time steps. Possible options are: ' + f"{TIME_INTERPOLATION_METHODS} " + 'None: means nearest model time; center_time: linearly across center time; ' + 'Azimuth_time_grid: means every pixel is weighted with respect to azimuth time of S1' + ), ) p.add_argument( - '-o', '--output-directory', default=os.getcwd(), type=str, + '-o', + '--output-directory', + default=Path.cwd(), + type=lambda s: Path(s).absolute(), help='Directory to store results.' ) - iargs = p.parse_args(iargs) + args: CalcDelaysArgsUnparsed = p.parse_args(iargs, namespace=CalcDelaysArgsUnparsed()) - if not iargs.input_bucket_prefix: - iargs.input_bucket_prefix = iargs.bucket_prefix + if args.input_bucket_prefix is None: + args.input_bucket_prefix = args.bucket_prefix - if iargs.interpolate_time not in ['none', 'center_time', 'azimuth_time_grid']: - raise ValueError( - 'interpolate_time arg must be in [\'none\', \'center_time\', \'azimuth_time_grid\']') - - if iargs.weather_model == 'None': + if args.weather_model == 'None': # NOTE: HyP3's current step function implementation does not have a good way of conditionally # running processing steps. This allows HyP3 to always run this step but exit immediately # and do nothing if tropospheric correction via RAiDER is not selected. This should not cause @@ -541,123 +590,134 @@ def calcDelaysGUNW(iargs: list[str] = None) -> xr.Dataset: print('Nothing to do!') return - if iargs.file and (iargs.weather_model == 'HRRR') and (iargs.interpolate_time == 'azimuth_time_grid'): - file_name = iargs.file.split('/')[-1] - gunw_id = file_name.replace('.nc', '') + if ( + args.file is not None and + args.weather_model == 'HRRR' and + args.interpolate_time == 'azimuth_time_grid' + ): + gunw_id = args.file.name.replace('.nc', '') if not RAiDER.aria.prepFromGUNW.check_hrrr_dataset_availablity_for_s1_azimuth_time_interpolation(gunw_id): - raise NoWeatherModelData( - 'The required HRRR data for time-grid interpolation is not available') + raise NoWeatherModelData('The required HRRR data for time-grid interpolation is not available') + + if args.file is None: + if args.bucket is None or args.bucket_prefix is None: + raise ValueError('Either argument --file or --bucket must be provided') - if not iargs.file and iargs.bucket: # only use GUNW ID for checking if HRRR available - iargs.file = aws.get_s3_file( - iargs.bucket, iargs.input_bucket_prefix, '.nc') - if iargs.file is None: + args.file = aws.get_s3_file( + args.bucket, + cast(str, args.input_bucket_prefix), # guaranteed not None at this point + '.nc' + ) + if args.file is None: raise ValueError( - 'GUNW product file could not be found at' - f's3://{iargs.bucket}/{iargs.input_bucket_prefix}' + 'GUNW product file could not be found at' f's3://{args.bucket}/{args.input_bucket_prefix}' ) - if iargs.weather_model == 'HRRR' and (iargs.interpolate_time == 'azimuth_time_grid'): - file_name_str = str(iargs.file) - gunw_nc_name = file_name_str.split('/')[-1] - gunw_id = gunw_nc_name.replace('.nc', '') + if args.weather_model == 'HRRR' and args.interpolate_time == 'azimuth_time_grid': + gunw_id = args.file.name.replace('.nc', '') if not RAiDER.aria.prepFromGUNW.check_hrrr_dataset_availablity_for_s1_azimuth_time_interpolation(gunw_id): - print('The required HRRR data for time-grid interpolation is not available; returning None and not modifying GUNW dataset') + print( + 'The required HRRR data for time-grid interpolation is not available; returning None and not modifying GUNW dataset' + ) return # Download file to obtain metadata - if not RAiDER.aria.prepFromGUNW.check_weather_model_availability(iargs.file, iargs.weather_model): + if not RAiDER.aria.prepFromGUNW.check_weather_model_availability(args.file, args.weather_model): # NOTE: We want to submit jobs that are outside of acceptable weather model range # and still deliver these products to the DAAC without this layer. Therefore # we include this within this portion of the control flow. print('Nothing to do because outside of weather model range') return json_file_path = aws.get_s3_file( - iargs.bucket, iargs.input_bucket_prefix, '.json') + args.bucket, + cast(str, args.input_bucket_prefix), + '.json' + ) if json_file_path is None: raise ValueError( - 'GUNW metadata file could not be found at' - f's3://{iargs.bucket}/{iargs.input_bucket_prefix}' + 'GUNW metadata file could not be found at' f's3://{args.bucket}/{args.input_bucket_prefix}' ) - json_data = json.load(open(json_file_path)) - json_data['metadata'].setdefault( - 'weather_model', []).append(iargs.weather_model) - json.dump(json_data, open(json_file_path, 'w')) + with json_file_path.open() as f: + json_data = json.load(f) + json_data['metadata'].setdefault('weather_model', []).append(args.weather_model) + with json_file_path.open('w') as f: + json.dump(json_data, f) # also get browse image -- if RAiDER is running in its own HyP3 job, the browse image will be needed for ingest - browse_file_path = aws.get_s3_file( - iargs.bucket, iargs.input_bucket_prefix, '.png') + browse_file_path = aws.get_s3_file(args.bucket, args.input_bucket_prefix, '.png') if browse_file_path is None: raise ValueError( - 'GUNW browse image could not be found at' - f's3://{iargs.bucket}/{iargs.input_bucket_prefix}' + 'GUNW browse image could not be found at' f's3://{args.bucket}/{args.input_bucket_prefix}' ) - elif not iargs.file: - raise ValueError('Either argument --file or --bucket must be provided') + args = cast(CalcDelaysArgs, args) # prep the config needed for delay calcs - path_cfg, wavelength = RAiDER.aria.prepFromGUNW.main(iargs) + path_cfg, wavelength = RAiDER.aria.prepFromGUNW.main(args) # write delay cube (nc) to disk using config # return a list with the path to cube for each date - cube_filenames = calcDelays([path_cfg]) + cube_filenames = calcDelays([str(path_cfg)]) assert len(cube_filenames) == 2, 'Incorrect number of delay files written.' # calculate the interferometric phase and write it out - ds = RAiDER.aria.calcGUNW.tropo_gunw_slc(cube_filenames, - iargs.file, - wavelength, - ) + ds = RAiDER.aria.calcGUNW.tropo_gunw_slc( + cube_filenames, + args.file, + wavelength, + ) # upload to s3 - if iargs.bucket: - aws.upload_file_to_s3(iargs.file, iargs.bucket, iargs.bucket_prefix) - aws.upload_file_to_s3( - json_file_path, iargs.bucket, iargs.bucket_prefix) - aws.upload_file_to_s3( - browse_file_path, iargs.bucket, iargs.bucket_prefix) + if args.bucket is not None: + aws.upload_file_to_s3(args.file, args.bucket, args.bucket_prefix) + aws.upload_file_to_s3(json_file_path, args.bucket, args.bucket_prefix) + aws.upload_file_to_s3(browse_file_path, args.bucket, args.bucket_prefix) return ds # ------------------------------------------------------------ processDelays.py -def combineZTDFiles(): - ''' - Command-line program to process delay files from RAiDER and GNSS into a single file. - ''' - from RAiDER.gnss.processDelayFiles import main, combineDelayFiles, create_parser +def combineZTDFiles() -> None: + """Command-line program to process delay files from RAiDER and GNSS into a single file.""" + from RAiDER.gnss.processDelayFiles import combineDelayFiles, create_parser, main p = create_parser() - args = p.parse_args() + args: RAiDERCombineArgs = p.parse_args(namespace=RAiDERCombineArgs()) - if not os.path.exists(args.raider_file): + if not args.raider_file.exists(): combineDelayFiles(args.raider_file, loc=args.raider_folder) - if not os.path.exists(args.gnss_file): - combineDelayFiles(args.gnss_file, loc=args.gnss_folder, source='GNSS', - ref=args.raider_file, col_name=args.column_name) - - if args.gnss_file is not None: - main( - args.raider_file, - args.gnss_file, - col_name=args.column_name, - raider_delay=args.raider_column_name, - outName=args.out_name, - localTime=args.local_time + if args.gnss_file is None: + return + + if not args.gnss_file.exists(): + combineDelayFiles( + args.gnss_file, loc=args.gnss_folder, source='GNSS', ref=args.raider_file, col_name=args.column_name ) + main( + args.raider_file, + args.gnss_file, + col_name=args.column_name, + raider_delay=args.raider_column_name, + out_path=args.out_name, + local_time=args.local_time, + ) + -def getWeatherFile(wfiles, times, t, model, interp_method='none'): - ''' - # Time interpolation - # - # Need to handle various cases, including if the exact weather model time is - # requested, or if one or more datetimes are not available from the weather - # model data provider - ''' +def getWeatherFile( + wfiles: list[Path], + times: list, + time: dt.datetime, + model: str, + interp_method: TimeInterpolationMethod='none' +) -> Optional[Path]: + """Time interpolation. + Need to handle various cases, including if the exact weather model time is + requested, or if one or more datetimes are not available from the weather + model data provider + """ # time interpolation method: number of expected files EXPECTED_NUM_FILES = {'none': 1, 'center_time': 2, 'azimuth_time_grid': 3} @@ -667,11 +727,10 @@ def getWeatherFile(wfiles, times, t, model, interp_method='none'): try: Nfiles_expected = EXPECTED_NUM_FILES[interp_method] except KeyError: - raise ValueError( - 'getWeatherFile: interp_method {} is not known'.format(interp_method)) + raise ValueError(f'getWeatherFile: interp_method {interp_method} is not known') - Nmatch = (Nfiles_expected == Nfiles) - Tmatch = (Nfiles == Ntimes) + Nmatch = Nfiles_expected == Nfiles + Tmatch = Nfiles == Ntimes # Case 1: no files downloaded if Nfiles == 0: @@ -679,74 +738,62 @@ def getWeatherFile(wfiles, times, t, model, interp_method='none'): return None # Case 2 - nearest weather model time is requested and retrieved - if (interp_method == 'none'): + if interp_method == 'none': weather_model_file = wfiles[0] - elif (interp_method == 'center_time'): - + elif interp_method == 'center_time': if Nmatch: # Case 3: two weather files downloaded - weather_model_file = combine_weather_files( - wfiles, - t, - model, - interp_method='center_time' - ) + weather_model_file = combine_weather_files(wfiles, time, model, interp_method='center_time') elif Tmatch: # Case 4: Exact time is available without interpolation - logger.warning( - 'Time interpolation is not needed as exact time is available') + logger.warning('Time interpolation is not needed as exact time is available') weather_model_file = wfiles[0] elif Nfiles == 1: # Case 5: one file does not download for some reason logger.warning( - 'getWeatherFile: One datetime is not available to download, defaulting to nearest available date') + 'getWeatherFile: One datetime is not available to download, defaulting to nearest available date' + ) weather_model_file = wfiles[0] else: raise WrongNumberOfFiles(Nfiles_expected, Nfiles) - elif (interp_method) == 'azimuth_time_grid': - + elif interp_method == 'azimuth_time_grid': if Nmatch or Tmatch: # Case 6: all files downloaded - weather_model_file = combine_weather_files( - wfiles, - t, - model, - interp_method='azimuth_time_grid' - ) + weather_model_file = combine_weather_files(wfiles, time, model, interp_method='azimuth_time_grid') else: raise WrongNumberOfFiles(Nfiles_expected, Nfiles) # Case 7 - Anything else errors out else: - N = len(wfiles) - raise NotImplementedError(f'The {interp_method} with {N} retrieved weather model files was not well posed ' - 'for the current workflow.') + raise NotImplementedError( + f'The {interp_method} with {len(wfiles)} retrieved weather model files was not well posed ' + 'for the current workflow.' + ) return weather_model_file -def combine_weather_files(wfiles, t, model, interp_method='center_time'): - '''Interpolate downloaded weather files and save to a single file''' - - STYLE = {'center_time': '_timeInterp_', - 'azimuth_time_grid': '_timeInterpAziGrid_'} +def combine_weather_files(wfiles: list[Path], time: dt.datetime, model: str, interp_method: TimeInterpolationMethod='center_time') -> Path: + """Interpolate downloaded weather files and save to a single file.""" + STYLE = {'center_time': '_timeInterp_', 'azimuth_time_grid': '_timeInterpAziGrid_'} # read the individual datetime datasets datasets = [xr.open_dataset(f) for f in wfiles] # Pull the datetimes from the datasets - times = [] + times: list[dt.datetime] = [] for ds in datasets: - times.append(datetime.datetime.strptime( - ds.attrs['datetime'], '%Y_%m_%dT%H_%M_%S')) + times.append(dt.datetime.strptime(ds.attrs['datetime'], '%Y_%m_%dT%H_%M_%S')) if len(times) == 0: raise NoWeatherModelData() # calculate relative weights of each dataset if interp_method == 'center_time': - wgts = get_weights_time_interp(times, t) + wgts = get_weights_time_interp(times, time) elif interp_method == 'azimuth_time_grid': - time_grid = get_time_grid_for_aztime_interp(datasets, t, model) + time_grid = get_time_grid_for_aztime_interp(datasets, time, model) wgts = get_inverse_weights_for_dates(time_grid, times) + else: # interp_method == 'none' + raise ValueError('Interpolating weather files is not available with interpolation method "none"') # combine datasets ds_out = datasets[0] @@ -756,11 +803,12 @@ def combine_weather_files(wfiles, t, model, interp_method='center_time'): ds_out.attrs['Date2'] = 0 # Give the weighted combination a new file name - weather_model_file = os.path.join( - os.path.dirname(wfiles[0]), - os.path.basename(wfiles[0]).split('_')[0] + '_' + - t.strftime('%Y_%m_%dT%H_%M_%S') + STYLE[interp_method] + - '_'.join(wfiles[0].split('_')[-4:]), + weather_model_file = wfiles[0].parent / ( + wfiles[0].name.split('_')[0] + + '_' + + time.strftime('%Y_%m_%dT%H_%M_%S') + + STYLE[interp_method] + + '_'.join(wfiles[0].name.split('_')[-4:]) ) # write the combined results to disk @@ -769,21 +817,19 @@ def combine_weather_files(wfiles, t, model, interp_method='center_time'): return weather_model_file -def combine_files_using_azimuth_time(wfiles, t, times): - '''Combine files using azimuth time interpolation''' - +def combine_files_using_azimuth_time(wfiles, time: dt.datetime, times: list[dt.datetime]): + """Combine files using azimuth time interpolation.""" # read the individual datetime datasets datasets = [xr.open_dataset(f) for f in wfiles] # Pull the datetimes from the datasets - times = [] + times: list[dt.datetime] = [] for ds in datasets: - times.append(datetime.datetime.strptime( - ds.attrs['datetime'], '%Y_%m_%dT%H_%M_%S')) + times.append(dt.datetime.strptime(ds.attrs['datetime'], '%Y_%m_%dT%H_%M_%S')) model = datasets[0].attrs['model_name'] - time_grid = get_time_grid_for_aztime_interp(datasets, times, t, model) + time_grid = get_time_grid_for_aztime_interp(datasets, times, time, model) wgts = get_inverse_weights_for_dates(time_grid, times) @@ -797,8 +843,11 @@ def combine_files_using_azimuth_time(wfiles, t, times): # Give the weighted combination a new file name weather_model_file = os.path.join( os.path.dirname(wfiles[0]), - os.path.basename(wfiles[0]).split('_')[0] + '_' + t.strftime( - '%Y_%m_%dT%H_%M_%S') + '_timeInterpAziGrid_' + '_'.join(wfiles[0].split('_')[-4:]), + os.path.basename(wfiles[0]).split('_')[0] + + '_' + + time.strftime('%Y_%m_%dT%H_%M_%S') + + '_timeInterpAziGrid_' + + '_'.join(wfiles[0].split('_')[-4:]), ) # write the combined results to disk @@ -807,25 +856,22 @@ def combine_files_using_azimuth_time(wfiles, t, times): return weather_model_file -def get_weights_time_interp(times, t): - '''Calculate weights for time interpolation using simple inverse linear weighting''' +def get_weights_time_interp(times: list[dt.datetime], time: dt.datetime) -> Optional[list[float]]: + """Calculate weights for time interpolation using simple inverse linear weighting.""" date1, date2 = times - wgts = [1 - get_dt(t, date1) / get_dt(date2, date1), 1 - - get_dt(date2, t) / get_dt(date2, date1)] + wgts = [1 - get_dt(time, date1) / get_dt(date2, date1), 1 - get_dt(date2, time) / get_dt(date2, date1)] try: assert np.isclose(np.sum(wgts), 1) except AssertionError: - logger.error( - 'Time interpolation weights do not sum to one; something is off with query datetime: %s', t) + logger.error('Time interpolation weights do not sum to one; something is off with query datetime: %s', time) return None return wgts -def get_time_grid_for_aztime_interp(datasets, t, model): - '''Calculate the time-varying grid for use with azimuth time interpolation''' - +def get_time_grid_for_aztime_interp(datasets: list[xr.Dataset], time: dt.datetime, model: str) -> np.ndarray: + """Calculate the time-varying grid for use with azimuth time interpolation.""" # Each model will require some inspection here # the subsequent s1 azimuth time grid requires dimension # inputs to all have same dimensions and either be @@ -843,13 +889,10 @@ def get_time_grid_for_aztime_interp(datasets, t, model): hgt = np.broadcast_to(z_1d[:, None, None], (m, n, p)) else: - raise NotImplementedError( - 'Azimuth Time is currently only implemented for HRRR') + raise NotImplementedError('Azimuth Time is currently only implemented for HRRR') - time_grid = get_s1_azimuth_time_grid( - lon, lat, hgt, t) # This is the acq time from loop + time_grid = get_s1_azimuth_time_grid(lon, lat, hgt, time) # This is the acq time from loop if np.any(np.isnan(time_grid)): - raise ValueError( - 'The Time Grid return nans meaning no orbit was downloaded.') + raise ValueError('The Time Grid return nans meaning no orbit was downloaded.') return time_grid diff --git a/tools/RAiDER/cli/statsPlot.py b/tools/RAiDER/cli/statsPlot.py index d75526689..bddde5cb9 100755 --- a/tools/RAiDER/cli/statsPlot.py +++ b/tools/RAiDER/cli/statsPlot.py @@ -5,34 +5,39 @@ # RESERVED. United States Government Sponsorship acknowledged. # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -from RAiDER.logger import logger, logging -from RAiDER.cli.parser import add_cpus -from RAiDER.utilFcns import WGS84_to_UTM -from rasterio.transform import Affine -from scipy import optimize -from scipy.optimize import OptimizeWarning -from shapely.strtree import STRtree -from shapely.geometry import Point, Polygon -from matplotlib import pyplot as plt -import pandas as pd -import numpy as np -import rasterio import argparse import copy import datetime as dt import itertools -import multiprocessing +import multiprocessing as mp import os import warnings import matplotlib as mpl +import numpy as np +import pandas as pd +import rasterio +from matplotlib import pyplot as plt +from rasterio.transform import Affine +from scipy import optimize +from scipy.optimize import OptimizeWarning +from shapely.geometry import Point, Polygon +from shapely.strtree import STRtree + +from RAiDER.cli.parser import add_cpus +from RAiDER.logger import logger, logging +from RAiDER.utilFcns import WGS84_to_UTM + + # must switch to Agg to avoid multiprocessing crashes mpl.use('Agg') def create_parser(): """Parse command line arguments using argparse.""" - parser = argparse.ArgumentParser(formatter_class=argparse.RawDescriptionHelpFormatter, description=""" + parser = argparse.ArgumentParser( + formatter_class=argparse.RawDescriptionHelpFormatter, + description=""" Perform basic statistical analyses concerning the spatiotemporal distribution of zenith delays. Specifically, make any of the following specified plot(s): @@ -48,117 +53,340 @@ def create_parser(): Example call to plot gridded station variogram in a specific time interval and through explicitly the summer seasons: raiderStats.py -f -grid_delay_mean -ti '2016-01-01 2018-01-01' --seasonalinterval '06-21 09-21' -variogramplot -""") +""", + ) # User inputs - userinps = parser.add_argument_group( - 'User inputs/options for which especially careful review is recommended') - userinps.add_argument('-f', '--file', dest='fname', - type=str, required=True, help='Final output file generated from downloadGNSSDelays.py which contains GPS zenith delays for a specified time period and spatial footprint. ') - userinps.add_argument('-c', '--column_name', dest='col_name', type=str, default='ZTD', - help='Name of the input column to plot. Input assumed to be in units of meters') - userinps.add_argument('-u', '--unit', dest='unit', type=str, default='m', - help='Specified output unit (as distance or time), by default m. Input unit assumed to be m following convention in downloadGNSSDelays.py. Refer to "convert_SI" for supported units. Note if you specify time unit here, you must specify input for "--obs_errlimit" to be in units of m') - userinps.add_argument('-w', '--workdir', dest='workdir', default='./', - help='Specify directory to deposit all outputs. Default is local directory where script is launched.') + userinps = parser.add_argument_group('User inputs/options for which especially careful review is recommended') + userinps.add_argument( + '-f', + '--file', + dest='fname', + type=str, + required=True, + help='Final output file generated from downloadGNSSDelays.py which contains GPS zenith delays for a specified time period and spatial footprint. ', + ) + userinps.add_argument( + '-c', + '--column_name', + dest='col_name', + type=str, + default='ZTD', + help='Name of the input column to plot. Input assumed to be in units of meters', + ) + userinps.add_argument( + '-u', + '--unit', + dest='unit', + type=str, + default='m', + help='Specified output unit (as distance or time), by default m. Input unit assumed to be m following convention in downloadGNSSDelays.py. Refer to "convert_SI" for supported units. Note if you specify time unit here, you must specify input for "--obs_errlimit" to be in units of m', + ) + userinps.add_argument( + '-w', + '--workdir', + dest='workdir', + default='./', + help='Specify directory to deposit all outputs. Default is local directory where script is launched.', + ) add_cpus(userinps) - userinps.add_argument('-verbose', '--verbose', action='store_true', dest='verbose', - help="Run in verbose (debug) mode. Default False") + userinps.add_argument( + '-verbose', + '--verbose', + action='store_true', + dest='verbose', + help='Run in verbose (debug) mode. Default False' + ) # Spatiotemporal subset options - dtsubsets = parser.add_argument_group( - 'Controls for spatiotemporal subsetting.') - dtsubsets.add_argument('-b', '--bounding_box', dest='bounding_box', type=str, default=None, - help="Provide either valid shapefile or Lat/Lon Bounding SNWE. -- Example : '19 20 -99.5 -98.5'") - dtsubsets.add_argument('-sp', '--spacing', dest='spacing', type=float, default='1', - help='Specify spacing of grid-cells for statistical analyses. By default 1 deg.') - dtsubsets.add_argument('-ti', '--timeinterval', dest='timeinterval', type=str, default=None, - help="Subset in time by specifying earliest YYYY-MM-DD date followed by latest date YYYY-MM-DD. -- Example : '2016-01-01 2019-01-01'.") - dtsubsets.add_argument('-si', '--seasonalinterval', dest='seasonalinterval', type=str, default=None, - help="Subset in by an specific interval for each year by specifying earliest MM-DD time followed by latest MM-DD time. -- Example : '03-21 06-21'.") - dtsubsets.add_argument('-oe', '--obs_errlimit', dest='obs_errlimit', type=float, default='inf', - help="Observation error threshold to discard observations with large uncertainties.") + dtsubsets = parser.add_argument_group('Controls for spatiotemporal subsetting.') + dtsubsets.add_argument( + '-b', + '--bounding_box', + dest='bounding_box', + type=str, + default=None, + help="Provide either valid shapefile or Lat/Lon Bounding SNWE. -- Example : '19 20 -99.5 -98.5'", + ) + dtsubsets.add_argument( + '-sp', + '--spacing', + dest='spacing', + type=float, + default='1', + help='Specify spacing of grid-cells for statistical analyses. By default 1 deg.', + ) + dtsubsets.add_argument( + '-ti', + '--timeinterval', + dest='timeinterval', + type=str, + default=None, + help="Subset in time by specifying earliest YYYY-MM-DD date followed by latest date YYYY-MM-DD. -- Example : '2016-01-01 2019-01-01'.", + ) + dtsubsets.add_argument( + '-si', + '--seasonalinterval', + dest='seasonalinterval', + type=str, + default=None, + help="Subset in by an specific interval for each year by specifying earliest MM-DD time followed by latest MM-DD time. -- Example : '03-21 06-21'.", + ) + dtsubsets.add_argument( + '-oe', + '--obs_errlimit', + dest='obs_errlimit', + type=float, + default='inf', + help='Observation error threshold to discard observations with large uncertainties.', + ) # Plot formatting/options - pltformat = parser.add_argument_group( - 'Optional controls for plot formatting/options.') - pltformat.add_argument('-figdpi', '--figdpi', dest='figdpi', type=int, - default=100, help='DPI to use for saving figures') - pltformat.add_argument('-title', '--user_title', dest='user_title', type=str, - default=None, help='Specify custom title for plots.') - pltformat.add_argument('-fmt', '--plot_format', dest='plot_fmt', type=str, - default='png', help='Plot format to use for saving figures') - pltformat.add_argument('-cb', '--color_bounds', dest='cbounds', type=str, - default=None, help='List of two floats to use as color axis bounds') - pltformat.add_argument('-cp', '--colorpercentile', dest='colorpercentile', type=float, default=None, nargs=2, - help='Set low and upper percentile for plot colorbars. By default 25%% and 95%%, respectively.') - pltformat.add_argument('-cm', '--colormap', dest='usr_colormap', type=str, default='hot_r', - help='Specify matplotlib colorbar.') - pltformat.add_argument('-dt', '--densitythreshold', dest='densitythreshold', type=int, default='10', - help='For variogram plots, given grid-cell is only valid if it contains this specified threshold of stations. By default 10 stations.') - pltformat.add_argument('-sg', '--stationsongrids', dest='stationsongrids', action='store_true', - help='In gridded plots, superimpose your gridded array with a scatterplot of station locations.') - pltformat.add_argument('-dg', '--drawgridlines', dest='drawgridlines', - action='store_true', help='Draw gridlines on gridded plots.') - pltformat.add_argument('-tl', '--time_lines', dest='time_lines', - action='store_true', help='Draw central longitudinal lines with respect to datetime. Most useful for local-time analyses.') - pltformat.add_argument('-plotall', '--plotall', action='store_true', dest='plotall', - help="Generate all supported plots, including variogram plots.") - pltformat.add_argument('-min_span', '--min_span', dest='min_span', type=float, - default=[2, 0.6], nargs=2, help="Minimum TS span (years) and minimum fractional observations in span (fraction) imposed for seasonal amplitude/phase analyses to be performed for a given station.") - pltformat.add_argument('-period_limit', '--period_limit', dest='period_limit', type=float, - default=0., help="period limit (years) imposed for seasonal amplitude/phase analyses to be performed for a given station.") + pltformat = parser.add_argument_group('Optional controls for plot formatting/options.') + pltformat.add_argument( + '-figdpi', + '--figdpi', + dest='figdpi', + type=int, + default=100, + help='DPI to use for saving figures' + ) + pltformat.add_argument( + '-title', + '--user_title', + dest='user_title', + type=str, + default=None, + help='Specify custom title for plots.' + ) + pltformat.add_argument( + '-fmt', + '--plot_format', + dest='plot_fmt', + type=str, + default='png', + help='Plot format to use for saving figures' + ) + pltformat.add_argument( + '-cb', + '--color_bounds', + dest='cbounds', + type=str, + default=None, + help='List of two floats to use as color axis bounds', + ) + pltformat.add_argument( + '-cp', + '--colorpercentile', + dest='colorpercentile', + type=float, + default=None, + nargs=2, + help='Set low and upper percentile for plot colorbars. By default 25%% and 95%%, respectively.', + ) + pltformat.add_argument( + '-cm', '--colormap', dest='usr_colormap', type=str, default='hot_r', help='Specify matplotlib colorbar.' + ) + pltformat.add_argument( + '-dt', + '--densitythreshold', + dest='densitythreshold', + type=int, + default='10', + help='For variogram plots, given grid-cell is only valid if it contains this specified threshold of stations. By default 10 stations.', + ) + pltformat.add_argument( + '-sg', + '--stationsongrids', + dest='stationsongrids', + action='store_true', + help='In gridded plots, superimpose your gridded array with a scatterplot of station locations.', + ) + pltformat.add_argument( + '-dg', '--drawgridlines', dest='drawgridlines', action='store_true', help='Draw gridlines on gridded plots.' + ) + pltformat.add_argument( + '-tl', + '--time_lines', + dest='time_lines', + action='store_true', + help='Draw central longitudinal lines with respect to datetime. Most useful for local-time analyses.', + ) + pltformat.add_argument( + '-plotall', + '--plotall', + action='store_true', + dest='plotall', + help='Generate all supported plots, including variogram plots.', + ) + pltformat.add_argument( + '-min_span', + '--min_span', + dest='min_span', + type=float, + default=[2, 0.6], + nargs=2, + help='Minimum TS span (years) and minimum fractional observations in span (fraction) imposed for seasonal amplitude/phase analyses to be performed for a given station.', + ) + pltformat.add_argument( + '-period_limit', + '--period_limit', + dest='period_limit', + type=float, + default=0.0, + help='period limit (years) imposed for seasonal amplitude/phase analyses to be performed for a given station.', + ) # All plot types # Station scatter-plots - pltscatter = parser.add_argument_group( - 'Supported types of individual station scatter-plots.') - pltscatter.add_argument('-station_distribution', '--station_distribution', - action='store_true', dest='station_distribution', help="Plot station distribution.") - pltscatter.add_argument('-station_delay_mean', '--station_delay_mean', - action='store_true', dest='station_delay_mean', help="Plot station mean delay.") - pltscatter.add_argument('-station_delay_median', '--station_delay_median', - action='store_true', dest='station_delay_median', help="Plot station median delay.") - pltscatter.add_argument('-station_delay_stdev', '--station_delay_stdev', - action='store_true', dest='station_delay_stdev', help="Plot station delay stdev.") - pltscatter.add_argument('-station_seasonal_phase', '--station_seasonal_phase', - action='store_true', dest='station_seasonal_phase', help="Plot station delay phase/amplitude.") - pltscatter.add_argument('-phaseamp_per_station', '--phaseamp_per_station', - action='store_true', dest='phaseamp_per_station', help="Save debug figures of curve-fit vs data per station.") + pltscatter = parser.add_argument_group('Supported types of individual station scatter-plots.') + pltscatter.add_argument( + '-station_distribution', + '--station_distribution', + action='store_true', + dest='station_distribution', + help='Plot station distribution.', + ) + pltscatter.add_argument( + '-station_delay_mean', + '--station_delay_mean', + action='store_true', + dest='station_delay_mean', + help='Plot station mean delay.', + ) + pltscatter.add_argument( + '-station_delay_median', + '--station_delay_median', + action='store_true', + dest='station_delay_median', + help='Plot station median delay.', + ) + pltscatter.add_argument( + '-station_delay_stdev', + '--station_delay_stdev', + action='store_true', + dest='station_delay_stdev', + help='Plot station delay stdev.', + ) + pltscatter.add_argument( + '-station_seasonal_phase', + '--station_seasonal_phase', + action='store_true', + dest='station_seasonal_phase', + help='Plot station delay phase/amplitude.', + ) + pltscatter.add_argument( + '-phaseamp_per_station', + '--phaseamp_per_station', + action='store_true', + dest='phaseamp_per_station', + help='Save debug figures of curve-fit vs data per station.', + ) # Gridded plots pltgrids = parser.add_argument_group('Supported types of gridded plots.') - pltgrids.add_argument('-grid_heatmap', '--grid_heatmap', action='store_true', - dest='grid_heatmap', help="Plot gridded station heatmap.") - pltgrids.add_argument('-grid_delay_mean', '--grid_delay_mean', action='store_true', - dest='grid_delay_mean', help="Plot gridded station-wise mean delay.") - pltgrids.add_argument('-grid_delay_median', '--grid_delay_median', action='store_true', - dest='grid_delay_median', help="Plot gridded station-wise median delay.") - pltgrids.add_argument('-grid_delay_stdev', '--grid_delay_stdev', action='store_true', - dest='grid_delay_stdev', help="Plot gridded station-wise delay stdev.") - pltgrids.add_argument('-grid_seasonal_phase', '--grid_seasonal_phase', action='store_true', - dest='grid_seasonal_phase', help="Plot gridded station-wise delay phase/amplitude.") - pltgrids.add_argument('-grid_delay_absolute_mean', '--grid_delay_absolute_mean', action='store_true', - dest='grid_delay_absolute_mean', help="Plot absolute gridded station mean delay.") - pltgrids.add_argument('-grid_delay_absolute_median', '--grid_delay_absolute_median', action='store_true', - dest='grid_delay_absolute_median', help="Plot absolute gridded station median delay.") - pltgrids.add_argument('-grid_delay_absolute_stdev', '--grid_delay_absolute_stdev', action='store_true', - dest='grid_delay_absolute_stdev', help="Plot absolute gridded station delay stdev.") - pltgrids.add_argument('-grid_seasonal_absolute_phase', '--grid_seasonal_absolute_phase', action='store_true', - dest='grid_seasonal_absolute_phase', help="Plot absolute gridded station delay phase/amplitude.") - pltgrids.add_argument('-grid_to_raster', '--grid_to_raster', action='store_true', - dest='grid_to_raster', help="Save gridded array as raster. May directly load/plot in successive script call.") + pltgrids.add_argument( + '-grid_heatmap', + '--grid_heatmap', + action='store_true', + dest='grid_heatmap', + help='Plot gridded station heatmap.', + ) + pltgrids.add_argument( + '-grid_delay_mean', + '--grid_delay_mean', + action='store_true', + dest='grid_delay_mean', + help='Plot gridded station-wise mean delay.', + ) + pltgrids.add_argument( + '-grid_delay_median', + '--grid_delay_median', + action='store_true', + dest='grid_delay_median', + help='Plot gridded station-wise median delay.', + ) + pltgrids.add_argument( + '-grid_delay_stdev', + '--grid_delay_stdev', + action='store_true', + dest='grid_delay_stdev', + help='Plot gridded station-wise delay stdev.', + ) + pltgrids.add_argument( + '-grid_seasonal_phase', + '--grid_seasonal_phase', + action='store_true', + dest='grid_seasonal_phase', + help='Plot gridded station-wise delay phase/amplitude.', + ) + pltgrids.add_argument( + '-grid_delay_absolute_mean', + '--grid_delay_absolute_mean', + action='store_true', + dest='grid_delay_absolute_mean', + help='Plot absolute gridded station mean delay.', + ) + pltgrids.add_argument( + '-grid_delay_absolute_median', + '--grid_delay_absolute_median', + action='store_true', + dest='grid_delay_absolute_median', + help='Plot absolute gridded station median delay.', + ) + pltgrids.add_argument( + '-grid_delay_absolute_stdev', + '--grid_delay_absolute_stdev', + action='store_true', + dest='grid_delay_absolute_stdev', + help='Plot absolute gridded station delay stdev.', + ) + pltgrids.add_argument( + '-grid_seasonal_absolute_phase', + '--grid_seasonal_absolute_phase', + action='store_true', + dest='grid_seasonal_absolute_phase', + help='Plot absolute gridded station delay phase/amplitude.', + ) + pltgrids.add_argument( + '-grid_to_raster', + '--grid_to_raster', + action='store_true', + dest='grid_to_raster', + help='Save gridded array as raster. May directly load/plot in successive script call.', + ) # Variogram plots pltvario = parser.add_argument_group('Supported types of variogram plots.') - pltvario.add_argument('-variogramplot', '--variogramplot', action='store_true', - dest='variogramplot', help="Plot gridded station variogram.") - pltvario.add_argument('-binnedvariogram', '--binnedvariogram', action='store_true', dest='binnedvariogram', - help="Apply experimental variogram fit to total binned empirical variograms for each time slice. Default is to pass total unbinned empiricial variogram.") - pltvario.add_argument('-variogram_per_timeslice', '--variogram_per_timeslice', action='store_true', dest='variogram_per_timeslice', - help="Generate variogram plots per gridded station AND time-slice.") - pltvario.add_argument('-variogram_errlimit', '--variogram_errlimit', dest='variogram_errlimit', type=float, default='inf', - help="Variogram RMSE threshold to discard grid-cells with large uncertainties.") + pltvario.add_argument( + '-variogramplot', + '--variogramplot', + action='store_true', + dest='variogramplot', + help='Plot gridded station variogram.', + ) + pltvario.add_argument( + '-binnedvariogram', + '--binnedvariogram', + action='store_true', + dest='binnedvariogram', + help='Apply experimental variogram fit to total binned empirical variograms for each time slice. Default is to pass total unbinned empiricial variogram.', + ) + pltvario.add_argument( + '-variogram_per_timeslice', + '--variogram_per_timeslice', + action='store_true', + dest='variogram_per_timeslice', + help='Generate variogram plots per gridded station AND time-slice.', + ) + pltvario.add_argument( + '-variogram_errlimit', + '--variogram_errlimit', + dest='variogram_errlimit', + type=float, + default='inf', + help='Variogram RMSE threshold to discard grid-cells with large uncertainties.', + ) return parser @@ -169,33 +397,28 @@ def cmd_line_parse(iargs=None): def convert_SI(val, unit_in, unit_out): - ''' - Convert input to desired units - ''' - - SI = {'mm': 0.001, 'cm': 0.01, 'm': 1.0, 'km': 1000., - 'mm^2': 1e-6, 'cm^2': 1e-4, 'm^2': 1.0, 'km^2': 1e+6} + """Convert input to desired units.""" + SI = {'mm': 0.001, 'cm': 0.01, 'm': 1.0, 'km': 1000.0, 'mm^2': 1e-6, 'cm^2': 1e-4, 'm^2': 1.0, 'km^2': 1e6} # avoid conversion if output unit in time if unit_out in ['minute', 'hour', 'day', 'year']: # adjust if input isn't datetime, and assume it to be part of workflow # e.g. sigZTD filter, already extracted datetime object try: - return eval('val.apply(pd.to_datetime).dt.{}.astype(float).astype("Int32")'.format(unit_out)) + datetime = val.apply(pd.to_datetime).dt + return getattr(datetime, unit_out).astype(float).astype("Int32") except AttributeError: return val # check if output spatial unit is supported if unit_out not in SI: - raise ValueError("User-specified output unit {} not recognized.".format(unit_out)) + raise ValueError(f'User-specified output unit {unit_out} not recognized.') return val * SI[unit_in] / SI[unit_out] def midpoint(p1, p2): - ''' - Calculate central longitude for '--time_lines' option - ''' + """Calculate central longitude for '--time_lines' option.""" import math if p1[1] == p2[1]: @@ -207,15 +430,23 @@ def midpoint(p1, p2): dy = math.cos(lat2) * math.sin(dlon) lon3 = lon1 + math.atan2(dy, math.cos(lat1) + dx) - return (int(math.degrees(lon3))) + return int(math.degrees(lon3)) -def save_gridfile(df, gridfile_type, fname, plotbbox, spacing, unit, - colorbarfmt='%.2f', stationsongrids=False, time_lines=False, - dtype="float32", noData=np.nan): - ''' - Function to save gridded-arrays as GDAL-readable file. - ''' +def save_gridfile( + df, + gridfile_type, + fname, + plotbbox, + spacing, + unit, + colorbarfmt='%.2f', + stationsongrids=False, + time_lines=False, + dtype='float32', + noData=np.nan, +): + """Function to save gridded-arrays as GDAL-readable file.""" # Pass metadata metadata_dict = {} metadata_dict['gridfile_type'] = gridfile_type @@ -237,11 +468,18 @@ def save_gridfile(df, gridfile_type, fname, plotbbox, spacing, unit, metadata_dict['time_lines'] = 'False' # Write data to file - transform = Affine(spacing, 0., plotbbox[0], 0., -1*spacing, plotbbox[-1]) - with rasterio.open(fname, mode="w", count=1, - width=df.shape[1], height=df.shape[0], - dtype=dtype, nodata=noData, - crs='+proj=latlong', transform=transform) as dst: + transform = Affine(spacing, 0.0, plotbbox[0], 0.0, -1 * spacing, plotbbox[-1]) + with rasterio.open( + fname, + mode='w', + count=1, + width=df.shape[1], + height=df.shape[0], + dtype=dtype, + nodata=noData, + crs='+proj=latlong', + transform=transform, + ) as dst: dst.update_tags(0, **metadata_dict) dst.write(df, 1) @@ -249,10 +487,7 @@ def save_gridfile(df, gridfile_type, fname, plotbbox, spacing, unit, def load_gridfile(fname, unit): - ''' - Function to load gridded-arrays saved from previous runs. - ''' - + """Function to load gridded-arrays saved from previous runs.""" try: with rasterio.open(fname) as src: grid_array = src.read(1).astype(float) @@ -305,12 +540,23 @@ def load_gridfile(fname, unit): return grid_array, plotbbox, spacing, colorbarfmt, stationsongrids, time_lines -class VariogramAnalysis(): - ''' - Class which ingests dataframe output from 'RaiderStats' class and performs variogram analysis. - ''' - - def __init__(self, filearg, gridpoints, col_name, unit='m', workdir='./', seasonalinterval=None, densitythreshold=10, binnedvariogram=False, numCPUs=8, variogram_per_timeslice=False, variogram_errlimit='inf'): +class VariogramAnalysis: + """Class which ingests dataframe output from 'RaiderStats' class and performs variogram analysis.""" + + def __init__( + self, + filearg, + gridpoints, + col_name, + unit='m', + workdir='./', + seasonalinterval=None, + densitythreshold=10, + binnedvariogram=False, + numCPUs=8, + variogram_per_timeslice=False, + variogram_errlimit='inf', + ) -> None: self.df = filearg self.col_name = col_name self.unit = unit @@ -324,10 +570,9 @@ def __init__(self, filearg, gridpoints, col_name, unit='m', workdir='./', season self.variogram_errlimit = float(variogram_errlimit) def _get_samples(self, data, Nsamp=1000): - ''' - pull samples from a 2D image for variogram analysis - ''' + """Pull samples from a 2D image for variogram analysis.""" import random + if len(data) < self.densitythreshold: logger.warning('Less than {} points for this gridcell', self.densitythreshold) logger.info('Will pass empty list') @@ -346,32 +591,24 @@ def _get_samples(self, data, Nsamp=1000): return d, indpars def _get_XY(self, x2d, y2d, indpars): - ''' - Given a list of indices, return the x,y locations - from two matrices - ''' + """Given a list of indices, return the x,y locations from two matrices.""" x = np.array([[x2d[r[0]], x2d[r[1]]] for r in indpars]) y = np.array([[y2d[r[0]], y2d[r[1]]] for r in indpars]) return x, y def _get_distances(self, XY): - ''' - Return the distances between each point in a list of points - ''' + """Return the distances between each point in a list of points.""" from scipy.spatial.distance import cdist + return np.diag(cdist(XY[:, :, 0], XY[:, :, 1], metric='euclidean')) def _get_variogram(self, XY, xy=None): - ''' - Return variograms - ''' + """Return variograms.""" return 0.5 * np.square(XY - xy) # XY = 1st col xy= 2nd col def _emp_vario(self, x, y, data, Nsamp=1000): - ''' - Compute empirical semivariance - ''' + """Compute empirical semivariance.""" # remove NaNs if possible mask = ~np.isnan(data) if False in mask: @@ -387,20 +624,17 @@ def _emp_vario(self, x, y, data, Nsamp=1000): samples, indpars = self._get_samples(data, Nsamp) x, y = self._get_XY(x, y, indpars) - dists = self._get_distances( - np.array([[x[:, 0], y[:, 0]], [x[:, 1], y[:, 1]]]).T) + dists = self._get_distances(np.array([[x[:, 0], y[:, 0]], [x[:, 1], y[:, 1]]]).T) vario = self._get_variogram(samples[:, 0], samples[:, 1]) return dists, vario def _binned_vario(self, hEff, rawVario, xBin=None): - ''' - return a binned empirical variogram - ''' + """Return a binned empirical variogram.""" if xBin is None: with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message="All-NaN slice encountered") - xBin = np.linspace(0, np.nanmax(hEff) * .67, 20) + warnings.filterwarnings('ignore', message='All-NaN slice encountered') + xBin = np.linspace(0, np.nanmax(hEff) * 0.67, 20) nBins = len(xBin) - 1 hExp, expVario = [], [] @@ -410,10 +644,10 @@ def _binned_vario(self, hEff, rawVario, xBin=None): # circumvent indexing try: with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message="Mean of empty slice") + warnings.filterwarnings('ignore', message='Mean of empty slice') hExp.append(np.nanmean(hEff[iBinMask])) expVario.append(np.nanmean(rawVario[iBinMask])) - except BaseException: # TODO: Which error(s)? + except: # TODO: Which error(s)? pass if False in ~np.isnan(hExp): @@ -424,23 +658,19 @@ def _binned_vario(self, hEff, rawVario, xBin=None): return np.array(hExp), np.array(expVario) def _fit_vario(self, dists, vario, model=None, x0=None, Nparm=None, ub=None): - ''' - Fit a variogram model to data - ''' + """Fit a variogram model to data.""" from scipy.optimize import least_squares def resid(x, d, v, m): - return (m(x, d) - v) + return m(x, d) - v if ub is None: with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message="All-NaN slice encountered") - ub = np.array([np.nanmax(dists) * 0.8, np.nanmax(vario) - * 0.8, np.nanmax(vario) * 0.8]) + warnings.filterwarnings('ignore', message='All-NaN slice encountered') + ub = np.array([np.nanmax(dists) * 0.8, np.nanmax(vario) * 0.8, np.nanmax(vario) * 0.8]) if x0 is None and Nparm is None: - raise RuntimeError( - 'Must specify either x0 or the number of model parameters') + raise RuntimeError('Must specify either x0 or the number of model parameters') if x0 is not None: lb = np.zeros(len(x0)) if Nparm is not None: @@ -452,28 +682,30 @@ def resid(x, d, v, m): d = dists[~mask].copy() v = vario[~mask].copy() - res_robust = least_squares(resid, x0, bounds=bounds, - loss='soft_l1', f_scale=0.1, - args=(d, v, model)) + res_robust = least_squares( + resid, + x0, + bounds=bounds, + loss='soft_l1', + f_scale=0.1, + args=(d, v, model), + ) with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message="All-NaN slice encountered") + warnings.filterwarnings('ignore', message='All-NaN slice encountered') d_test = np.linspace(0, np.nanmax(dists), 100) # v_test is my y., # res_robust.x =a, b, c, where a = range, b = sill, and c = nugget model, d_test=x v_test = model(res_robust.x, d_test) return res_robust, d_test, v_test - # this would be expontential plus nugget + # this would be exponential plus nugget def __exponential__(self, parms, h, nugget=False): - ''' - returns a variogram model given a set of arguments and - key-word arguments - ''' + """Return variogram model given a set of arguments and keyword arguments.""" # a = range, b = sill, c = nugget model a, b, c = parms with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message="overflow encountered in true_divide") + warnings.filterwarnings('ignore', message='overflow encountered in true_divide') if nugget: return b * (1 - np.exp(-h / a)) + c else: @@ -481,16 +713,12 @@ def __exponential__(self, parms, h, nugget=False): # this would be gaussian plus nugget def __gaussian__(self, parms, h): - ''' - returns a Gaussian variogram model - ''' + """Returns a Gaussian variogram model.""" a, b, c = parms return b * (1 - np.exp(-np.square(h) / (a**2))) + c def _append_variogram(self, grid_ind, grid_subset): - ''' - For a given grid-cell, iterate through time slices to generate/append empirical variogram(s) - ''' + """For a given grid-cell, iterate through time slices to generate/append empirical variogram(s).""" # Comprehensive arrays recording data across all time epochs for given station dists_arr = [] vario_arr = [] @@ -503,40 +731,54 @@ def _append_variogram(self, grid_ind, grid_subset): # If insufficient sample size, skip slice and record occurence if len(np.array(grid_subset[grid_subset['Date'] == j][self.col_name])) < self.densitythreshold: # Record skipped [gridnode, timeslice] - self.skipped_slices.append([grid_ind, j.strftime("%Y-%m-%d")]) + self.skipped_slices.append([grid_ind, j.strftime('%Y-%m-%d')]) else: - self.gridcenterlist.append(['grid{} '.format( - grid_ind) + 'Lat:{} Lon:{}'.format( - str(self.gridpoints[grid_ind][1]), str(self.gridpoints[grid_ind][0]))]) - lonarr = np.array( - grid_subset[grid_subset['Date'] == j]['Lon']) - latarr = np.array( - grid_subset[grid_subset['Date'] == j]['Lat']) - delayarray = np.array( - grid_subset[grid_subset['Date'] == j][self.col_name]) + self.gridcenterlist.append( + [ + f'grid{grid_ind} ' + + f'Lat:{str(self.gridpoints[grid_ind][1])} Lon:{str(self.gridpoints[grid_ind][0])}' + ] + ) + lonarr = np.array(grid_subset[grid_subset['Date'] == j]['Lon']) + latarr = np.array(grid_subset[grid_subset['Date'] == j]['Lat']) + delayarray = np.array(grid_subset[grid_subset['Date'] == j][self.col_name]) # fit empirical variogram for each time AND grid dists, vario = self._emp_vario(lonarr, latarr, delayarray) - dists_binned, vario_binned = self._binned_vario( - dists, vario) + dists_binned, vario_binned = self._binned_vario(dists, vario) # fit experimental variogram for each time AND grid, model default is exponential res_robust, d_test, v_test = self._fit_vario( - dists_binned, vario_binned, model=self.__exponential__, x0=None, Nparm=3) + dists_binned, vario_binned, model=self.__exponential__, x0=None, Nparm=3 + ) # Plot empirical + experimental variogram for this gridnode and timeslice - if not os.path.exists(os.path.join(self.workdir, 'variograms/grid{}'.format(grid_ind))): - os.makedirs(os.path.join( - self.workdir, 'variograms/grid{}'.format(grid_ind))) + if not os.path.exists(os.path.join(self.workdir, f'variograms/grid{grid_ind}')): + os.makedirs(os.path.join(self.workdir, f'variograms/grid{grid_ind}')) # Make variogram plots for each time-slice if self.variogram_per_timeslice: # Plot empirical variogram for this gridnode and timeslice - self.plot_variogram(grid_ind, j.strftime("%Y%m%d"), [self.gridpoints[grid_ind][1], self.gridpoints[grid_ind][0]], - workdir=os.path.join(self.workdir, 'variograms/grid{}'.format(grid_ind)), dists=dists, vario=vario, - dists_binned=dists_binned, vario_binned=vario_binned) + self.plot_variogram( + grid_ind, + j.strftime('%Y%m%d'), + [self.gridpoints[grid_ind][1], self.gridpoints[grid_ind][0]], + workdir=os.path.join(self.workdir, f'variograms/grid{grid_ind}'), + dists=dists, + vario=vario, + dists_binned=dists_binned, + vario_binned=vario_binned, + ) # Plot experimental variogram for this gridnode and timeslice - self.plot_variogram(grid_ind, j.strftime("%Y%m%d"), [self.gridpoints[grid_ind][1], self.gridpoints[grid_ind][0]], - workdir=os.path.join(self.workdir, 'variograms/grid{}'.format(grid_ind)), d_test=d_test, v_test=v_test, - res_robust=res_robust.x, dists_binned=dists_binned, vario_binned=vario_binned) + self.plot_variogram( + grid_ind, + j.strftime('%Y%m%d'), + [self.gridpoints[grid_ind][1], self.gridpoints[grid_ind][0]], + workdir=os.path.join(self.workdir, f'variograms/grid{grid_ind}'), + d_test=d_test, + v_test=v_test, + res_robust=res_robust.x, + dists_binned=dists_binned, + vario_binned=vario_binned, + ) # append for plotting - self.good_slices.append([grid_ind, j.strftime("%Y%m%d")]) + self.good_slices.append([grid_ind, j.strftime('%Y%m%d')]) dists_arr.append(dists) vario_arr.append(vario) dists_binned_arr.append(dists_binned) @@ -555,28 +797,45 @@ def _append_variogram(self, grid_ind, grid_subset): vario_binned_arr = np.concatenate(vario_binned_arr).ravel() else: # dists_binned_arr = dists_arr ; vario_binned_arr = vario_arr - dists_binned_arr, vario_binned_arr = self._binned_vario( - dists_arr, vario_arr) + dists_binned_arr, vario_binned_arr = self._binned_vario(dists_arr, vario_arr) TOT_res_robust, TOT_d_test, TOT_v_test = self._fit_vario( - dists_binned_arr, vario_binned_arr, model=self.__exponential__, x0=None, Nparm=3) + dists_binned_arr, vario_binned_arr, model=self.__exponential__, x0=None, Nparm=3 + ) tot_timetag = self.good_slices[0][1] + '–' + self.good_slices[-1][1] # Append TOT arrays self.TOT_good_slices.append([grid_ind, tot_timetag]) self.TOT_res_robust_arr.append(TOT_res_robust.x) self.TOT_tot_timetag.append(tot_timetag) - var_rmse = np.sqrt(np.nanmean((TOT_res_robust.fun)**2)) + var_rmse = np.sqrt(np.nanmean((TOT_res_robust.fun) ** 2)) if var_rmse <= self.variogram_errlimit: self.TOT_res_robust_rmse.append(var_rmse) else: self.TOT_res_robust_rmse.append(np.array(np.nan)) # Plot empirical variogram for this gridnode - self.plot_variogram(grid_ind, tot_timetag, [self.gridpoints[grid_ind][1], self.gridpoints[grid_ind][0]], - workdir=os.path.join(self.workdir, 'variograms/grid{}'.format(grid_ind)), dists=dists_arr, vario=vario_arr, - dists_binned=dists_binned_arr, vario_binned=vario_binned_arr, seasonalinterval=self.seasonalinterval) + self.plot_variogram( + grid_ind, + tot_timetag, + [self.gridpoints[grid_ind][1], self.gridpoints[grid_ind][0]], + workdir=os.path.join(self.workdir, f'variograms/grid{grid_ind}'), + dists=dists_arr, + vario=vario_arr, + dists_binned=dists_binned_arr, + vario_binned=vario_binned_arr, + seasonalinterval=self.seasonalinterval, + ) # Plot experimental variogram for this gridnode - self.plot_variogram(grid_ind, tot_timetag, [self.gridpoints[grid_ind][1], self.gridpoints[grid_ind][0]], - workdir=os.path.join(self.workdir, 'variograms/grid{}'.format(grid_ind)), d_test=TOT_d_test, v_test=TOT_v_test, - res_robust=TOT_res_robust.x, seasonalinterval=self.seasonalinterval, dists_binned=dists_binned_arr, vario_binned=vario_binned_arr) + self.plot_variogram( + grid_ind, + tot_timetag, + [self.gridpoints[grid_ind][1], self.gridpoints[grid_ind][0]], + workdir=os.path.join(self.workdir, f'variograms/grid{grid_ind}'), + d_test=TOT_d_test, + v_test=TOT_v_test, + res_robust=TOT_res_robust.x, + seasonalinterval=self.seasonalinterval, + dists_binned=dists_binned_arr, + vario_binned=vario_binned_arr, + ) # Record sparse grids which didn't have sufficient sample size of data through any of the timeslices else: self.sparse_grids.append(grid_ind) @@ -584,9 +843,7 @@ def _append_variogram(self, grid_ind, grid_subset): return self.TOT_good_slices, self.TOT_res_robust_arr, self.TOT_res_robust_rmse, self.gridcenterlist def create_variograms(self): - ''' - Iterate through grid-cells and time slices to generate empirical variogram(s) - ''' + """Iterate through grid-cells and time slices to generate empirical variogram(s).""" # track data for plotting self.TOT_good_slices = [] self.TOT_res_robust_arr = [] @@ -604,7 +861,7 @@ def create_variograms(self): grid_subset = self.df[self.df['gridnode'] == i] args.append((i, grid_subset)) # Parallelize iteration through all grid-cells and time slices - with multiprocessing.Pool(self.numCPUs) as multipool: + with mp.Pool(self.numCPUs) as multipool: for i, j, k, l in multipool.starmap(self._append_variogram, args): self.TOT_good_slices.extend(i) self.TOT_res_robust_arr.extend(j) @@ -612,34 +869,45 @@ def create_variograms(self): self.gridcenterlist.extend(l) # save grid-center lookup table - self.gridcenterlist = [list(i) for i in set(tuple(j) - for j in self.gridcenterlist)] + self.gridcenterlist = [list(i) for i in set(tuple(j) for j in self.gridcenterlist)] self.gridcenterlist.sort(key=lambda x: int(x[0][4:6])) - gridcenter = open( - (os.path.join(self.workdir, 'variograms/gridlocation_lookup.txt')), "w") + gridcenter = open((os.path.join(self.workdir, 'variograms/gridlocation_lookup.txt')), 'w') for element in self.gridcenterlist: - gridcenter.writelines("\n".join(element)) - gridcenter.write("\n") + gridcenter.writelines('\n'.join(element)) + gridcenter.write('\n') gridcenter.close() TOT_grids = [i[0] for i in self.TOT_good_slices] return TOT_grids, self.TOT_res_robust_arr, self.TOT_res_robust_rmse - def plot_variogram(self, gridID, timeslice, coords, workdir='./', d_test=None, v_test=None, res_robust=None, dists=None, vario=None, dists_binned=None, vario_binned=None, seasonalinterval=None): - ''' - Make empirical and/or experimental variogram fit plots - ''' + def plot_variogram( + self, + gridID, + timeslice, + coords, + workdir='./', + d_test=None, + v_test=None, + res_robust=None, + dists=None, + vario=None, + dists_binned=None, + vario_binned=None, + seasonalinterval=None, + ) -> None: + """Make empirical and/or experimental variogram fit plots.""" # If specified workdir doesn't exist, create it if not os.path.exists(workdir): os.mkdir(workdir) # make plot title - title_str = ' \nLat:{:.2f} Lon:{:.2f}\nTime:{}'.format( - coords[1], coords[0], str(timeslice)) + title_str = f' \nLat:{coords[1]:.2f} Lon:{coords[0]:.2f}\nTime:{str(timeslice)}' if seasonalinterval: - title_str += ' Season(mm/dd): {}/{} – {}/{}'.format(int(timeslice[4:6]), int( - timeslice[6:8]), int(timeslice[-4:-2]), int(timeslice[-2:])) + title_str += ( + ' Season(mm/dd): ' + f'{int(timeslice[4:6])}/{int(timeslice[6:8])} – {int(timeslice[-4:-2])}/{int(timeslice[-2:])}' + ) if dists is not None and vario is not None: # scale from m to user-defined units @@ -650,51 +918,68 @@ def plot_variogram(self, gridID, timeslice, coords, workdir='./', d_test=None, v dists_binned = [convert_SI(i, 'm', self.unit) for i in dists_binned] plt.plot(dists_binned, vario_binned, 'bo', label='binned') if res_robust is not None: - plt.axhline(y=res_robust[1], color='g', - linestyle='--', label='ɣ\u0332\u00b2({}\u00b2)'.format(self.unit)) + plt.axhline(y=res_robust[1], color='g', linestyle='--', label=f'ɣ\u0332\u00b2({self.unit}\u00b2)') # scale from m to user-defined units res_robust[0] = convert_SI(res_robust[0], 'm', self.unit) - plt.axvline(x=res_robust[0], color='c', - linestyle='--', label='h ({})'.format(self.unit)) + plt.axvline(x=res_robust[0], color='c', linestyle='--', label=f'h ({self.unit})') if d_test is not None and v_test is not None: # scale from m to user-defined units d_test = [convert_SI(i, 'm', self.unit) for i in d_test] plt.plot(d_test, v_test, 'r-', label='experimental fit') - plt.xlabel('Distance ({})'.format(self.unit)) - plt.ylabel('Dissimilarity ({}\u00b2)'.format(self.unit)) - plt.legend(bbox_to_anchor=(1.02, 1), - loc='upper left', borderaxespad=0., framealpha=1.) + plt.xlabel(f'Distance ({self.unit})') + plt.ylabel(f'Dissimilarity ({self.unit}\u00b2)') + plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0.0, framealpha=1.0) # Plot empirical variogram if d_test is None and v_test is None: plt.title('Empirical variogram' + title_str) plt.tight_layout() - plt.savefig(os.path.join( - workdir, 'grid{}_timeslice{}_justEMPvariogram.eps'.format(gridID, timeslice))) + plt.savefig(os.path.join(workdir, f'grid{gridID}_timeslice{timeslice}_justEMPvariogram.eps')) # Plot just experimental variogram else: plt.title('Experimental variogram' + title_str) plt.tight_layout() - plt.savefig(os.path.join( - workdir, 'grid{}_timeslice{}_justEXPvariogram.eps'.format(gridID, timeslice))) + plt.savefig(os.path.join(workdir, f'grid{gridID}_timeslice{timeslice}_justEXPvariogram.eps')) plt.close() - return - -class RaiderStats(object): - ''' - Class which loads standard weather model/GPS delay files and generates a series of user-requested statistics and graphics. - ''' +class RaiderStats: + """Class which loads standard weather model/GPS delay files and generates a series of user-requested statistics and graphics.""" # import dependencies import glob - def __init__(self, filearg, col_name, unit='m', workdir='./', bbox=None, spacing=1, timeinterval=None, seasonalinterval=None, - obs_errlimit='inf', time_lines=False, stationsongrids=False, station_seasonal_phase=False, cbounds=None, colorpercentile=[25, 95], - usr_colormap='hot_r', grid_heatmap=False, grid_delay_mean=False, grid_delay_median=False, grid_delay_stdev=False, - grid_seasonal_phase=False, grid_delay_absolute_mean=False, grid_delay_absolute_median=False, - grid_delay_absolute_stdev=False, grid_seasonal_absolute_phase=False, grid_to_raster=False, min_span=[2, 0.6], - period_limit=0., numCPUs=8, phaseamp_per_station=False): + def __init__( + self, + filearg, + col_name, + unit='m', + workdir='./', + bbox=None, + spacing=1, + timeinterval=None, + seasonalinterval=None, + obs_errlimit='inf', + time_lines=False, + stationsongrids=False, + station_seasonal_phase=False, + cbounds=None, + colorpercentile=[25, 95], + usr_colormap='hot_r', + grid_heatmap=False, + grid_delay_mean=False, + grid_delay_median=False, + grid_delay_stdev=False, + grid_seasonal_phase=False, + grid_delay_absolute_mean=False, + grid_delay_absolute_median=False, + grid_delay_absolute_stdev=False, + grid_seasonal_absolute_phase=False, + grid_to_raster=False, + min_span=[2, 0.6], + period_limit=0.0, + numCPUs=8, + phaseamp_per_station=False, + ) -> None: self.fname = filearg self.col_name = col_name self.unit = unit @@ -752,117 +1037,309 @@ def __init__(self, filearg, col_name, unit='m', workdir='./', bbox=None, spacing if self.colorpercentile is None: self.colorpercentile = [25, 95] if self.colorpercentile[0] > self.colorpercentile[1]: - raise Exception('Input colorpercentile lower threshold {} higher than upper threshold {}'.format( - self.colorpercentile[0], self.colorpercentile[1])) + raise Exception( + f'Input colorpercentile lower threshold {self.colorpercentile[0]} higher than upper threshold {self.colorpercentile[1]}' + ) # load dataframe directly if previously generated TIF grid-file if self.fname.endswith('.tif'): if 'grid_heatmap' in self.fname: - self.grid_heatmap, self.plotbbox, self.spacing, self.colorbarfmt, self.stationsongrids, self.time_lines = load_gridfile(self.fname, self.unit) + ( + self.grid_heatmap, + self.plotbbox, + self.spacing, + self.colorbarfmt, + self.stationsongrids, + self.time_lines, + ) = load_gridfile(self.fname, self.unit) self.col_name = os.path.basename(self.fname).split('_' + 'grid_heatmap')[0] if 'grid_delay_mean' in self.fname: - self.grid_delay_mean, self.plotbbox, self.spacing, self.colorbarfmt, self.stationsongrids, self.time_lines = load_gridfile(self.fname, self.unit) + ( + self.grid_delay_mean, + self.plotbbox, + self.spacing, + self.colorbarfmt, + self.stationsongrids, + self.time_lines, + ) = load_gridfile(self.fname, self.unit) self.col_name = os.path.basename(self.fname).split('_' + 'grid_delay_mean')[0] if 'grid_delay_median' in self.fname: - self.grid_delay_median, self.plotbbox, self.spacing, self.colorbarfmt, self.stationsongrids, self.time_lines = load_gridfile(self.fname, self.unit) + ( + self.grid_delay_median, + self.plotbbox, + self.spacing, + self.colorbarfmt, + self.stationsongrids, + self.time_lines, + ) = load_gridfile(self.fname, self.unit) self.col_name = os.path.basename(self.fname).split('_' + 'grid_delay_median')[0] if 'grid_delay_stdev' in self.fname: - self.grid_delay_stdev, self.plotbbox, self.spacing, self.colorbarfmt, self.stationsongrids, self.time_lines = load_gridfile(self.fname, self.unit) + ( + self.grid_delay_stdev, + self.plotbbox, + self.spacing, + self.colorbarfmt, + self.stationsongrids, + self.time_lines, + ) = load_gridfile(self.fname, self.unit) self.col_name = os.path.basename(self.fname).split('_' + 'grid_delay_stdev')[0] if 'grid_seasonal_phase' in self.fname: - self.grid_seasonal_phase, self.plotbbox, self.spacing, self.colorbarfmt, self.stationsongrids, self.time_lines = load_gridfile(self.fname, self.unit) + ( + self.grid_seasonal_phase, + self.plotbbox, + self.spacing, + self.colorbarfmt, + self.stationsongrids, + self.time_lines, + ) = load_gridfile(self.fname, self.unit) self.col_name = os.path.basename(self.fname).split('_' + 'grid_seasonal_phase')[0] if 'grid_seasonal_period' in self.fname: - self.grid_seasonal_period, self.plotbbox, self.spacing, self.colorbarfmt, self.stationsongrids, self.time_lines = load_gridfile(self.fname, self.unit) + ( + self.grid_seasonal_period, + self.plotbbox, + self.spacing, + self.colorbarfmt, + self.stationsongrids, + self.time_lines, + ) = load_gridfile(self.fname, self.unit) self.col_name = os.path.basename(self.fname).split('_' + 'grid_seasonal_period')[0] if 'grid_seasonal_amplitude' in self.fname: - self.grid_seasonal_amplitude, self.plotbbox, self.spacing, self.colorbarfmt, self.stationsongrids, self.time_lines = load_gridfile(self.fname, self.unit) + ( + self.grid_seasonal_amplitude, + self.plotbbox, + self.spacing, + self.colorbarfmt, + self.stationsongrids, + self.time_lines, + ) = load_gridfile(self.fname, self.unit) self.col_name = os.path.basename(self.fname).split('_' + 'grid_seasonal_amplitude')[0] if 'grid_seasonal_phase_stdev' in self.fname: - self.grid_seasonal_phase_stdev, self.plotbbox, self.spacing, self.colorbarfmt, self.stationsongrids, self.time_lines = load_gridfile(self.fname, self.unit) + ( + self.grid_seasonal_phase_stdev, + self.plotbbox, + self.spacing, + self.colorbarfmt, + self.stationsongrids, + self.time_lines, + ) = load_gridfile(self.fname, self.unit) self.col_name = os.path.basename(self.fname).split('_' + 'grid_seasonal_phase_stdev')[0] if 'grid_seasonal_amplitude_stdev' in self.fname: - self.grid_seasonal_amplitude_stdev, self.plotbbox, self.spacing, self.colorbarfmt, self.stationsongrids, self.time_lines = load_gridfile(self.fname, self.unit) + ( + self.grid_seasonal_amplitude_stdev, + self.plotbbox, + self.spacing, + self.colorbarfmt, + self.stationsongrids, + self.time_lines, + ) = load_gridfile(self.fname, self.unit) self.col_name = os.path.basename(self.fname).split('_' + 'grid_seasonal_amplitude_stdev')[0] if 'grid_seasonal_period_stdev' in self.fname: - self.grid_seasonal_period_stdev, self.plotbbox, self.spacing, self.colorbarfmt, self.stationsongrids, self.time_lines = load_gridfile(self.fname, self.unit) + ( + self.grid_seasonal_period_stdev, + self.plotbbox, + self.spacing, + self.colorbarfmt, + self.stationsongrids, + self.time_lines, + ) = load_gridfile(self.fname, self.unit) self.col_name = os.path.basename(self.fname).split('_' + 'grid_seasonal_period_stdev')[0] if 'grid_seasonal_fit_rmse' in self.fname: - self.grid_seasonal_fit_rmse, self.plotbbox, self.spacing, self.colorbarfmt, self.stationsongrids, self.time_lines = load_gridfile(self.fname, self.unit) + ( + self.grid_seasonal_fit_rmse, + self.plotbbox, + self.spacing, + self.colorbarfmt, + self.stationsongrids, + self.time_lines, + ) = load_gridfile(self.fname, self.unit) self.col_name = os.path.basename(self.fname).split('_' + 'grid_seasonal_fit_rmse')[0] if 'grid_delay_absolute_mean' in self.fname: - self.grid_delay_absolute_mean, self.plotbbox, self.spacing, self.colorbarfmt, self.stationsongrids, self.time_lines = load_gridfile(self.fname, self.unit) + ( + self.grid_delay_absolute_mean, + self.plotbbox, + self.spacing, + self.colorbarfmt, + self.stationsongrids, + self.time_lines, + ) = load_gridfile(self.fname, self.unit) self.col_name = os.path.basename(self.fname).split('_' + 'grid_delay_absolute_mean')[0] if 'grid_delay_absolute_median' in self.fname: - self.grid_delay_absolute_median, self.plotbbox, self.spacing, self.colorbarfmt, self.stationsongrids, self.time_lines = load_gridfile(self.fname, self.unit) + ( + self.grid_delay_absolute_median, + self.plotbbox, + self.spacing, + self.colorbarfmt, + self.stationsongrids, + self.time_lines, + ) = load_gridfile(self.fname, self.unit) self.col_name = os.path.basename(self.fname).split('_' + 'grid_delay_absolute_median')[0] if 'grid_delay_absolute_stdev' in self.fname: - self.grid_delay_absolute_stdev, self.plotbbox, self.spacing, self.colorbarfmt, self.stationsongrids, self.time_lines = load_gridfile(self.fname, self.unit) + ( + self.grid_delay_absolute_stdev, + self.plotbbox, + self.spacing, + self.colorbarfmt, + self.stationsongrids, + self.time_lines, + ) = load_gridfile(self.fname, self.unit) self.col_name = os.path.basename(self.fname).split('_' + 'grid_delay_absolute_stdev')[0] if 'grid_seasonal_absolute_phase' in self.fname: - self.grid_seasonal_absolute_phase, self.plotbbox, self.spacing, self.colorbarfmt, self.stationsongrids, self.time_lines = load_gridfile(self.fname, self.unit) + ( + self.grid_seasonal_absolute_phase, + self.plotbbox, + self.spacing, + self.colorbarfmt, + self.stationsongrids, + self.time_lines, + ) = load_gridfile(self.fname, self.unit) self.col_name = os.path.basename(self.fname).split('_' + 'grid_seasonal_absolute_phase')[0] if 'grid_seasonal_absolute_period' in self.fname: - self.grid_seasonal_absolute_period, self.plotbbox, self.spacing, self.colorbarfmt, self.stationsongrids, self.time_lines = load_gridfile(self.fname, self.unit) + ( + self.grid_seasonal_absolute_period, + self.plotbbox, + self.spacing, + self.colorbarfmt, + self.stationsongrids, + self.time_lines, + ) = load_gridfile(self.fname, self.unit) self.col_name = os.path.basename(self.fname).split('_' + 'grid_seasonal_absolute_period')[0] if 'grid_seasonal_absolute_amplitude' in self.fname: - self.grid_seasonal_absolute_amplitude, self.plotbbox, self.spacing, self.colorbarfmt, self.stationsongrids, self.time_lines = load_gridfile(self.fname, self.unit) + ( + self.grid_seasonal_absolute_amplitude, + self.plotbbox, + self.spacing, + self.colorbarfmt, + self.stationsongrids, + self.time_lines, + ) = load_gridfile(self.fname, self.unit) self.col_name = os.path.basename(self.fname).split('_' + 'grid_seasonal_absolute_amplitude')[0] if 'grid_seasonal_absolute_phase_stdev' in self.fname: - self.grid_seasonal_absolute_phase_stdev, self.plotbbox, self.spacing, self.colorbarfmt, self.stationsongrids, self.time_lines = load_gridfile(self.fname, self.unit) + ( + self.grid_seasonal_absolute_phase_stdev, + self.plotbbox, + self.spacing, + self.colorbarfmt, + self.stationsongrids, + self.time_lines, + ) = load_gridfile(self.fname, self.unit) self.col_name = os.path.basename(self.fname).split('_' + 'grid_seasonal_absolute_phase_stdev')[0] if 'grid_seasonal_absolute_amplitude_stdev' in self.fname: - self.grid_seasonal_absolute_amplitude_stdev, self.plotbbox, self.spacing, self.colorbarfmt, self.stationsongrids, self.time_lines = load_gridfile(self.fname, self.unit) + ( + self.grid_seasonal_absolute_amplitude_stdev, + self.plotbbox, + self.spacing, + self.colorbarfmt, + self.stationsongrids, + self.time_lines, + ) = load_gridfile(self.fname, self.unit) self.col_name = os.path.basename(self.fname).split('_' + 'grid_seasonal_absolute_amplitude_stdev')[0] if 'grid_seasonal_absolute_period_stdev' in self.fname: - self.grid_seasonal_absolute_period_stdev, self.plotbbox, self.spacing, self.colorbarfmt, self.stationsongrids, self.time_lines = load_gridfile(self.fname, self.unit) + ( + self.grid_seasonal_absolute_period_stdev, + self.plotbbox, + self.spacing, + self.colorbarfmt, + self.stationsongrids, + self.time_lines, + ) = load_gridfile(self.fname, self.unit) self.col_name = os.path.basename(self.fname).split('_' + 'grid_seasonal_absolute_period_stdev')[0] if 'grid_seasonal_absolute_fit_rmse' in self.fname: - self.grid_seasonal_absolute_fit_rmse, self.plotbbox, self.spacing, self.colorbarfmt, self.stationsongrids, self.time_lines = load_gridfile(self.fname, self.unit) + ( + self.grid_seasonal_absolute_fit_rmse, + self.plotbbox, + self.spacing, + self.colorbarfmt, + self.stationsongrids, + self.time_lines, + ) = load_gridfile(self.fname, self.unit) self.col_name = os.path.basename(self.fname).split('_' + 'grid_seasonal_absolute_fit_rmse')[0] if 'grid_range' in self.fname: - self.grid_range, self.plotbbox, self.spacing, self.colorbarfmt, self.stationsongrids, self.time_lines = load_gridfile(self.fname, self.unit) + ( + self.grid_range, + self.plotbbox, + self.spacing, + self.colorbarfmt, + self.stationsongrids, + self.time_lines, + ) = load_gridfile(self.fname, self.unit) self.col_name = os.path.basename(self.fname).split('_' + 'grid_range')[0] if 'grid_variance' in self.fname: - self.grid_variance, self.plotbbox, self.spacing, self.colorbarfmt, self.stationsongrids, self.time_lines = load_gridfile(self.fname, self.unit) + ( + self.grid_variance, + self.plotbbox, + self.spacing, + self.colorbarfmt, + self.stationsongrids, + self.time_lines, + ) = load_gridfile(self.fname, self.unit) self.col_name = os.path.basename(self.fname).split('_' + 'grid_variance')[0] if 'grid_variogram_rmse' in self.fname: - self.grid_variogram_rmse, self.plotbbox, self.spacing, self.colorbarfmt, self.stationsongrids, self.time_lines = load_gridfile(self.fname, self.unit) + ( + self.grid_variogram_rmse, + self.plotbbox, + self.spacing, + self.colorbarfmt, + self.stationsongrids, + self.time_lines, + ) = load_gridfile(self.fname, self.unit) self.col_name = os.path.basename(self.fname).split('_' + 'grid_variogram_rmse')[0] # setup dataframe for statistical analyses (if CSV) if self.fname.endswith('.csv'): self.create_DF() def _get_extent(self): # dataset, spacing=1, userbbox=None - """ Get the bbox, spacing in deg (by default 1deg), optionally pass user-specified bbox. Output array in WESN degrees """ - extent = [np.floor(min(self.df['Lon'])), np.ceil(max(self.df['Lon'])), - np.floor(min(self.df['Lat'])), np.ceil(max(self.df['Lat']))] + """Get the bbox, spacing in deg (by default 1deg), optionally pass user-specified bbox. Output array in WESN degrees.""" + extent = [ + np.floor(min(self.df['Lon'])), + np.ceil(max(self.df['Lon'])), + np.floor(min(self.df['Lat'])), + np.ceil(max(self.df['Lat'])), + ] if self.bbox is not None: - dfextents_poly = Polygon(np.column_stack((np.array([extent[0], extent[0], extent[1], extent[1], extent[0]]), - np.array([extent[2], extent[3], extent[3], extent[2], extent[2]])))) - userbbox_poly = Polygon(np.column_stack((np.array([self.bbox[2], self.bbox[3], self.bbox[3], self.bbox[2], self.bbox[2]]), - np.array([self.bbox[0], self.bbox[0], self.bbox[1], self.bbox[1], self.bbox[0]])))) + dfextents_poly = Polygon( + np.column_stack( + ( + np.array([extent[0], extent[0], extent[1], extent[1], extent[0]]), + np.array([extent[2], extent[3], extent[3], extent[2], extent[2]]), + ) + ) + ) + userbbox_poly = Polygon( + np.column_stack( + ( + np.array([self.bbox[2], self.bbox[3], self.bbox[3], self.bbox[2], self.bbox[2]]), + np.array([self.bbox[0], self.bbox[0], self.bbox[1], self.bbox[1], self.bbox[0]]), + ) + ) + ) if userbbox_poly.intersects(dfextents_poly): extent = [np.floor(self.bbox[2]), np.ceil(self.bbox[-1]), np.floor(self.bbox[0]), np.ceil(self.bbox[1])] else: - raise Exception("User-specified bounds do not overlap with dataset bounds, adjust bounds and re-run program.") - if extent[0] < -180. or extent[1] > 180. or extent[2] < -90. or extent[3] > 90.: - raise Exception("Specified bounds exceed -180/180 lon and/or -90/90 lat, adjust bounds and re-run program.") + raise Exception( + 'User-specified bounds do not overlap with dataset bounds, adjust bounds and re-run program.' + ) + if extent[0] < -180.0 or extent[1] > 180.0 or extent[2] < -90.0 or extent[3] > 90.0: + raise Exception( + 'Specified bounds exceed -180/180 lon and/or -90/90 lat, adjust bounds and re-run program.' + ) del dfextents_poly, userbbox_poly # ensure that extents do not exceed -180/180 lon and -90/90 lat - if extent[0] < -180.: - extent[0] = -180. - if extent[1] > 180.: - extent[1] = 180. - if extent[2] < -90.: - extent[2] = -90. - if extent[3] > 90.: - extent[3] = 90. + if extent[0] < -180.0: + extent[0] = -180.0 + if extent[1] > 180.0: + extent[1] = 180.0 + if extent[2] < -90.0: + extent[2] = -90.0 + if extent[3] > 90.0: + extent[3] = 90.0 # ensure even spacing, set spacing to 1 if specified spacing is not even multiple of bounds if (extent[1] - extent[0]) % self.spacing != 0 or (extent[-1] - extent[-2]) % self.spacing: - logger.warning("User-specified spacing %s is not even multiple of bounds, resetting spacing to 1\N{DEGREE SIGN}", self.spacing) + logger.warning( + 'User-specified spacing %s is not even multiple of bounds, resetting spacing to 1\N{DEGREE SIGN}', + self.spacing, + ) self.spacing = 1 # Create corners of rectangle to be transformed to a grid @@ -870,8 +1347,7 @@ def _get_extent(self): # dataset, spacing=1, userbbox=None se = [extent[1] - (self.spacing / 2), extent[2] + (self.spacing / 2)] # Store grid dimension [y,x] - grid_dim = [int((extent[1] - extent[0]) / self.spacing), - int((extent[-1] - extent[-2]) / self.spacing)] + grid_dim = [int((extent[1] - extent[0]) / self.spacing), int((extent[-1] - extent[-2]) / self.spacing)] # Iterate over 2D area gridpoints = [] @@ -891,12 +1367,16 @@ def _get_extent(self): # dataset, spacing=1, userbbox=None return extent, grid_dim, gridpoints def _check_stationgrid_intersection(self, stat_ID): - ''' + """ Return index of grid cell which intersects with station - Note: Fast, but assumes station locations don't change - ''' - coord = Point((self.unique_points[1][self.unique_points[0].index( - stat_ID)], self.unique_points[2][self.unique_points[0].index(stat_ID)])) + Note: Fast, but assumes station locations don't change. + """ + coord = Point( + ( + self.unique_points[1][self.unique_points[0].index(stat_ID)], + self.unique_points[2][self.unique_points[0].index(stat_ID)], + ) + ) # Get grid cell polygon which intersect with station coordinate grid_int = self.polygon_tree.query(coord) # Pass corresponding grid cell index @@ -906,24 +1386,23 @@ def _check_stationgrid_intersection(self, stat_ID): return 'NaN' def _reader(self): - ''' - Read a input file - ''' + """Read a input file.""" try: data = pd.read_csv(self.fname, parse_dates=['Datetime']) data['Date'] = data['Datetime'].apply(lambda x: x.date()) - data['Date'] = data['Date'].apply(lambda x: dt.datetime.strptime(x.strftime("%Y-%m-%d"), "%Y-%m-%d")) - except BaseException: + data['Date'] = data['Date'].apply(lambda x: dt.datetime.strptime(x.strftime('%Y-%m-%d'), '%Y-%m-%d')) + except: data = pd.read_csv(self.fname, parse_dates=['Date']) # check if user-specified key is valid if self.col_name not in data.keys(): raise Exception( - 'User-specified key {} not found in input file {}. Must specify valid key.' .format(self.col_name, self.fname)) + f'User-specified key {self.col_name} not found in input file {self.fname}. Must specify valid key.' + ) # if user-specified key is the same as the 'Date' field, rename if self.col_name == 'Date': - logger.warning('Input key {} same as "Date" field name, rename the former'.format(self.col_name)) + logger.warning(f'Input key {self.col_name} same as "Date" field name, rename the former') self.col_name += '_plot' data[self.col_name] = data['Date'] @@ -940,10 +1419,8 @@ def _reader(self): return data - def create_DF(self): - ''' - Create dataframe. - ''' + def create_DF(self) -> None: + """Create dataframe.""" # Open file self.df = self._reader() @@ -955,46 +1432,56 @@ def create_DF(self): # time-interval filter if self.timeinterval: - self.timeinterval = [dt.datetime.strptime( - val, '%Y-%m-%d') for val in self.timeinterval.split()] - self.df = self.df[(self.df['Date'] >= self.timeinterval[0]) & ( - self.df['Date'] <= self.timeinterval[-1])] + self.timeinterval = [dt.datetime.strptime(val, '%Y-%m-%d') for val in self.timeinterval.split()] + self.df = self.df[(self.df['Date'] >= self.timeinterval[0]) & (self.df['Date'] <= self.timeinterval[-1])] # seasonal filter if self.seasonalinterval: self.seasonalinterval = self.seasonalinterval.split() # get day of year - self.seasonalinterval = [dt.datetime.strptime('2001-' + self.seasonalinterval[0], '%Y-%m-%d').timetuple( - ).tm_yday, dt.datetime.strptime('2001-' + self.seasonalinterval[-1], '%Y-%m-%d').timetuple().tm_yday] + self.seasonalinterval = [ + dt.datetime.strptime('2001-' + self.seasonalinterval[0], '%Y-%m-%d').timetuple().tm_yday, + dt.datetime.strptime('2001-' + self.seasonalinterval[-1], '%Y-%m-%d').timetuple().tm_yday, + ] # track input order and wrap around year if necessary # e.g. month/day: 03/01 to 06/01 if self.seasonalinterval[0] < self.seasonalinterval[1]: # non leap-year - filtered_self = self.df[(self.df['Date'].dt.is_leap_year == False) & ( - self.df['Date'].dt.dayofyear >= self.seasonalinterval[0]) & (self.df['Date'].dt.dayofyear <= self.seasonalinterval[-1])] + filtered_self = self.df[ + (not self.df['Date'].dt.is_leap_year) + & (self.df['Date'].dt.dayofyear >= self.seasonalinterval[0]) + & (self.df['Date'].dt.dayofyear <= self.seasonalinterval[-1]) + ] # leap-year - self.seasonalinterval = [i + 1 if i > - 59 else i for i in self.seasonalinterval] - filtered_self_ly = self.df[(self.df['Date'].dt.is_leap_year == True) & ( - self.df['Date'].dt.dayofyear >= self.seasonalinterval[0]) & (self.df['Date'].dt.dayofyear <= self.seasonalinterval[-1])] + self.seasonalinterval = [i + 1 if i > 59 else i for i in self.seasonalinterval] + filtered_self_ly = self.df[ + (self.df['Date'].dt.is_leap_year) + & (self.df['Date'].dt.dayofyear >= self.seasonalinterval[0]) + & (self.df['Date'].dt.dayofyear <= self.seasonalinterval[-1]) + ] self.df = pd.concat([filtered_self, filtered_self_ly], ignore_index=True) del filtered_self # e.g. month/day: 12/01 to 03/01 if self.seasonalinterval[0] > self.seasonalinterval[1]: # non leap-year - filtered_self = self.df[(self.df['Date'].dt.is_leap_year == False) & ( - self.df['Date'].dt.dayofyear >= self.seasonalinterval[-1]) & (self.df['Date'].dt.dayofyear <= self.seasonalinterval[0])] + filtered_self = self.df[ + (not self.df['Date'].dt.is_leap_year) + & (self.df['Date'].dt.dayofyear >= self.seasonalinterval[-1]) + & (self.df['Date'].dt.dayofyear <= self.seasonalinterval[0]) + ] # leap-year - self.seasonalinterval = [i + 1 if i > - 59 else i for i in self.seasonalinterval] - filtered_self_ly = self.df[(self.df['Date'].dt.is_leap_year == True) & ( - self.df['Date'].dt.dayofyear >= self.seasonalinterval[-1]) & (self.df['Date'].dt.dayofyear <= self.seasonalinterval[0])] + self.seasonalinterval = [i + 1 if i > 59 else i for i in self.seasonalinterval] + filtered_self_ly = self.df[ + (self.df['Date'].dt.is_leap_year) + & (self.df['Date'].dt.dayofyear >= self.seasonalinterval[-1]) + & (self.df['Date'].dt.dayofyear <= self.seasonalinterval[0]) + ] self.df = pd.concat([filtered_self, filtered_self_ly], ignore_index=True) del filtered_self # estimate central longitude lines if '--time_lines' specified if self.time_lines and 'Datetime' in self.df.keys(): - self.df['Date_hr'] = self.df['Datetime'].dt.hour.astype(float).astype("Int32") + self.df['Date_hr'] = self.df['Datetime'].dt.hour.astype(float).astype('Int32') # get list of unique times all_hrs = sorted(set(self.df['Date_hr'])) @@ -1002,8 +1489,7 @@ def create_DF(self): central_points = [] # if single time, avoid loop if len(all_hrs) == 1: - central_points.append(([0, max(self.df['Lon'])], - [0, min(self.df['Lon'])])) + central_points.append(([0, max(self.df['Lon'])], [0, min(self.df['Lon'])])) else: for i in enumerate(all_hrs): # last entry @@ -1013,10 +1499,10 @@ def create_DF(self): elif i[0] == 0: lons = self.df[self.df['Date_hr'] < all_hrs[i[0] + 1]] else: - lons = self.df[(self.df['Date_hr'] > all_hrs[i[0] - 1]) - & (self.df['Date_hr'] < all_hrs[i[0] + 1])] - central_points.append(([0, max(lons['Lon'])], - [0, min(lons['Lon'])])) + lons = self.df[ + (self.df['Date_hr'] > all_hrs[i[0] - 1]) & (self.df['Date_hr'] < all_hrs[i[0] + 1]) + ] + central_points.append(([0, max(lons['Lon'])], [0, min(lons['Lon'])])) # get central longitudes self.time_lines = [midpoint(i[0], i[1]) for i in central_points] @@ -1025,24 +1511,40 @@ def create_DF(self): if self.bbox is not None: try: self.bbox = [float(val) for val in self.bbox.split()] - except BaseException: + except: raise Exception( - 'Cannot understand the --bounding_box argument. String input is incorrect or path does not exist.') + 'Cannot understand the --bounding_box argument. String input is incorrect or path does not exist.' + ) self.plotbbox, self.grid_dim, self.gridpoints = self._get_extent() # generate list of grid-polygons append_poly = [] for i in self.gridpoints: - bbox = [i[1] - (self.spacing / 2), i[1] + (self.spacing / 2), - i[0] - (self.spacing / 2), i[0] + (self.spacing / 2)] - append_poly.append(Polygon(np.column_stack((np.array([bbox[2], bbox[3], bbox[3], bbox[2], bbox[2]]), - np.array([bbox[0], bbox[0], bbox[1], bbox[1], bbox[0]]))))) # Pass lons/lats to create polygon + bbox = [ + i[1] - (self.spacing / 2), + i[1] + (self.spacing / 2), + i[0] - (self.spacing / 2), + i[0] + (self.spacing / 2), + ] + append_poly.append( + Polygon( + np.column_stack( + ( + np.array([bbox[2], bbox[3], bbox[3], bbox[2], bbox[2]]), + np.array([bbox[0], bbox[0], bbox[1], bbox[1], bbox[0]]), + ) + ) + ) + ) # Pass lons/lats to create polygon # Check for grid cell intersection with each station idtogrid_dict = {} self.unique_points = self.df.groupby(['ID', 'Lon', 'Lat']).size() - self.unique_points = [self.unique_points.index.get_level_values('ID').tolist(), self.unique_points.index.get_level_values( - 'Lon').tolist(), self.unique_points.index.get_level_values('Lat').tolist()] + self.unique_points = [ + self.unique_points.index.get_level_values('ID').tolist(), + self.unique_points.index.get_level_values('Lon').tolist(), + self.unique_points.index.get_level_values('Lat').tolist(), + ] # Initiate R-tree of gridded array domain self.polygon_tree = STRtree(append_poly) for stat_ID in self.unique_points[0]: @@ -1059,108 +1561,236 @@ def create_DF(self): # If specified, pass station locations to superimpose on gridplots if self.stationsongrids: unique_points = self.df.groupby(['Lon', 'Lat']).size() - self.stationsongrids = [unique_points.index.get_level_values( - 'Lon').tolist(), unique_points.index.get_level_values('Lat').tolist()] + self.stationsongrids = [ + unique_points.index.get_level_values('Lon').tolist(), + unique_points.index.get_level_values('Lat').tolist(), + ] # If specified, setup gridded array(s) if self.grid_heatmap: - self.grid_heatmap = np.array([np.nan if i[0] not in self.df['gridnode'].values[:] else int(len(np.unique( - self.df['ID'][self.df['gridnode'] == i[0]]))) for i in enumerate(self.gridpoints)]).reshape(self.grid_dim).T + self.grid_heatmap = ( + np.array( + [ + np.nan + if i[0] not in self.df['gridnode'].values[:] + else int(len(np.unique(self.df['ID'][self.df['gridnode'] == i[0]]))) + for i in enumerate(self.gridpoints) + ] + ) + .reshape(self.grid_dim) + .T + ) # If specified, save gridded array(s) if self.grid_to_raster: gridfile_name = os.path.join(self.workdir, self.col_name + '_' + 'grid_heatmap' + '.tif') - save_gridfile(self.grid_heatmap, 'grid_heatmap', gridfile_name, self.plotbbox, self.spacing, - self.unit, colorbarfmt='%1i', - stationsongrids=self.stationsongrids, - time_lines=self.time_lines, dtype='int16', - noData=0) + save_gridfile( + self.grid_heatmap, + 'grid_heatmap', + gridfile_name, + self.plotbbox, + self.spacing, + self.unit, + colorbarfmt='%1i', + stationsongrids=self.stationsongrids, + time_lines=self.time_lines, + dtype='int16', + noData=0, + ) if self.grid_delay_mean: # Take mean of station-wise means per gridcell unique_points = self.df.groupby(['ID', 'Lon', 'Lat', 'gridnode'], as_index=False)[self.col_name].mean() unique_points = unique_points.groupby(['gridnode'])[self.col_name].mean() unique_points.dropna(how='any', inplace=True) - self.grid_delay_mean = np.array([np.nan if i[0] not in unique_points.index.get_level_values('gridnode').tolist( - ) else unique_points[i[0]] for i in enumerate(self.gridpoints)]).reshape(self.grid_dim).T + self.grid_delay_mean = ( + np.array( + [ + np.nan + if i[0] not in unique_points.index.get_level_values('gridnode').tolist() + else unique_points[i[0]] + for i in enumerate(self.gridpoints) + ] + ) + .reshape(self.grid_dim) + .T + ) # If specified, save gridded array(s) if self.grid_to_raster: gridfile_name = os.path.join(self.workdir, self.col_name + '_' + 'grid_delay_mean' + '.tif') - save_gridfile(self.grid_delay_mean, 'grid_delay_mean', gridfile_name, self.plotbbox, self.spacing, - self.unit, colorbarfmt='%.2f', - stationsongrids=self.stationsongrids, - time_lines=self.time_lines, dtype='float32') + save_gridfile( + self.grid_delay_mean, + 'grid_delay_mean', + gridfile_name, + self.plotbbox, + self.spacing, + self.unit, + colorbarfmt='%.2f', + stationsongrids=self.stationsongrids, + time_lines=self.time_lines, + dtype='float32', + ) if self.grid_delay_median: # Take mean of station-wise medians per gridcell unique_points = self.df.groupby(['ID', 'Lon', 'Lat', 'gridnode'], as_index=False)[self.col_name].median() unique_points = unique_points.groupby(['gridnode'])[self.col_name].mean() unique_points.dropna(how='any', inplace=True) - self.grid_delay_median = np.array([np.nan if i[0] not in unique_points.index.get_level_values('gridnode').tolist( - ) else unique_points[i[0]] for i in enumerate(self.gridpoints)]).reshape(self.grid_dim).T + self.grid_delay_median = ( + np.array( + [ + np.nan + if i[0] not in unique_points.index.get_level_values('gridnode').tolist() + else unique_points[i[0]] + for i in enumerate(self.gridpoints) + ] + ) + .reshape(self.grid_dim) + .T + ) # If specified, save gridded array(s) if self.grid_to_raster: gridfile_name = os.path.join(self.workdir, self.col_name + '_' + 'grid_delay_median' + '.tif') - save_gridfile(self.grid_delay_median, 'grid_delay_median', gridfile_name, self.plotbbox, self.spacing, - self.unit, colorbarfmt='%.2f', - stationsongrids=self.stationsongrids, - time_lines=self.time_lines, dtype='float32') + save_gridfile( + self.grid_delay_median, + 'grid_delay_median', + gridfile_name, + self.plotbbox, + self.spacing, + self.unit, + colorbarfmt='%.2f', + stationsongrids=self.stationsongrids, + time_lines=self.time_lines, + dtype='float32', + ) if self.grid_delay_stdev: # Take mean of station-wise stdev per gridcell unique_points = self.df.groupby(['ID', 'Lon', 'Lat', 'gridnode'], as_index=False)[self.col_name].std() unique_points = unique_points.groupby(['gridnode'])[self.col_name].mean() unique_points.dropna(how='any', inplace=True) - self.grid_delay_stdev = np.array([np.nan if i[0] not in unique_points.index.get_level_values('gridnode').tolist( - ) else unique_points[i[0]] for i in enumerate(self.gridpoints)]).reshape(self.grid_dim).T + self.grid_delay_stdev = ( + np.array( + [ + np.nan + if i[0] not in unique_points.index.get_level_values('gridnode').tolist() + else unique_points[i[0]] + for i in enumerate(self.gridpoints) + ] + ) + .reshape(self.grid_dim) + .T + ) # If specified, save gridded array(s) if self.grid_to_raster: gridfile_name = os.path.join(self.workdir, self.col_name + '_' + 'grid_delay_stdev' + '.tif') - save_gridfile(self.grid_delay_stdev, 'grid_delay_stdev', gridfile_name, self.plotbbox, self.spacing, - self.unit, colorbarfmt='%.2f', - stationsongrids=self.stationsongrids, - time_lines=self.time_lines, dtype='float32') + save_gridfile( + self.grid_delay_stdev, + 'grid_delay_stdev', + gridfile_name, + self.plotbbox, + self.spacing, + self.unit, + colorbarfmt='%.2f', + stationsongrids=self.stationsongrids, + time_lines=self.time_lines, + dtype='float32', + ) if self.grid_delay_absolute_mean: # Take mean of all data per gridcell unique_points = self.df.groupby(['gridnode'])[self.col_name].mean() unique_points.dropna(how='any', inplace=True) - self.grid_delay_absolute_mean = np.array([np.nan if i[0] not in unique_points.index.get_level_values('gridnode').tolist( - ) else unique_points[i[0]] for i in enumerate(self.gridpoints)]).reshape(self.grid_dim).T + self.grid_delay_absolute_mean = ( + np.array( + [ + np.nan + if i[0] not in unique_points.index.get_level_values('gridnode').tolist() + else unique_points[i[0]] + for i in enumerate(self.gridpoints) + ] + ) + .reshape(self.grid_dim) + .T + ) # If specified, save gridded array(s) if self.grid_to_raster: gridfile_name = os.path.join(self.workdir, self.col_name + '_' + 'grid_delay_absolute_mean' + '.tif') - save_gridfile(self.grid_delay_absolute_mean, 'grid_delay_absolute_mean', gridfile_name, self.plotbbox, self.spacing, - self.unit, colorbarfmt='%.2f', - stationsongrids=self.stationsongrids, - time_lines=self.time_lines, dtype='float32') + save_gridfile( + self.grid_delay_absolute_mean, + 'grid_delay_absolute_mean', + gridfile_name, + self.plotbbox, + self.spacing, + self.unit, + colorbarfmt='%.2f', + stationsongrids=self.stationsongrids, + time_lines=self.time_lines, + dtype='float32', + ) if self.grid_delay_absolute_median: # Take median of all data per gridcell unique_points = self.df.groupby(['gridnode'])[self.col_name].median() unique_points.dropna(how='any', inplace=True) - self.grid_delay_absolute_median = np.array([np.nan if i[0] not in unique_points.index.get_level_values('gridnode').tolist( - ) else unique_points[i[0]] for i in enumerate(self.gridpoints)]).reshape(self.grid_dim).T + self.grid_delay_absolute_median = ( + np.array( + [ + np.nan + if i[0] not in unique_points.index.get_level_values('gridnode').tolist() + else unique_points[i[0]] + for i in enumerate(self.gridpoints) + ] + ) + .reshape(self.grid_dim) + .T + ) # If specified, save gridded array(s) if self.grid_to_raster: gridfile_name = os.path.join(self.workdir, self.col_name + '_' + 'grid_delay_absolute_median' + '.tif') - save_gridfile(self.grid_delay_absolute_median, 'grid_delay_absolute_median', gridfile_name, self.plotbbox, self.spacing, - self.unit, colorbarfmt='%.2f', - stationsongrids=self.stationsongrids, - time_lines=self.time_lines, dtype='float32') + save_gridfile( + self.grid_delay_absolute_median, + 'grid_delay_absolute_median', + gridfile_name, + self.plotbbox, + self.spacing, + self.unit, + colorbarfmt='%.2f', + stationsongrids=self.stationsongrids, + time_lines=self.time_lines, + dtype='float32', + ) if self.grid_delay_absolute_stdev: # Take stdev of all data per gridcell unique_points = self.df.groupby(['gridnode'])[self.col_name].std() unique_points.dropna(how='any', inplace=True) - self.grid_delay_absolute_stdev = np.array([np.nan if i[0] not in unique_points.index.get_level_values('gridnode').tolist( - ) else unique_points[i[0]] for i in enumerate(self.gridpoints)]).reshape(self.grid_dim).T + self.grid_delay_absolute_stdev = ( + np.array( + [ + np.nan + if i[0] not in unique_points.index.get_level_values('gridnode').tolist() + else unique_points[i[0]] + for i in enumerate(self.gridpoints) + ] + ) + .reshape(self.grid_dim) + .T + ) # If specified, save gridded array(s) if self.grid_to_raster: gridfile_name = os.path.join(self.workdir, self.col_name + '_' + 'grid_delay_absolute_stdev' + '.tif') - save_gridfile(self.grid_delay_absolute_stdev, 'grid_delay_absolute_stdev', gridfile_name, self.plotbbox, self.spacing, - self.unit, colorbarfmt='%.2f', - stationsongrids=self.stationsongrids, - time_lines=self.time_lines, dtype='float32') + save_gridfile( + self.grid_delay_absolute_stdev, + 'grid_delay_absolute_stdev', + gridfile_name, + self.plotbbox, + self.spacing, + self.unit, + colorbarfmt='%.2f', + stationsongrids=self.stationsongrids, + time_lines=self.time_lines, + dtype='float32', + ) # If specified, compute phase/amplitude fits if self.station_seasonal_phase or self.grid_seasonal_phase or self.grid_seasonal_absolute_phase: @@ -1178,9 +1808,18 @@ def create_DF(self): args = [] for i in sorted(list(set(unique_points['ID']))): # pass all values corresponding to station (ID, data = y, time = x) - args.append((i, unique_points[unique_points['ID'] == i]['Date'].to_list(), unique_points[unique_points['ID'] == i][self.col_name].to_list(), self.min_span[0], self.min_span[1], self.period_limit)) + args.append( + ( + i, + unique_points[unique_points['ID'] == i]['Date'].to_list(), + unique_points[unique_points['ID'] == i][self.col_name].to_list(), + self.min_span[0], + self.min_span[1], + self.period_limit, + ) + ) # Parallelize iteration through all grid-cells and time slices - with multiprocessing.Pool(self.numCPUs) as multipool: + with mp.Pool(self.numCPUs) as multipool: for i, j, k, l, m, n, o in multipool.starmap(self._amplitude_and_phase, args): self.ampfit.extend(i) self.phsfit.extend(j) @@ -1196,8 +1835,9 @@ def create_DF(self): self.df['phsfit'] = self.df['ID'].map(self.phsfit) # check if there are any valid data values if self.df['phsfit'].isnull().values.all(axis=0): - raise Exception("No valid data values, adjust --min_span inputs for time span in years {} and/or fractional obs. {}". - format(self.min_span[0], self.min_span[1])) + raise Exception( + f'No valid data values, adjust --min_span inputs for time span in years {self.min_span[0]} and/or fractional obs. {self.min_span[1]}' + ) self.df['ampfit'] = self.df['ID'].map(self.ampfit) self.df['periodfit'] = self.df['ID'].map(self.periodfit) self.phsfit_c = {k: v for d in self.phsfit_c for k, v in d.items()} @@ -1216,191 +1856,465 @@ def create_DF(self): unique_points = self.df.groupby(['ID', 'Lon', 'Lat', 'gridnode'], as_index=False)['phsfit'].mean() unique_points = unique_points.groupby(['gridnode'])['phsfit'].mean() unique_points.dropna(how='any', inplace=True) - self.grid_seasonal_phase = np.array([np.nan if i[0] not in unique_points.index.get_level_values('gridnode').tolist( - ) else unique_points[i[0]] for i in enumerate(self.gridpoints)]).reshape(self.grid_dim).T + self.grid_seasonal_phase = ( + np.array( + [ + np.nan + if i[0] not in unique_points.index.get_level_values('gridnode').tolist() + else unique_points[i[0]] + for i in enumerate(self.gridpoints) + ] + ) + .reshape(self.grid_dim) + .T + ) # If specified, save gridded array(s) if self.grid_to_raster: gridfile_name = os.path.join(self.workdir, self.col_name + '_' + 'grid_seasonal_phase' + '.tif') - save_gridfile(self.grid_seasonal_phase, 'grid_seasonal_phase', gridfile_name, self.plotbbox, self.spacing, - 'days', colorbarfmt='%.1i', - stationsongrids=self.stationsongrids, - time_lines=self.time_lines, dtype='float32') + save_gridfile( + self.grid_seasonal_phase, + 'grid_seasonal_phase', + gridfile_name, + self.plotbbox, + self.spacing, + 'days', + colorbarfmt='%.1i', + stationsongrids=self.stationsongrids, + time_lines=self.time_lines, + dtype='float32', + ) # Pass mean amplitude of station-wise means per gridcell unique_points = self.df.groupby(['ID', 'Lon', 'Lat', 'gridnode'], as_index=False)['ampfit'].mean() unique_points = unique_points.groupby(['gridnode'])['ampfit'].mean() unique_points.dropna(how='any', inplace=True) - self.grid_seasonal_amplitude = np.array([np.nan if i[0] not in unique_points.index.get_level_values('gridnode').tolist( - ) else unique_points[i[0]] for i in enumerate(self.gridpoints)]).reshape(self.grid_dim).T + self.grid_seasonal_amplitude = ( + np.array( + [ + np.nan + if i[0] not in unique_points.index.get_level_values('gridnode').tolist() + else unique_points[i[0]] + for i in enumerate(self.gridpoints) + ] + ) + .reshape(self.grid_dim) + .T + ) # If specified, save gridded array(s) if self.grid_to_raster: gridfile_name = os.path.join(self.workdir, self.col_name + '_' + 'grid_seasonal_amplitude' + '.tif') - save_gridfile(self.grid_seasonal_amplitude, 'grid_seasonal_amplitude', gridfile_name, self.plotbbox, self.spacing, - self.unit, colorbarfmt='%.3f', - stationsongrids=self.stationsongrids, - time_lines=self.time_lines, dtype='float32') + save_gridfile( + self.grid_seasonal_amplitude, + 'grid_seasonal_amplitude', + gridfile_name, + self.plotbbox, + self.spacing, + self.unit, + colorbarfmt='%.3f', + stationsongrids=self.stationsongrids, + time_lines=self.time_lines, + dtype='float32', + ) # Pass mean period of station-wise means per gridcell unique_points = self.df.groupby(['ID', 'Lon', 'Lat', 'gridnode'], as_index=False)['periodfit'].mean() unique_points = unique_points.groupby(['gridnode'])['periodfit'].mean() unique_points.dropna(how='any', inplace=True) - self.grid_seasonal_period = np.array([np.nan if i[0] not in unique_points.index.get_level_values('gridnode').tolist( - ) else unique_points[i[0]] for i in enumerate(self.gridpoints)]).reshape(self.grid_dim).T + self.grid_seasonal_period = ( + np.array( + [ + np.nan + if i[0] not in unique_points.index.get_level_values('gridnode').tolist() + else unique_points[i[0]] + for i in enumerate(self.gridpoints) + ] + ) + .reshape(self.grid_dim) + .T + ) # If specified, save gridded array(s) if self.grid_to_raster: gridfile_name = os.path.join(self.workdir, self.col_name + '_' + 'grid_seasonal_period' + '.tif') - save_gridfile(self.grid_seasonal_period, 'grid_seasonal_period', gridfile_name, self.plotbbox, self.spacing, - 'years', colorbarfmt='%.2f', - stationsongrids=self.stationsongrids, - time_lines=self.time_lines, dtype='float32') + save_gridfile( + self.grid_seasonal_period, + 'grid_seasonal_period', + gridfile_name, + self.plotbbox, + self.spacing, + 'years', + colorbarfmt='%.2f', + stationsongrids=self.stationsongrids, + time_lines=self.time_lines, + dtype='float32', + ) ######################################################################################################################## # Pass mean phase stdev of station-wise means per gridcell unique_points = self.df.groupby(['ID', 'Lon', 'Lat', 'gridnode'], as_index=False)['phsfit_c'].mean() unique_points = unique_points.groupby(['gridnode'])['phsfit_c'].mean() unique_points.dropna(how='any', inplace=True) - self.grid_seasonal_phase_stdev = np.array([np.nan if i[0] not in unique_points.index.get_level_values('gridnode').tolist( - ) else unique_points[i[0]] for i in enumerate(self.gridpoints)]).reshape(self.grid_dim).T + self.grid_seasonal_phase_stdev = ( + np.array( + [ + np.nan + if i[0] not in unique_points.index.get_level_values('gridnode').tolist() + else unique_points[i[0]] + for i in enumerate(self.gridpoints) + ] + ) + .reshape(self.grid_dim) + .T + ) # If specified, save gridded array(s) if self.grid_to_raster: - gridfile_name = os.path.join(self.workdir, self.col_name + '_' + 'grid_seasonal_phase_stdev' + '.tif') - save_gridfile(self.grid_seasonal_phase_stdev, 'grid_seasonal_phase_stdev', gridfile_name, self.plotbbox, self.spacing, - 'days', colorbarfmt='%.1i', - stationsongrids=self.stationsongrids, - time_lines=self.time_lines, dtype='float32') + gridfile_name = os.path.join( + self.workdir, self.col_name + '_' + 'grid_seasonal_phase_stdev' + '.tif' + ) + save_gridfile( + self.grid_seasonal_phase_stdev, + 'grid_seasonal_phase_stdev', + gridfile_name, + self.plotbbox, + self.spacing, + 'days', + colorbarfmt='%.1i', + stationsongrids=self.stationsongrids, + time_lines=self.time_lines, + dtype='float32', + ) # Pass mean amplitude stdev of station-wise means per gridcell unique_points = self.df.groupby(['ID', 'Lon', 'Lat', 'gridnode'], as_index=False)['ampfit_c'].mean() unique_points = unique_points.groupby(['gridnode'])['ampfit_c'].mean() unique_points.dropna(how='any', inplace=True) - self.grid_seasonal_amplitude_stdev = np.array([np.nan if i[0] not in unique_points.index.get_level_values('gridnode').tolist( - ) else unique_points[i[0]] for i in enumerate(self.gridpoints)]).reshape(self.grid_dim).T + self.grid_seasonal_amplitude_stdev = ( + np.array( + [ + np.nan + if i[0] not in unique_points.index.get_level_values('gridnode').tolist() + else unique_points[i[0]] + for i in enumerate(self.gridpoints) + ] + ) + .reshape(self.grid_dim) + .T + ) # If specified, save gridded array(s) if self.grid_to_raster: - gridfile_name = os.path.join(self.workdir, self.col_name + '_' + 'grid_seasonal_amplitude_stdev' + '.tif') - save_gridfile(self.grid_seasonal_amplitude_stdev, 'grid_seasonal_amplitude_stdev', gridfile_name, self.plotbbox, self.spacing, - self.unit, colorbarfmt='%.3f', - stationsongrids=self.stationsongrids, - time_lines=self.time_lines, dtype='float32') + gridfile_name = os.path.join( + self.workdir, self.col_name + '_' + 'grid_seasonal_amplitude_stdev' + '.tif' + ) + save_gridfile( + self.grid_seasonal_amplitude_stdev, + 'grid_seasonal_amplitude_stdev', + gridfile_name, + self.plotbbox, + self.spacing, + self.unit, + colorbarfmt='%.3f', + stationsongrids=self.stationsongrids, + time_lines=self.time_lines, + dtype='float32', + ) # Pass mean period stdev of station-wise means per gridcell unique_points = self.df.groupby(['ID', 'Lon', 'Lat', 'gridnode'], as_index=False)['periodfit_c'].mean() unique_points = unique_points.groupby(['gridnode'])['periodfit_c'].mean() unique_points.dropna(how='any', inplace=True) - self.grid_seasonal_period_stdev = np.array([np.nan if i[0] not in unique_points.index.get_level_values('gridnode').tolist( - ) else unique_points[i[0]] for i in enumerate(self.gridpoints)]).reshape(self.grid_dim).T + self.grid_seasonal_period_stdev = ( + np.array( + [ + np.nan + if i[0] not in unique_points.index.get_level_values('gridnode').tolist() + else unique_points[i[0]] + for i in enumerate(self.gridpoints) + ] + ) + .reshape(self.grid_dim) + .T + ) # If specified, save gridded array(s) if self.grid_to_raster: - gridfile_name = os.path.join(self.workdir, self.col_name + '_' + 'grid_seasonal_period_stdev' + '.tif') - save_gridfile(self.grid_seasonal_period_stdev, 'grid_seasonal_period_stdev', gridfile_name, self.plotbbox, self.spacing, - 'years', colorbarfmt='%.2e', - stationsongrids=self.stationsongrids, - time_lines=self.time_lines, dtype='float32') + gridfile_name = os.path.join( + self.workdir, self.col_name + '_' + 'grid_seasonal_period_stdev' + '.tif' + ) + save_gridfile( + self.grid_seasonal_period_stdev, + 'grid_seasonal_period_stdev', + gridfile_name, + self.plotbbox, + self.spacing, + 'years', + colorbarfmt='%.2e', + stationsongrids=self.stationsongrids, + time_lines=self.time_lines, + dtype='float32', + ) # Pass mean seasonal fit RMSE of station-wise means per gridcell - unique_points = self.df.groupby(['ID', 'Lon', 'Lat', 'gridnode'], as_index=False)['seasonalfit_rmse'].mean() + unique_points = self.df.groupby(['ID', 'Lon', 'Lat', 'gridnode'], as_index=False)[ + 'seasonalfit_rmse' + ].mean() unique_points = unique_points.groupby(['gridnode'])['seasonalfit_rmse'].mean() unique_points.dropna(how='any', inplace=True) - self.grid_seasonal_fit_rmse = np.array([np.nan if i[0] not in unique_points.index.get_level_values('gridnode').tolist( - ) else unique_points[i[0]] for i in enumerate(self.gridpoints)]).reshape(self.grid_dim).T + self.grid_seasonal_fit_rmse = ( + np.array( + [ + np.nan + if i[0] not in unique_points.index.get_level_values('gridnode').tolist() + else unique_points[i[0]] + for i in enumerate(self.gridpoints) + ] + ) + .reshape(self.grid_dim) + .T + ) # If specified, save gridded array(s) if self.grid_to_raster: gridfile_name = os.path.join(self.workdir, self.col_name + '_' + 'grid_seasonal_fit_rmse' + '.tif') - save_gridfile(self.grid_seasonal_fit_rmse, 'grid_seasonal_fit_rmse', gridfile_name, self.plotbbox, self.spacing, - self.unit, colorbarfmt='%.3f', - stationsongrids=self.stationsongrids, - time_lines=self.time_lines, dtype='float32') + save_gridfile( + self.grid_seasonal_fit_rmse, + 'grid_seasonal_fit_rmse', + gridfile_name, + self.plotbbox, + self.spacing, + self.unit, + colorbarfmt='%.3f', + stationsongrids=self.stationsongrids, + time_lines=self.time_lines, + dtype='float32', + ) ######################################################################################################################## if self.grid_seasonal_absolute_phase: # Pass absolute mean phase of all data per gridcell unique_points = self.df.groupby(['gridnode'])['phsfit'].mean() unique_points.dropna(how='any', inplace=True) - self.grid_seasonal_absolute_phase = np.array([np.nan if i[0] not in unique_points.index.get_level_values('gridnode').tolist( - ) else unique_points[i[0]] for i in enumerate(self.gridpoints)]).reshape(self.grid_dim).T + self.grid_seasonal_absolute_phase = ( + np.array( + [ + np.nan + if i[0] not in unique_points.index.get_level_values('gridnode').tolist() + else unique_points[i[0]] + for i in enumerate(self.gridpoints) + ] + ) + .reshape(self.grid_dim) + .T + ) # If specified, save gridded array(s) if self.grid_to_raster: - gridfile_name = os.path.join(self.workdir, self.col_name + '_' + 'grid_seasonal_absolute_phase' + '.tif') - save_gridfile(self.grid_seasonal_absolute_phase, 'grid_seasonal_absolute_phase', gridfile_name, self.plotbbox, self.spacing, - 'days', colorbarfmt='%.1i', - stationsongrids=self.stationsongrids, - time_lines=self.time_lines, dtype='float32') + gridfile_name = os.path.join( + self.workdir, self.col_name + '_' + 'grid_seasonal_absolute_phase' + '.tif' + ) + save_gridfile( + self.grid_seasonal_absolute_phase, + 'grid_seasonal_absolute_phase', + gridfile_name, + self.plotbbox, + self.spacing, + 'days', + colorbarfmt='%.1i', + stationsongrids=self.stationsongrids, + time_lines=self.time_lines, + dtype='float32', + ) # Pass absolute mean amplitude of all data per gridcell unique_points = self.df.groupby(['gridnode'])['ampfit'].mean() unique_points.dropna(how='any', inplace=True) - self.grid_seasonal_absolute_amplitude = np.array([np.nan if i[0] not in unique_points.index.get_level_values('gridnode').tolist( - ) else unique_points[i[0]] for i in enumerate(self.gridpoints)]).reshape(self.grid_dim).T + self.grid_seasonal_absolute_amplitude = ( + np.array( + [ + np.nan + if i[0] not in unique_points.index.get_level_values('gridnode').tolist() + else unique_points[i[0]] + for i in enumerate(self.gridpoints) + ] + ) + .reshape(self.grid_dim) + .T + ) # If specified, save gridded array(s) if self.grid_to_raster: - gridfile_name = os.path.join(self.workdir, self.col_name + '_' + 'grid_seasonal_absolute_amplitude' + '.tif') - save_gridfile(self.grid_seasonal_absolute_amplitude, 'grid_seasonal_absolute_amplitude', gridfile_name, self.plotbbox, self.spacing, - self.unit, colorbarfmt='%.3f', - stationsongrids=self.stationsongrids, - time_lines=self.time_lines, dtype='float32') + gridfile_name = os.path.join( + self.workdir, self.col_name + '_' + 'grid_seasonal_absolute_amplitude' + '.tif' + ) + save_gridfile( + self.grid_seasonal_absolute_amplitude, + 'grid_seasonal_absolute_amplitude', + gridfile_name, + self.plotbbox, + self.spacing, + self.unit, + colorbarfmt='%.3f', + stationsongrids=self.stationsongrids, + time_lines=self.time_lines, + dtype='float32', + ) # Pass absolute mean period of all data per gridcell unique_points = self.df.groupby(['gridnode'])['periodfit'].mean() unique_points.dropna(how='any', inplace=True) - self.grid_seasonal_absolute_period = np.array([np.nan if i[0] not in unique_points.index.get_level_values('gridnode').tolist( - ) else unique_points[i[0]] for i in enumerate(self.gridpoints)]).reshape(self.grid_dim).T + self.grid_seasonal_absolute_period = ( + np.array( + [ + np.nan + if i[0] not in unique_points.index.get_level_values('gridnode').tolist() + else unique_points[i[0]] + for i in enumerate(self.gridpoints) + ] + ) + .reshape(self.grid_dim) + .T + ) # If specified, save gridded array(s) if self.grid_to_raster: - gridfile_name = os.path.join(self.workdir, self.col_name + '_' + 'grid_seasonal_absolute_period' + '.tif') - save_gridfile(self.grid_seasonal_absolute_period, 'grid_seasonal_absolute_period', gridfile_name, self.plotbbox, self.spacing, - 'years', colorbarfmt='%.2f', - stationsongrids=self.stationsongrids, - time_lines=self.time_lines, dtype='float32') + gridfile_name = os.path.join( + self.workdir, self.col_name + '_' + 'grid_seasonal_absolute_period' + '.tif' + ) + save_gridfile( + self.grid_seasonal_absolute_period, + 'grid_seasonal_absolute_period', + gridfile_name, + self.plotbbox, + self.spacing, + 'years', + colorbarfmt='%.2f', + stationsongrids=self.stationsongrids, + time_lines=self.time_lines, + dtype='float32', + ) ######################################################################################################################## # Pass absolute mean phase stdev of all data per gridcell unique_points = self.df.groupby(['gridnode'])['phsfit_c'].mean() unique_points.dropna(how='any', inplace=True) - self.grid_seasonal_absolute_phase_stdev = np.array([np.nan if i[0] not in unique_points.index.get_level_values('gridnode').tolist( - ) else unique_points[i[0]] for i in enumerate(self.gridpoints)]).reshape(self.grid_dim).T + self.grid_seasonal_absolute_phase_stdev = ( + np.array( + [ + np.nan + if i[0] not in unique_points.index.get_level_values('gridnode').tolist() + else unique_points[i[0]] + for i in enumerate(self.gridpoints) + ] + ) + .reshape(self.grid_dim) + .T + ) # If specified, save gridded array(s) if self.grid_to_raster: - gridfile_name = os.path.join(self.workdir, self.col_name + '_' + 'grid_seasonal_absolute_phase_stdev' + '.tif') - save_gridfile(self.grid_seasonal_absolute_phase_stdev, 'grid_seasonal_absolute_phase_stdev', gridfile_name, self.plotbbox, self.spacing, - 'days', colorbarfmt='%.1i', - stationsongrids=self.stationsongrids, - time_lines=self.time_lines, dtype='float32') + gridfile_name = os.path.join( + self.workdir, self.col_name + '_' + 'grid_seasonal_absolute_phase_stdev' + '.tif' + ) + save_gridfile( + self.grid_seasonal_absolute_phase_stdev, + 'grid_seasonal_absolute_phase_stdev', + gridfile_name, + self.plotbbox, + self.spacing, + 'days', + colorbarfmt='%.1i', + stationsongrids=self.stationsongrids, + time_lines=self.time_lines, + dtype='float32', + ) # Pass absolute mean amplitude stdev of all data per gridcell unique_points = self.df.groupby(['gridnode'])['ampfit_c'].mean() unique_points.dropna(how='any', inplace=True) - self.grid_seasonal_absolute_amplitude_stdev = np.array([np.nan if i[0] not in unique_points.index.get_level_values('gridnode').tolist( - ) else unique_points[i[0]] for i in enumerate(self.gridpoints)]).reshape(self.grid_dim).T + self.grid_seasonal_absolute_amplitude_stdev = ( + np.array( + [ + np.nan + if i[0] not in unique_points.index.get_level_values('gridnode').tolist() + else unique_points[i[0]] + for i in enumerate(self.gridpoints) + ] + ) + .reshape(self.grid_dim) + .T + ) # If specified, save gridded array(s) if self.grid_to_raster: - gridfile_name = os.path.join(self.workdir, self.col_name + '_' + 'grid_seasonal_absolute_amplitude_stdev' + '.tif') - save_gridfile(self.grid_seasonal_absolute_amplitude_stdev, 'grid_seasonal_absolute_amplitude_stdev', gridfile_name, self.plotbbox, self.spacing, - self.unit, colorbarfmt='%.3f', - stationsongrids=self.stationsongrids, - time_lines=self.time_lines, dtype='float32') + gridfile_name = os.path.join( + self.workdir, self.col_name + '_' + 'grid_seasonal_absolute_amplitude_stdev' + '.tif' + ) + save_gridfile( + self.grid_seasonal_absolute_amplitude_stdev, + 'grid_seasonal_absolute_amplitude_stdev', + gridfile_name, + self.plotbbox, + self.spacing, + self.unit, + colorbarfmt='%.3f', + stationsongrids=self.stationsongrids, + time_lines=self.time_lines, + dtype='float32', + ) # Pass absolute mean period stdev of all data per gridcell unique_points = self.df.groupby(['gridnode'])['periodfit_c'].mean() unique_points.dropna(how='any', inplace=True) - self.grid_seasonal_absolute_period_stdev = np.array([np.nan if i[0] not in unique_points.index.get_level_values('gridnode').tolist( - ) else unique_points[i[0]] for i in enumerate(self.gridpoints)]).reshape(self.grid_dim).T + self.grid_seasonal_absolute_period_stdev = ( + np.array( + [ + np.nan + if i[0] not in unique_points.index.get_level_values('gridnode').tolist() + else unique_points[i[0]] + for i in enumerate(self.gridpoints) + ] + ) + .reshape(self.grid_dim) + .T + ) # If specified, save gridded array(s) if self.grid_to_raster: - gridfile_name = os.path.join(self.workdir, self.col_name + '_' + 'grid_seasonal_absolute_period_stdev' + '.tif') - save_gridfile(self.grid_seasonal_absolute_period_stdev, 'grid_seasonal_absolute_period_stdev', gridfile_name, self.plotbbox, self.spacing, - 'years', colorbarfmt='%.2e', - stationsongrids=self.stationsongrids, - time_lines=self.time_lines, dtype='float32') + gridfile_name = os.path.join( + self.workdir, self.col_name + '_' + 'grid_seasonal_absolute_period_stdev' + '.tif' + ) + save_gridfile( + self.grid_seasonal_absolute_period_stdev, + 'grid_seasonal_absolute_period_stdev', + gridfile_name, + self.plotbbox, + self.spacing, + 'years', + colorbarfmt='%.2e', + stationsongrids=self.stationsongrids, + time_lines=self.time_lines, + dtype='float32', + ) # Pass absolute mean seasonal fit RMSE of all data per gridcell unique_points = self.df.groupby(['gridnode'])['seasonalfit_rmse'].mean() unique_points.dropna(how='any', inplace=True) - self.grid_seasonal_absolute_fit_rmse = np.array([np.nan if i[0] not in unique_points.index.get_level_values('gridnode').tolist( - ) else unique_points[i[0]] for i in enumerate(self.gridpoints)]).reshape(self.grid_dim).T + self.grid_seasonal_absolute_fit_rmse = ( + np.array( + [ + np.nan + if i[0] not in unique_points.index.get_level_values('gridnode').tolist() + else unique_points[i[0]] + for i in enumerate(self.gridpoints) + ] + ) + .reshape(self.grid_dim) + .T + ) # If specified, save gridded array(s) if self.grid_to_raster: - gridfile_name = os.path.join(self.workdir, self.col_name + '_' + 'grid_seasonal_absolute_fit_rmse' + '.tif') - save_gridfile(self.grid_seasonal_absolute_fit_rmse, 'grid_seasonal_absolute_fit_rmse', gridfile_name, self.plotbbox, self.spacing, - self.unit, colorbarfmt='%.2e', - stationsongrids=self.stationsongrids, - time_lines=self.time_lines, dtype='float32') - - def _amplitude_and_phase(self, station, tt, yy, min_span=2, min_frac=0.6, period_limit=0.): - ''' + gridfile_name = os.path.join( + self.workdir, self.col_name + '_' + 'grid_seasonal_absolute_fit_rmse' + '.tif' + ) + save_gridfile( + self.grid_seasonal_absolute_fit_rmse, + 'grid_seasonal_absolute_fit_rmse', + gridfile_name, + self.plotbbox, + self.spacing, + self.unit, + colorbarfmt='%.2e', + stationsongrids=self.stationsongrids, + time_lines=self.time_lines, + dtype='float32', + ) + + def _amplitude_and_phase(self, station, tt, yy, min_span=2, min_frac=0.6, period_limit=0.0): + """ Fit sin to the input time sequence, and return fitting parameters: "amp", "omega", "phase", "offset", "freq", "period" and "fitfunc". Minimum time span in years (min_span), minimum fractional observations in span (min_frac), and period limit (period_limit) enforced for statistical analysis. - Source: https://stackoverflow.com/questions/16716302/how-do-i-fit-a-sine-curve-to-my-data-with-pylab-and-numpy - ''' + Source: https://stackoverflow.com/questions/16716302/how-do-i-fit-a-sine-curve-to-my-data-with-pylab-and-numpy. + """ ampfit = {} phsfit = {} periodfit = {} @@ -1416,15 +2330,17 @@ def _amplitude_and_phase(self, station, tt, yy, min_span=2, min_frac=0.6, period periodfit_c[station] = np.nan seasonalfit_rmse[station] = np.nan # Fit with custom fit function with fixed period, if specified - if period_limit != 0.: + if period_limit != 0.0: # convert from years to radians/seconds - w = (1 / period_limit) * (1 / 31556952) * (2. * np.pi) + w = (1 / period_limit) * (1 / 31556952) * (2.0 * np.pi) def custom_sine_function_base(t, A, p, c): return self._sine_function_base(t, A, w, p, c) else: + def custom_sine_function_base(t, A, w, p, c): return self._sine_function_base(t, A, w, p, c) + # If station TS does not span specified time period, pass NaNs time_span_yrs = (max(tt) - min(tt)) / 31556952 if time_span_yrs >= min_span and len(list(set(tt))) / (time_span_yrs * 365.25) >= min_frac: @@ -1433,16 +2349,16 @@ def custom_sine_function_base(t, A, w, p, c): ff = np.fft.fftfreq(len(tt), (tt[1] - tt[0])) # assume uniform spacing Fyy = abs(np.fft.fft(yy)) guess_freq = abs(ff[np.argmax(Fyy[1:]) + 1]) # excluding the zero period "peak", which is related to offset - guess_amp = np.std(yy) * 2.**0.5 + guess_amp = np.std(yy) * 2.0**0.5 guess_offset = np.mean(yy) - guess = np.array([guess_amp, 2. * np.pi * guess_freq, 0., guess_offset]) + guess = np.array([guess_amp, 2.0 * np.pi * guess_freq, 0.0, guess_offset]) # Adjust frequency guess to reflect fixed period, if specified - if period_limit != 0.: - guess = np.array([guess_amp, 0., guess_offset]) + if period_limit != 0.0: + guess = np.array([guess_amp, 0.0, guess_offset]) # Catch warning where covariance cannot be estimated # I.e. OptimizeWarning: Covariance of the parameters could not be estimated with warnings.catch_warnings(): - warnings.simplefilter("error", OptimizeWarning) + warnings.simplefilter('error', OptimizeWarning) try: optimize_warning = False try: @@ -1450,29 +2366,46 @@ def custom_sine_function_base(t, A, w, p, c): popt, pcov = optimize.curve_fit(custom_sine_function_base, tt, yy, p0=guess, maxfev=int(1e6)) # If sparse input such that fittitng is not possible, pass NaNs except TypeError: - self.ampfit.append(np.nan), self.phsfit.append(np.nan), self.periodfit.append(np.nan), \ - self.ampfit_c.append(np.nan), self.phsfit_c.append(np.nan), \ - self.periodfit_c.append(np.nan), self.seasonalfit_rmse.append(np.nan) - return self.ampfit, self.phsfit, self.periodfit, self.ampfit_c, \ - self.phsfit_c, self.periodfit_c, self.seasonalfit_rmse + ( + self.ampfit.append(np.nan), + self.phsfit.append(np.nan), + self.periodfit.append(np.nan), + self.ampfit_c.append(np.nan), + self.phsfit_c.append(np.nan), + self.periodfit_c.append(np.nan), + self.seasonalfit_rmse.append(np.nan), + ) + return ( + self.ampfit, + self.phsfit, + self.periodfit, + self.ampfit_c, + self.phsfit_c, + self.periodfit_c, + self.seasonalfit_rmse, + ) except OptimizeWarning: optimize_warning = True - warnings.simplefilter("ignore", OptimizeWarning) + warnings.simplefilter('ignore', OptimizeWarning) popt, pcov = optimize.curve_fit(custom_sine_function_base, tt, yy, p0=guess, maxfev=int(1e6)) - print('OptimizeWarning: Covariance for station {} could not be estimated. Refer to debug figure here {} \ - '.format(station, os.path.join(self.workdir, 'phaseamp_per_station', 'station{}.png'.format(station)))) + debug_figure_path = os.path.join(self.workdir, 'phaseamp_per_station', f'station{station}.png') + print( + f'OptimizeWarning: Covariance for station {station} could not be estimated. ' + f'Refer to debug figure here {debug_figure_path}' + ) pass # Adjust expected output to reflect fixed period, if specified - if period_limit != 0.: + if period_limit != 0.0: A, p, c = popt else: A, w, p, c = popt # convert from radians/seconds to years - f = (w / (2. * np.pi)) * (31556952) + f = (w / (2.0 * np.pi)) * (31556952) f = 1 / f def fitfunc(t): return A * np.sin(w * t + p) + c + # Outputs = "amp": A, "angular frequency": w, "phase": p, "offset": c, "freq": f, "period": 1./f, # "fitfunc": fitfunc, "maxcov": np.max(pcov), "rawres": (guess,popt,pcov) # Pass amplitude (specified units) and phase (days) and stdev @@ -1487,13 +2420,14 @@ def fitfunc(t): with np.errstate(invalid='raise'): try: # pass covariance for each parameter - ampfit_c[station] = pcov[0, 0]**0.5 - periodfit_c[station] = pcov[1, 1]**0.5 - phsfit_c[station] = pcov[2, 2]**0.5 + ampfit_c[station] = pcov[0, 0] ** 0.5 + periodfit_c[station] = pcov[1, 1] ** 0.5 + phsfit_c[station] = pcov[2, 2] ** 0.5 # pass RMSE of fit seasonalfit_rmse[station] = yy - custom_sine_function_base(tt, *popt) - seasonalfit_rmse[station] = (np.sum(seasonalfit_rmse[station]**2) / - (seasonalfit_rmse[station].size - 2))**0.5 + seasonalfit_rmse[station] = ( + np.sum(seasonalfit_rmse[station] ** 2) / (seasonalfit_rmse[station].size - 2) + ) ** 0.5 except FloatingPointError: pass if self.phaseamp_per_station or optimize_warning: @@ -1502,9 +2436,9 @@ def fitfunc(t): tt_plot = copy.deepcopy(tt) tt_plot -= min(tt_plot) tt_plot /= 31556952 - plt.plot(tt_plot, yy, "ok", label="input") - plt.xlabel("time (years)") - plt.ylabel("data ({})".format(self.unit)) + plt.plot(tt_plot, yy, 'ok', label='input') + plt.xlabel('time (years)') + plt.ylabel(f'data ({self.unit})') num_testpoints = len(tt) * 10 if num_testpoints > 1000: num_testpoints = 1000 @@ -1513,12 +2447,15 @@ def fitfunc(t): tt2_plot = copy.deepcopy(tt2) tt2_plot -= min(tt2_plot) tt2_plot /= 31556952 - plt.plot(tt2_plot, fitfunc(tt2), "r-", label="fit", linewidth=2) - plt.legend(loc="best") + plt.plot(tt2_plot, fitfunc(tt2), 'r-', label='fit', linewidth=2) + plt.legend(loc='best') if not os.path.exists(os.path.join(self.workdir, 'phaseamp_per_station')): os.mkdir(os.path.join(self.workdir, 'phaseamp_per_station')) - plt.savefig(os.path.join(self.workdir, 'phaseamp_per_station', 'station{}.png'.format(station)), - format='png', bbox_inches='tight') + plt.savefig( + os.path.join(self.workdir, 'phaseamp_per_station', f'station{station}.png'), + format='png', + bbox_inches='tight', + ) plt.close() optimize_warning = False @@ -1530,19 +2467,33 @@ def fitfunc(t): self.periodfit_c.append(periodfit_c) self.seasonalfit_rmse.append(seasonalfit_rmse) - return self.ampfit, self.phsfit, self.periodfit, self.ampfit_c, \ - self.phsfit_c, self.periodfit_c, self.seasonalfit_rmse + return ( + self.ampfit, + self.phsfit, + self.periodfit, + self.ampfit_c, + self.phsfit_c, + self.periodfit_c, + self.seasonalfit_rmse, + ) def _sine_function_base(self, t, A, w, p, c): - ''' - Base function for modeling sinusoidal amplitude/phase fits. - ''' + """Base function for modeling sinusoidal amplitude/phase fits.""" return A * np.sin(w * t + p) + c - def __call__(self, gridarr, plottype, workdir='./', drawgridlines=False, colorbarfmt='%.2f', stationsongrids=None, resValue=5, plotFormat='pdf', userTitle=None): - ''' - Visualize a suite of statistics w.r.t. stations. Pass either a list of points or a gridded array as the first argument. Alternatively, you may superimpose your gridded array with a supplementary list of points by passing the latter through the stationsongrids argument. - ''' + def __call__( + self, + gridarr, + plottype, + workdir='./', + drawgridlines=False, + colorbarfmt='%.2f', + stationsongrids=None, + resValue=5, + plotFormat='pdf', + userTitle=None, + ): + """Visualize a suite of statistics w.r.t. stations. Pass either a list of points or a gridded array as the first argument. Alternatively, you may superimpose your gridded array with a supplementary list of points by passing the latter through the stationsongrids argument.""" from cartopy import crs as ccrs from cartopy import feature as cfeature from cartopy.mpl.ticker import LatitudeFormatter, LongitudeFormatter @@ -1563,53 +2514,46 @@ def __call__(self, gridarr, plottype, workdir='./', drawgridlines=False, colorba fig, axes = plt.subplots(subplot_kw={'projection': ccrs.PlateCarree()}) # by default set background to white - axes.add_feature(cfeature.NaturalEarthFeature( - 'physical', 'land', '50m', facecolor='white'), zorder=0) + axes.add_feature(cfeature.NaturalEarthFeature('physical', 'land', '50m', facecolor='white'), zorder=0) axes.set_extent(self.plotbbox, ccrs.PlateCarree()) # add coastlines - axes.coastlines(linewidth=0.2, color="gray", zorder=4) + axes.coastlines(linewidth=0.2, color='gray', zorder=4) cmap = copy.copy(mpl.cm.get_cmap(self.usr_colormap)) # cmap.set_bad('black', 0.) # extract all colors from the hot map cmaplist = [cmap(i) for i in range(cmap.N)] # create the new map - cmap = mpl.colors.LinearSegmentedColormap.from_list( - 'Custom cmap', cmaplist) + cmap = mpl.colors.LinearSegmentedColormap.from_list('Custom cmap', cmaplist) axes.set_xlabel('Longitude', weight='bold', zorder=2) axes.set_ylabel('Latitude', weight='bold', zorder=2) # set ticks - axes.set_xticks(np.linspace( - self.plotbbox[0], self.plotbbox[1], 5), crs=ccrs.PlateCarree()) - axes.set_yticks(np.linspace( - self.plotbbox[2], self.plotbbox[3], 5), crs=ccrs.PlateCarree()) - lon_formatter = LongitudeFormatter( - number_format='.0f', degree_symbol='') - lat_formatter = LatitudeFormatter( - number_format='.0f', degree_symbol='') + axes.set_xticks(np.linspace(self.plotbbox[0], self.plotbbox[1], 5), crs=ccrs.PlateCarree()) + axes.set_yticks(np.linspace(self.plotbbox[2], self.plotbbox[3], 5), crs=ccrs.PlateCarree()) + lon_formatter = LongitudeFormatter(number_format='.0f', degree_symbol='') + lat_formatter = LatitudeFormatter(number_format='.0f', degree_symbol='') axes.xaxis.set_major_formatter(lon_formatter) axes.yaxis.set_major_formatter(lat_formatter) # draw central longitude lines corresponding to respective datetimes if self.time_lines: - tl = axes.grid(axis='x', linewidth=1.5, - color='blue', alpha=0.5, linestyle='-', - zorder=3) + tl = axes.grid(axis='x', linewidth=1.5, color='blue', alpha=0.5, linestyle='-', zorder=3) # If individual stations passed if isinstance(gridarr, list): # spatial distribution of stations - if plottype == "station_distribution": - im = axes.scatter(gridarr[0], gridarr[1], zorder=1, s=0.5, - marker='.', color='b', transform=ccrs.PlateCarree()) + if plottype == 'station_distribution': + im = axes.scatter( + gridarr[0], gridarr[1], zorder=1, s=0.5, marker='.', color='b', transform=ccrs.PlateCarree() + ) # passing 3rd column as z-value if len(gridarr) > 2: # set land/water background to light gray/blue respectively so station point data can be seen - axes.add_feature(cfeature.NaturalEarthFeature( - 'physical', 'land', '50m', facecolor='#A9A9A9'), zorder=0) - axes.add_feature(cfeature.NaturalEarthFeature( - 'physical', 'ocean', '50m', facecolor='#ADD8E6'), zorder=0) + axes.add_feature(cfeature.NaturalEarthFeature('physical', 'land', '50m', facecolor='#A9A9A9'), zorder=0) + axes.add_feature( + cfeature.NaturalEarthFeature('physical', 'ocean', '50m', facecolor='#ADD8E6'), zorder=0 + ) # set masked values as nans zvalues = gridarr[2] for i in nodat_arr: @@ -1618,15 +2562,18 @@ def __call__(self, gridarr, plottype, workdir='./', drawgridlines=False, colorba # define the bins and normalize if cbounds is None: # avoid "ufunc 'isnan'" error by casting array as float - cbounds = [np.nanpercentile(zvalues.astype('float'), self.colorpercentile[0]), np.nanpercentile( - zvalues.astype('float'), self.colorpercentile[1])] + cbounds = [ + np.nanpercentile(zvalues.astype('float'), self.colorpercentile[0]), + np.nanpercentile(zvalues.astype('float'), self.colorpercentile[1]), + ] # if upper/lower bounds identical, overwrite lower bound as 75% of upper bound to avoid plotting ValueError if cbounds[0] == cbounds[1]: cbounds[0] *= 0.75 cbounds.sort() # adjust precision for colorbar if necessary - if (abs(np.nanmax(zvalues) - np.nanmin(zvalues)) < 1 and (np.nanmean(zvalues)) < 1) \ - or abs(np.nanmax(zvalues) - np.nanmin(zvalues)) > 500: + if (abs(np.nanmax(zvalues) - np.nanmin(zvalues)) < 1 and (np.nanmean(zvalues)) < 1) or abs( + np.nanmax(zvalues) - np.nanmin(zvalues) + ) > 500: colorbarfmt = '%.2e' colorbounds = np.linspace(cbounds[0], cbounds[1], 256) @@ -1635,13 +2582,29 @@ def __call__(self, gridarr, plottype, workdir='./', drawgridlines=False, colorba colorbounds_ticks = np.linspace(cbounds[0], cbounds[1], 10) # plot data and initiate colorbar - im = axes.scatter(gridarr[0], gridarr[1], c=zvalues, cmap=cmap, norm=norm, - zorder=1, s=0.5, marker='.', transform=ccrs.PlateCarree()) + im = axes.scatter( + gridarr[0], + gridarr[1], + c=zvalues, + cmap=cmap, + norm=norm, + zorder=1, + s=0.5, + marker='.', + transform=ccrs.PlateCarree(), + ) # initiate colorbar and control height of colorbar divider = make_axes_locatable(axes) - cax = divider.append_axes("right", size="5%", pad=0.05, axes_class=plt.Axes) - cbar_ax = fig.colorbar(im, spacing='proportional', - ticks=colorbounds_ticks, boundaries=colorbounds, format=colorbarfmt, pad=0.1, cax=cax) + cax = divider.append_axes('right', size='5%', pad=0.05, axes_class=plt.Axes) + cbar_ax = fig.colorbar( + im, + spacing='proportional', + ticks=colorbounds_ticks, + boundaries=colorbounds, + format=colorbarfmt, + pad=0.1, + cax=cax, + ) cbar_ax.ax.minorticks_off() # If gridded area passed @@ -1651,21 +2614,22 @@ def __call__(self, gridarr, plottype, workdir='./', drawgridlines=False, colorba gridarr = np.ma.masked_where(gridarr == i, gridarr) gridarr = np.ma.filled(gridarr, np.nan) # set land/water background to light gray/blue respectively so grid cells can be seen - axes.add_feature(cfeature.NaturalEarthFeature( - 'physical', 'land', '50m', facecolor='#A9A9A9'), zorder=0) - axes.add_feature(cfeature.NaturalEarthFeature( - 'physical', 'ocean', '50m', facecolor='#ADD8E6'), zorder=0) + axes.add_feature(cfeature.NaturalEarthFeature('physical', 'land', '50m', facecolor='#A9A9A9'), zorder=0) + axes.add_feature(cfeature.NaturalEarthFeature('physical', 'ocean', '50m', facecolor='#ADD8E6'), zorder=0) # define the bins and normalize if cbounds is None: - cbounds = [np.nanpercentile(gridarr, self.colorpercentile[0]), np.nanpercentile( - gridarr, self.colorpercentile[1])] + cbounds = [ + np.nanpercentile(gridarr, self.colorpercentile[0]), + np.nanpercentile(gridarr, self.colorpercentile[1]), + ] # if upper/lower bounds identical, overwrite lower bound as 75% of upper bound to avoid plotting ValueError if cbounds[0] == cbounds[1]: cbounds[0] *= 0.75 cbounds.sort() # plot data and initiate colorbar - if (abs(np.nanmax(gridarr) - np.nanmin(gridarr)) < 1 and abs(np.nanmean(gridarr)) < 1) \ - or abs(np.nanmax(gridarr) - np.nanmin(gridarr)) > 500: + if (abs(np.nanmax(gridarr) - np.nanmin(gridarr)) < 1 and abs(np.nanmean(gridarr)) < 1) or abs( + np.nanmax(gridarr) - np.nanmin(gridarr) + ) > 500: colorbarfmt = '%.2e' colorbounds = np.linspace(cbounds[0], cbounds[1], 256) @@ -1674,65 +2638,127 @@ def __call__(self, gridarr, plottype, workdir='./', drawgridlines=False, colorba colorbounds_ticks = np.linspace(cbounds[0], cbounds[1], 10) # plot data - im = axes.imshow(gridarr, cmap=cmap, norm=norm, extent=self.plotbbox, - zorder=1, origin='upper', transform=ccrs.PlateCarree()) + im = axes.imshow( + gridarr, + cmap=cmap, + norm=norm, + extent=self.plotbbox, + zorder=1, + origin='upper', + transform=ccrs.PlateCarree(), + ) # initiate colorbar and control height of colorbar divider = make_axes_locatable(axes) - cax = divider.append_axes("right", size="5%", pad=0.05, axes_class=plt.Axes) - cbar_ax = fig.colorbar(im, spacing='proportional', ticks=colorbounds_ticks, - boundaries=colorbounds, format=colorbarfmt, pad=0.1, cax=cax) + cax = divider.append_axes('right', size='5%', pad=0.05, axes_class=plt.Axes) + cbar_ax = fig.colorbar( + im, + spacing='proportional', + ticks=colorbounds_ticks, + boundaries=colorbounds, + format=colorbarfmt, + pad=0.1, + cax=cax, + ) cbar_ax.ax.minorticks_off() # superimpose your gridded array with a supplementary list of point, if specified if self.stationsongrids: - axes.scatter(self.stationsongrids[0], self.stationsongrids[1], zorder=2, - s=0.5, marker='.', color='b', transform=ccrs.PlateCarree()) + axes.scatter( + self.stationsongrids[0], + self.stationsongrids[1], + zorder=2, + s=0.5, + marker='.', + color='b', + transform=ccrs.PlateCarree(), + ) # draw gridlines, if specified if drawgridlines: - gl = axes.gridlines(crs=ccrs.PlateCarree( - ), linewidth=0.5, color='black', alpha=0.5, linestyle='-', zorder=3) - gl.xlocator = mticker.FixedLocator(np.arange( - self.plotbbox[0], self.plotbbox[1] + self.spacing, self.spacing).tolist()) - gl.ylocator = mticker.FixedLocator(np.arange( - self.plotbbox[2], self.plotbbox[3] + self.spacing, self.spacing).tolist()) + gl = axes.gridlines( + crs=ccrs.PlateCarree(), linewidth=0.5, color='black', alpha=0.5, linestyle='-', zorder=3 + ) + gl.xlocator = mticker.FixedLocator( + np.arange(self.plotbbox[0], self.plotbbox[1] + self.spacing, self.spacing).tolist() + ) + gl.ylocator = mticker.FixedLocator( + np.arange(self.plotbbox[2], self.plotbbox[3] + self.spacing, self.spacing).tolist() + ) # Add labels to colorbar, if necessary if 'cbar_ax' in locals(): # experimental variogram fit sill heatmap - if plottype == "grid_variance": - cbar_ax.set_label(" ".join(plottype.replace('grid_', '').split('_')).title() + ' ({}\u00b2)'.format(self.unit), - rotation=-90, labelpad=10) + if plottype == 'grid_variance': + cbar_ax.set_label( + ' '.join(plottype.replace('grid_', '').split('_')).title() + f' ({self.unit}\u00b2)', + rotation=-90, + labelpad=10, + ) # specify appropriate units for mean/median/std/amplitude/experimental variogram fit heatmap - elif plottype == "grid_delay_mean" or plottype == "grid_delay_median" or plottype == "grid_delay_stdev" or \ - plottype == "grid_seasonal_amplitude" or plottype == "grid_range" or plottype == "station_delay_mean" or \ - plottype == "station_delay_median" or plottype == "station_delay_stdev" or \ - plottype == "station_seasonal_amplitude" or plottype == "grid_delay_absolute_mean" or \ - plottype == "grid_delay_absolute_median" or plottype == "grid_delay_absolute_stdev" or \ - plottype == "grid_seasonal_absolute_amplitude" or plottype == "grid_seasonal_amplitude_stdev" or \ - plottype == "grid_seasonal_absolute_amplitude_stdev" or plottype == "grid_seasonal_fit_rmse" or \ - plottype == "grid_seasonal_absolute_fit_rmse" or plottype == "grid_variogram_rmse": + elif ( + plottype == 'grid_delay_mean' + or plottype == 'grid_delay_median' + or plottype == 'grid_delay_stdev' + or plottype == 'grid_seasonal_amplitude' + or plottype == 'grid_range' + or plottype == 'station_delay_mean' + or plottype == 'station_delay_median' + or plottype == 'station_delay_stdev' + or plottype == 'station_seasonal_amplitude' + or plottype == 'grid_delay_absolute_mean' + or plottype == 'grid_delay_absolute_median' + or plottype == 'grid_delay_absolute_stdev' + or plottype == 'grid_seasonal_absolute_amplitude' + or plottype == 'grid_seasonal_amplitude_stdev' + or plottype == 'grid_seasonal_absolute_amplitude_stdev' + or plottype == 'grid_seasonal_fit_rmse' + or plottype == 'grid_seasonal_absolute_fit_rmse' + or plottype == 'grid_variogram_rmse' + ): # update label if sigZTD if 'sig' in self.col_name: - cbar_ax.set_label("sig ZTD " + " ".join(plottype.replace('grid_', - '').replace('delay_', '').split('_')).title() + ' ({})'.format(self.unit), - rotation=-90, labelpad=10) + cbar_ax.set_label( + 'sig ZTD ' + + ' '.join(plottype.replace('grid_', '').replace('delay_', '').split('_')).title() + + f' ({self.unit})', + rotation=-90, + labelpad=10, + ) else: - cbar_ax.set_label(" ".join(plottype.replace('grid_', '').split('_')).title() + ' ({})'.format(self.unit), - rotation=-90, labelpad=10) + cbar_ax.set_label( + ' '.join(plottype.replace('grid_', '').split('_')).title() + f' ({self.unit})', + rotation=-90, + labelpad=10, + ) # specify appropriate units for phase heatmap (days) - elif plottype == "station_seasonal_phase" or plottype == "grid_seasonal_phase" or plottype == "grid_seasonal_absolute_phase" or \ - plottype == "grid_seasonal_absolute_phase_stdev" or plottype == "grid_seasonal_phase_stdev": - cbar_ax.set_label(" ".join(plottype.replace('grid_', '').split('_')).title() + ' ({})'.format('days'), - rotation=-90, labelpad=10) + elif ( + plottype == 'station_seasonal_phase' + or plottype == 'grid_seasonal_phase' + or plottype == 'grid_seasonal_absolute_phase' + or plottype == 'grid_seasonal_absolute_phase_stdev' + or plottype == 'grid_seasonal_phase_stdev' + ): + cbar_ax.set_label( + ' '.join(plottype.replace('grid_', '').split('_')).title() + ' (days)', + rotation=-90, + labelpad=10, + ) # specify appropriate units for period heatmap (years) - elif plottype == "station_delay_period" or plottype == "grid_seasonal_period" or plottype == "grid_seasonal_absolute_period" or \ - plottype == "grid_seasonal_absolute_period_stdev" or plottype == "grid_seasonal_period_stdev": - cbar_ax.set_label(" ".join(plottype.replace('grid_', '').split('_')).title() + ' ({})'.format('years'), - rotation=-90, labelpad=10) + elif ( + plottype == 'station_delay_period' + or plottype == 'grid_seasonal_period' + or plottype == 'grid_seasonal_absolute_period' + or plottype == 'grid_seasonal_absolute_period_stdev' + or plottype == 'grid_seasonal_period_stdev' + ): + cbar_ax.set_label( + ' '.join(plottype.replace('grid_', '').split('_')).title() + ' (years)', + rotation=-90, + labelpad=10, + ) # gridmap of station density has no units else: - cbar_ax.set_label(" ".join(plottype.replace('grid_', '').split('_')).title(), rotation=-90, labelpad=10) + cbar_ax.set_label(' '.join(plottype.replace('grid_', '').split('_')).title(), rotation=-90, labelpad=10) # Add title to plots, if specified if userTitle: @@ -1741,13 +2767,14 @@ def __call__(self, gridarr, plottype, workdir='./', drawgridlines=False, colorba # save/close figure # cbar_ax.ax.locator_params(nbins=10) # for label in cbar_ax.ax.xaxis.get_ticklabels()[::25]: - # label.set_visible(False) - plt.savefig(os.path.join(workdir, self.col_name + '_' + plottype + '.' + plotFormat), - format=plotFormat, bbox_inches='tight') + # label.set_visible(False) + plt.savefig( + os.path.join(workdir, self.col_name + '_' + plottype + '.' + plotFormat), + format=plotFormat, + bbox_inches='tight', + ) plt.close() - return - def stats_analyses( fname, @@ -1793,12 +2820,12 @@ def stats_analyses( variogramplot, binnedvariogram, variogram_per_timeslice, - variogram_errlimit -): - ''' + variogram_errlimit, +) -> None: + """ Main workflow for generating a suite of plots to illustrate spatiotemporal distribution - and/or character of zenith delays - ''' + and/or character of zenith delays. + """ if verbose: logger.setLevel(logging.DEBUG) @@ -1824,239 +2851,575 @@ def stats_analyses( grid_seasonal_absolute_phase = True variogramplot = True - logger.info("***Stats Function:***") + logger.info('***Stats Function:***') # prep dataframe object for plotting/variogram analysis based off of user specifications - df_stats = RaiderStats(fname, col_name, unit, workdir, bbox, spacing, - timeinterval, seasonalinterval, obs_errlimit, time_lines, stationsongrids, station_seasonal_phase, cbounds, colorpercentile, - usr_colormap, grid_heatmap, grid_delay_mean, grid_delay_median, grid_delay_stdev, grid_seasonal_phase, - grid_delay_absolute_mean, grid_delay_absolute_median, grid_delay_absolute_stdev, - grid_seasonal_absolute_phase, grid_to_raster, min_span, period_limit, numCPUs, phaseamp_per_station) + df_stats = RaiderStats( + fname, + col_name, + unit, + workdir, + bbox, + spacing, + timeinterval, + seasonalinterval, + obs_errlimit, + time_lines, + stationsongrids, + station_seasonal_phase, + cbounds, + colorpercentile, + usr_colormap, + grid_heatmap, + grid_delay_mean, + grid_delay_median, + grid_delay_stdev, + grid_seasonal_phase, + grid_delay_absolute_mean, + grid_delay_absolute_median, + grid_delay_absolute_stdev, + grid_seasonal_absolute_phase, + grid_to_raster, + min_span, + period_limit, + numCPUs, + phaseamp_per_station, + ) # Station plots # Plot each individual station if station_distribution: - logger.info("- Plot spatial distribution of stations.") + logger.info('- Plot spatial distribution of stations.') unique_points = df_stats.df.groupby(['Lon', 'Lat']).size() - df_stats([unique_points.index.get_level_values('Lon').tolist(), unique_points.index.get_level_values('Lat').tolist( - )], 'station_distribution', workdir=os.path.join(workdir, 'figures'), plotFormat=plot_fmt, userTitle=user_title) + df_stats( + [ + unique_points.index.get_level_values('Lon').tolist(), + unique_points.index.get_level_values('Lat').tolist(), + ], + 'station_distribution', + workdir=os.path.join(workdir, 'figures'), + plotFormat=plot_fmt, + userTitle=user_title, + ) # Plot mean delay per station if station_delay_mean: - logger.info("- Plot mean delay for each station.") - unique_points = df_stats.df.groupby( - ['Lon', 'Lat'])[col_name].median() + logger.info('- Plot mean delay for each station.') + unique_points = df_stats.df.groupby(['Lon', 'Lat'])[col_name].median() unique_points.dropna(how='any', inplace=True) - df_stats([unique_points.index.get_level_values('Lon').tolist(), unique_points.index.get_level_values('Lat').tolist( - ), unique_points.values], 'station_delay_mean', workdir=os.path.join(workdir, 'figures'), plotFormat=plot_fmt, userTitle=user_title) + df_stats( + [ + unique_points.index.get_level_values('Lon').tolist(), + unique_points.index.get_level_values('Lat').tolist(), + unique_points.values, + ], + 'station_delay_mean', + workdir=os.path.join(workdir, 'figures'), + plotFormat=plot_fmt, + userTitle=user_title, + ) # Plot median delay per station if station_delay_median: - logger.info("- Plot median delay for each station.") - unique_points = df_stats.df.groupby( - ['Lon', 'Lat'])[col_name].mean() + logger.info('- Plot median delay for each station.') + unique_points = df_stats.df.groupby(['Lon', 'Lat'])[col_name].mean() unique_points.dropna(how='any', inplace=True) - df_stats([unique_points.index.get_level_values('Lon').tolist(), unique_points.index.get_level_values('Lat').tolist( - ), unique_points.values], 'station_delay_median', workdir=os.path.join(workdir, 'figures'), plotFormat=plot_fmt, userTitle=user_title) + df_stats( + [ + unique_points.index.get_level_values('Lon').tolist(), + unique_points.index.get_level_values('Lat').tolist(), + unique_points.values, + ], + 'station_delay_median', + workdir=os.path.join(workdir, 'figures'), + plotFormat=plot_fmt, + userTitle=user_title, + ) # Plot delay stdev per station if station_delay_stdev: - logger.info("- Plot delay stdev for each station.") - unique_points = df_stats.df.groupby( - ['Lon', 'Lat'])[col_name].std() + logger.info('- Plot delay stdev for each station.') + unique_points = df_stats.df.groupby(['Lon', 'Lat'])[col_name].std() unique_points.dropna(how='any', inplace=True) - df_stats([unique_points.index.get_level_values('Lon').tolist(), unique_points.index.get_level_values('Lat').tolist( - ), unique_points.values], 'station_delay_stdev', workdir=os.path.join(workdir, 'figures'), plotFormat=plot_fmt, userTitle=user_title) + df_stats( + [ + unique_points.index.get_level_values('Lon').tolist(), + unique_points.index.get_level_values('Lat').tolist(), + unique_points.values, + ], + 'station_delay_stdev', + workdir=os.path.join(workdir, 'figures'), + plotFormat=plot_fmt, + userTitle=user_title, + ) # Plot delay phase/amplitude per station if station_seasonal_phase: - logger.info("- Plot delay phase/amplitude for each station.") + logger.info('- Plot delay phase/amplitude for each station.') # phase - unique_points_phase = df_stats.df.groupby( - ['Lon', 'Lat'])['phsfit'].mean() + unique_points_phase = df_stats.df.groupby(['Lon', 'Lat'])['phsfit'].mean() unique_points_phase.dropna(how='any', inplace=True) - df_stats([unique_points_phase.index.get_level_values('Lon').tolist(), unique_points_phase.index.get_level_values('Lat').tolist( - ), unique_points_phase.values], 'station_seasonal_phase', workdir=os.path.join(workdir, 'figures'), - colorbarfmt='%.1i', plotFormat=plot_fmt, userTitle=user_title) + df_stats( + [ + unique_points_phase.index.get_level_values('Lon').tolist(), + unique_points_phase.index.get_level_values('Lat').tolist(), + unique_points_phase.values, + ], + 'station_seasonal_phase', + workdir=os.path.join(workdir, 'figures'), + colorbarfmt='%.1i', + plotFormat=plot_fmt, + userTitle=user_title, + ) # amplitude - unique_points_amplitude = df_stats.df.groupby( - ['Lon', 'Lat'])['ampfit'].mean() + unique_points_amplitude = df_stats.df.groupby(['Lon', 'Lat'])['ampfit'].mean() unique_points_amplitude.dropna(how='any', inplace=True) - df_stats([unique_points_amplitude.index.get_level_values('Lon').tolist(), unique_points_amplitude.index.get_level_values('Lat').tolist( - ), unique_points_amplitude.values], 'station_seasonal_amplitude', workdir=os.path.join(workdir, 'figures'), - colorbarfmt='%.3f', plotFormat=plot_fmt, userTitle=user_title) + df_stats( + [ + unique_points_amplitude.index.get_level_values('Lon').tolist(), + unique_points_amplitude.index.get_level_values('Lat').tolist(), + unique_points_amplitude.values, + ], + 'station_seasonal_amplitude', + workdir=os.path.join(workdir, 'figures'), + colorbarfmt='%.3f', + plotFormat=plot_fmt, + userTitle=user_title, + ) # period - unique_points_period = df_stats.df.groupby( - ['Lon', 'Lat'])['periodfit'].mean() - df_stats([unique_points_period.index.get_level_values('Lon').tolist(), unique_points_period.index.get_level_values('Lat').tolist( - ), unique_points_period.values], 'station_delay_period', workdir=os.path.join(workdir, 'figures'), - colorbarfmt='%.2f', plotFormat=plot_fmt, userTitle=user_title) + unique_points_period = df_stats.df.groupby(['Lon', 'Lat'])['periodfit'].mean() + df_stats( + [ + unique_points_period.index.get_level_values('Lon').tolist(), + unique_points_period.index.get_level_values('Lat').tolist(), + unique_points_period.values, + ], + 'station_delay_period', + workdir=os.path.join(workdir, 'figures'), + colorbarfmt='%.2f', + plotFormat=plot_fmt, + userTitle=user_title, + ) # Gridded station plots # Plot density of stations for each gridcell if isinstance(df_stats.grid_heatmap, np.ndarray): - logger.info("- Plot density of stations per gridcell.") - df_stats(df_stats.grid_heatmap, 'grid_heatmap', workdir=os.path.join(workdir, 'figures'), drawgridlines=drawgridlines, - colorbarfmt='%.1i', stationsongrids=stationsongrids, plotFormat=plot_fmt, userTitle=user_title) + logger.info('- Plot density of stations per gridcell.') + df_stats( + df_stats.grid_heatmap, + 'grid_heatmap', + workdir=os.path.join(workdir, 'figures'), + drawgridlines=drawgridlines, + colorbarfmt='%.1i', + stationsongrids=stationsongrids, + plotFormat=plot_fmt, + userTitle=user_title, + ) # Plot mean of station-wise mean delay across each gridcell if isinstance(df_stats.grid_delay_mean, np.ndarray): - logger.info("- Plot mean of station-wise mean delay across each gridcell.") - df_stats(df_stats.grid_delay_mean, 'grid_delay_mean', workdir=os.path.join(workdir, 'figures'), - drawgridlines=drawgridlines, colorbarfmt='%.2f', stationsongrids=stationsongrids, plotFormat=plot_fmt, userTitle=user_title) + logger.info('- Plot mean of station-wise mean delay across each gridcell.') + df_stats( + df_stats.grid_delay_mean, + 'grid_delay_mean', + workdir=os.path.join(workdir, 'figures'), + drawgridlines=drawgridlines, + colorbarfmt='%.2f', + stationsongrids=stationsongrids, + plotFormat=plot_fmt, + userTitle=user_title, + ) # Plot mean of station-wise median delay across each gridcell if isinstance(df_stats.grid_delay_median, np.ndarray): - logger.info("- Plot mean of station-wise median delay across each gridcell.") - df_stats(df_stats.grid_delay_median, 'grid_delay_median', workdir=os.path.join(workdir, 'figures'), - drawgridlines=drawgridlines, colorbarfmt='%.2f', stationsongrids=stationsongrids, plotFormat=plot_fmt, userTitle=user_title) + logger.info('- Plot mean of station-wise median delay across each gridcell.') + df_stats( + df_stats.grid_delay_median, + 'grid_delay_median', + workdir=os.path.join(workdir, 'figures'), + drawgridlines=drawgridlines, + colorbarfmt='%.2f', + stationsongrids=stationsongrids, + plotFormat=plot_fmt, + userTitle=user_title, + ) # Plot mean of station-wise stdev delay across each gridcell if isinstance(df_stats.grid_delay_stdev, np.ndarray): - logger.info("- Plot mean of station-wise stdev delay across each gridcell.") - df_stats(df_stats.grid_delay_stdev, 'grid_delay_stdev', workdir=os.path.join(workdir, 'figures'), - drawgridlines=drawgridlines, colorbarfmt='%.2f', stationsongrids=stationsongrids, plotFormat=plot_fmt, userTitle=user_title) + logger.info('- Plot mean of station-wise stdev delay across each gridcell.') + df_stats( + df_stats.grid_delay_stdev, + 'grid_delay_stdev', + workdir=os.path.join(workdir, 'figures'), + drawgridlines=drawgridlines, + colorbarfmt='%.2f', + stationsongrids=stationsongrids, + plotFormat=plot_fmt, + userTitle=user_title, + ) # Plot mean of station-wise delay phase across each gridcell if isinstance(df_stats.grid_seasonal_phase, np.ndarray): - logger.info("- Plot mean of station-wise delay phase across each gridcell.") - df_stats(df_stats.grid_seasonal_phase, 'grid_seasonal_phase', workdir=os.path.join(workdir, 'figures'), - drawgridlines=drawgridlines, colorbarfmt='%.1i', stationsongrids=stationsongrids, plotFormat=plot_fmt, userTitle=user_title) + logger.info('- Plot mean of station-wise delay phase across each gridcell.') + df_stats( + df_stats.grid_seasonal_phase, + 'grid_seasonal_phase', + workdir=os.path.join(workdir, 'figures'), + drawgridlines=drawgridlines, + colorbarfmt='%.1i', + stationsongrids=stationsongrids, + plotFormat=plot_fmt, + userTitle=user_title, + ) # Plot mean of station-wise delay amplitude across each gridcell if isinstance(df_stats.grid_seasonal_amplitude, np.ndarray): - logger.info("- Plot mean of station-wise delay amplitude across each gridcell.") - df_stats(df_stats.grid_seasonal_amplitude, 'grid_seasonal_amplitude', workdir=os.path.join(workdir, 'figures'), - drawgridlines=drawgridlines, colorbarfmt='%.3f', stationsongrids=stationsongrids, plotFormat=plot_fmt, userTitle=user_title) + logger.info('- Plot mean of station-wise delay amplitude across each gridcell.') + df_stats( + df_stats.grid_seasonal_amplitude, + 'grid_seasonal_amplitude', + workdir=os.path.join(workdir, 'figures'), + drawgridlines=drawgridlines, + colorbarfmt='%.3f', + stationsongrids=stationsongrids, + plotFormat=plot_fmt, + userTitle=user_title, + ) # Plot mean of station-wise delay period across each gridcell if isinstance(df_stats.grid_seasonal_period, np.ndarray): - logger.info("- Plot mean of station-wise delay period across each gridcell.") - df_stats(df_stats.grid_seasonal_period, 'grid_seasonal_period', workdir=os.path.join(workdir, 'figures'), - drawgridlines=drawgridlines, colorbarfmt='%.2f', stationsongrids=stationsongrids, plotFormat=plot_fmt, userTitle=user_title) + logger.info('- Plot mean of station-wise delay period across each gridcell.') + df_stats( + df_stats.grid_seasonal_period, + 'grid_seasonal_period', + workdir=os.path.join(workdir, 'figures'), + drawgridlines=drawgridlines, + colorbarfmt='%.2f', + stationsongrids=stationsongrids, + plotFormat=plot_fmt, + userTitle=user_title, + ) # Plot mean stdev of station-wise delay phase across each gridcell if isinstance(df_stats.grid_seasonal_phase_stdev, np.ndarray): - logger.info("- Plot mean stdev of station-wise delay phase across each gridcell.") - df_stats(df_stats.grid_seasonal_phase_stdev, 'grid_seasonal_phase_stdev', workdir=os.path.join(workdir, 'figures'), - drawgridlines=drawgridlines, colorbarfmt='%.1i', stationsongrids=stationsongrids, plotFormat=plot_fmt, userTitle=user_title) + logger.info('- Plot mean stdev of station-wise delay phase across each gridcell.') + df_stats( + df_stats.grid_seasonal_phase_stdev, + 'grid_seasonal_phase_stdev', + workdir=os.path.join(workdir, 'figures'), + drawgridlines=drawgridlines, + colorbarfmt='%.1i', + stationsongrids=stationsongrids, + plotFormat=plot_fmt, + userTitle=user_title, + ) # Plot mean stdev of station-wise delay amplitude across each gridcell if isinstance(df_stats.grid_seasonal_amplitude_stdev, np.ndarray): - logger.info("- Plot mean stdev of station-wise delay amplitude across each gridcell.") - df_stats(df_stats.grid_seasonal_amplitude_stdev, 'grid_seasonal_amplitude_stdev', workdir=os.path.join(workdir, 'figures'), - drawgridlines=drawgridlines, colorbarfmt='%.3f', stationsongrids=stationsongrids, plotFormat=plot_fmt, userTitle=user_title) + logger.info('- Plot mean stdev of station-wise delay amplitude across each gridcell.') + df_stats( + df_stats.grid_seasonal_amplitude_stdev, + 'grid_seasonal_amplitude_stdev', + workdir=os.path.join(workdir, 'figures'), + drawgridlines=drawgridlines, + colorbarfmt='%.3f', + stationsongrids=stationsongrids, + plotFormat=plot_fmt, + userTitle=user_title, + ) # Plot mean stdev of station-wise delay period across each gridcell if isinstance(df_stats.grid_seasonal_period_stdev, np.ndarray): - logger.info("- Plot mean stdev of station-wise delay period across each gridcell.") - df_stats(df_stats.grid_seasonal_period_stdev, 'grid_seasonal_period_stdev', workdir=os.path.join(workdir, 'figures'), - drawgridlines=drawgridlines, colorbarfmt='%.2e', stationsongrids=stationsongrids, plotFormat=plot_fmt, userTitle=user_title) + logger.info('- Plot mean stdev of station-wise delay period across each gridcell.') + df_stats( + df_stats.grid_seasonal_period_stdev, + 'grid_seasonal_period_stdev', + workdir=os.path.join(workdir, 'figures'), + drawgridlines=drawgridlines, + colorbarfmt='%.2e', + stationsongrids=stationsongrids, + plotFormat=plot_fmt, + userTitle=user_title, + ) # Plot mean of seasonal fit RMSE across each gridcell if isinstance(df_stats.grid_seasonal_fit_rmse, np.ndarray): - logger.info("- Plot mean of seasonal fit RMSE across each gridcell.") - df_stats(df_stats.grid_seasonal_fit_rmse, 'grid_seasonal_fit_rmse', workdir=os.path.join(workdir, 'figures'), - drawgridlines=drawgridlines, colorbarfmt='%.3f', stationsongrids=stationsongrids, plotFormat=plot_fmt, userTitle=user_title) + logger.info('- Plot mean of seasonal fit RMSE across each gridcell.') + df_stats( + df_stats.grid_seasonal_fit_rmse, + 'grid_seasonal_fit_rmse', + workdir=os.path.join(workdir, 'figures'), + drawgridlines=drawgridlines, + colorbarfmt='%.3f', + stationsongrids=stationsongrids, + plotFormat=plot_fmt, + userTitle=user_title, + ) # Plot absolute mean delay for each gridcell if isinstance(df_stats.grid_delay_absolute_mean, np.ndarray): - logger.info("- Plot absolute mean delay per gridcell.") - df_stats(df_stats.grid_delay_absolute_mean, 'grid_delay_absolute_mean', workdir=os.path.join(workdir, 'figures'), - drawgridlines=drawgridlines, colorbarfmt='%.2f', stationsongrids=stationsongrids, plotFormat=plot_fmt, userTitle=user_title) + logger.info('- Plot absolute mean delay per gridcell.') + df_stats( + df_stats.grid_delay_absolute_mean, + 'grid_delay_absolute_mean', + workdir=os.path.join(workdir, 'figures'), + drawgridlines=drawgridlines, + colorbarfmt='%.2f', + stationsongrids=stationsongrids, + plotFormat=plot_fmt, + userTitle=user_title, + ) # Plot absolute median delay for each gridcell if isinstance(df_stats.grid_delay_absolute_median, np.ndarray): - logger.info("- Plot absolute median delay per gridcell.") - df_stats(df_stats.grid_delay_absolute_median, 'grid_delay_absolute_median', workdir=os.path.join(workdir, 'figures'), - drawgridlines=drawgridlines, colorbarfmt='%.2f', stationsongrids=stationsongrids, plotFormat=plot_fmt, userTitle=user_title) + logger.info('- Plot absolute median delay per gridcell.') + df_stats( + df_stats.grid_delay_absolute_median, + 'grid_delay_absolute_median', + workdir=os.path.join(workdir, 'figures'), + drawgridlines=drawgridlines, + colorbarfmt='%.2f', + stationsongrids=stationsongrids, + plotFormat=plot_fmt, + userTitle=user_title, + ) # Plot absolute stdev delay for each gridcell if isinstance(df_stats.grid_delay_absolute_stdev, np.ndarray): - logger.info("- Plot absolute delay stdev per gridcell.") - df_stats(df_stats.grid_delay_absolute_stdev, 'grid_delay_absolute_stdev', workdir=os.path.join(workdir, 'figures'), - drawgridlines=drawgridlines, colorbarfmt='%.2f', stationsongrids=stationsongrids, plotFormat=plot_fmt, userTitle=user_title) + logger.info('- Plot absolute delay stdev per gridcell.') + df_stats( + df_stats.grid_delay_absolute_stdev, + 'grid_delay_absolute_stdev', + workdir=os.path.join(workdir, 'figures'), + drawgridlines=drawgridlines, + colorbarfmt='%.2f', + stationsongrids=stationsongrids, + plotFormat=plot_fmt, + userTitle=user_title, + ) # Plot absolute delay phase for each gridcell if isinstance(df_stats.grid_seasonal_absolute_phase, np.ndarray): - logger.info("- Plot absolute delay phase per gridcell.") - df_stats(df_stats.grid_seasonal_absolute_phase, 'grid_seasonal_absolute_phase', workdir=os.path.join(workdir, 'figures'), - drawgridlines=drawgridlines, colorbarfmt='%.1i', stationsongrids=stationsongrids, plotFormat=plot_fmt, userTitle=user_title) + logger.info('- Plot absolute delay phase per gridcell.') + df_stats( + df_stats.grid_seasonal_absolute_phase, + 'grid_seasonal_absolute_phase', + workdir=os.path.join(workdir, 'figures'), + drawgridlines=drawgridlines, + colorbarfmt='%.1i', + stationsongrids=stationsongrids, + plotFormat=plot_fmt, + userTitle=user_title, + ) # Plot absolute delay amplitude for each gridcell if isinstance(df_stats.grid_seasonal_absolute_amplitude, np.ndarray): - logger.info("- Plot absolute delay amplitude per gridcell.") - df_stats(df_stats.grid_seasonal_absolute_amplitude, 'grid_seasonal_absolute_amplitude', workdir=os.path.join(workdir, 'figures'), - drawgridlines=drawgridlines, colorbarfmt='%.3f', stationsongrids=stationsongrids, plotFormat=plot_fmt, userTitle=user_title) + logger.info('- Plot absolute delay amplitude per gridcell.') + df_stats( + df_stats.grid_seasonal_absolute_amplitude, + 'grid_seasonal_absolute_amplitude', + workdir=os.path.join(workdir, 'figures'), + drawgridlines=drawgridlines, + colorbarfmt='%.3f', + stationsongrids=stationsongrids, + plotFormat=plot_fmt, + userTitle=user_title, + ) # Plot absolute delay period for each gridcell if isinstance(df_stats.grid_seasonal_absolute_period, np.ndarray): - logger.info("- Plot absolute delay period per gridcell.") - df_stats(df_stats.grid_seasonal_absolute_period, 'grid_seasonal_absolute_period', workdir=os.path.join(workdir, 'figures'), - drawgridlines=drawgridlines, colorbarfmt='%.2f', stationsongrids=stationsongrids, plotFormat=plot_fmt, userTitle=user_title) + logger.info('- Plot absolute delay period per gridcell.') + df_stats( + df_stats.grid_seasonal_absolute_period, + 'grid_seasonal_absolute_period', + workdir=os.path.join(workdir, 'figures'), + drawgridlines=drawgridlines, + colorbarfmt='%.2f', + stationsongrids=stationsongrids, + plotFormat=plot_fmt, + userTitle=user_title, + ) # Plot absolute delay phase stdev for each gridcell if isinstance(df_stats.grid_seasonal_absolute_phase_stdev, np.ndarray): - logger.info("- Plot absolute delay phase stdev per gridcell.") - df_stats(df_stats.grid_seasonal_absolute_phase_stdev, 'grid_seasonal_absolute_phase_stdev', workdir=os.path.join(workdir, 'figures'), - drawgridlines=drawgridlines, colorbarfmt='%.1i', stationsongrids=stationsongrids, plotFormat=plot_fmt, userTitle=user_title) + logger.info('- Plot absolute delay phase stdev per gridcell.') + df_stats( + df_stats.grid_seasonal_absolute_phase_stdev, + 'grid_seasonal_absolute_phase_stdev', + workdir=os.path.join(workdir, 'figures'), + drawgridlines=drawgridlines, + colorbarfmt='%.1i', + stationsongrids=stationsongrids, + plotFormat=plot_fmt, + userTitle=user_title, + ) # Plot absolute delay amplitude stdev for each gridcell if isinstance(df_stats.grid_seasonal_absolute_amplitude_stdev, np.ndarray): - logger.info("- Plot absolute delay amplitude stdev per gridcell.") - df_stats(df_stats.grid_seasonal_absolute_amplitude_stdev, 'grid_seasonal_absolute_amplitude_stdev', workdir=os.path.join(workdir, 'figures'), - drawgridlines=drawgridlines, colorbarfmt='%.3f', stationsongrids=stationsongrids, plotFormat=plot_fmt, userTitle=user_title) + logger.info('- Plot absolute delay amplitude stdev per gridcell.') + df_stats( + df_stats.grid_seasonal_absolute_amplitude_stdev, + 'grid_seasonal_absolute_amplitude_stdev', + workdir=os.path.join(workdir, 'figures'), + drawgridlines=drawgridlines, + colorbarfmt='%.3f', + stationsongrids=stationsongrids, + plotFormat=plot_fmt, + userTitle=user_title, + ) # Plot absolute delay period stdev for each gridcell if isinstance(df_stats.grid_seasonal_absolute_period_stdev, np.ndarray): - logger.info("- Plot absolute delay period stdev per gridcell.") - df_stats(df_stats.grid_seasonal_absolute_period_stdev, 'grid_seasonal_absolute_period_stdev', workdir=os.path.join(workdir, 'figures'), - drawgridlines=drawgridlines, colorbarfmt='%.2e', stationsongrids=stationsongrids, plotFormat=plot_fmt, userTitle=user_title) + logger.info('- Plot absolute delay period stdev per gridcell.') + df_stats( + df_stats.grid_seasonal_absolute_period_stdev, + 'grid_seasonal_absolute_period_stdev', + workdir=os.path.join(workdir, 'figures'), + drawgridlines=drawgridlines, + colorbarfmt='%.2e', + stationsongrids=stationsongrids, + plotFormat=plot_fmt, + userTitle=user_title, + ) # Plot absolute mean seasonal fit RMSE for each gridcell if isinstance(df_stats.grid_seasonal_absolute_fit_rmse, np.ndarray): - logger.info("- Plot absolute mean seasonal fit RMSE per gridcell.") - df_stats(df_stats.grid_seasonal_absolute_fit_rmse, 'grid_seasonal_absolute_fit_rmse', workdir=os.path.join(workdir, 'figures'), - drawgridlines=drawgridlines, colorbarfmt='%.2e', stationsongrids=stationsongrids, plotFormat=plot_fmt, userTitle=user_title) + logger.info('- Plot absolute mean seasonal fit RMSE per gridcell.') + df_stats( + df_stats.grid_seasonal_absolute_fit_rmse, + 'grid_seasonal_absolute_fit_rmse', + workdir=os.path.join(workdir, 'figures'), + drawgridlines=drawgridlines, + colorbarfmt='%.2e', + stationsongrids=stationsongrids, + plotFormat=plot_fmt, + userTitle=user_title, + ) # Perform variogram analysis - if variogramplot and not isinstance(df_stats.grid_range, np.ndarray) \ - and not isinstance(df_stats.grid_variance, np.ndarray) \ - and not isinstance(df_stats.grid_variogram_rmse, np.ndarray): - logger.info("***Variogram Analysis Function:***") + if ( + variogramplot + and not isinstance(df_stats.grid_range, np.ndarray) + and not isinstance(df_stats.grid_variance, np.ndarray) + and not isinstance(df_stats.grid_variogram_rmse, np.ndarray) + ): + logger.info('***Variogram Analysis Function:***') if unit in ['minute', 'hour', 'day', 'year']: unit = 'm' df_stats.unit = 'm' - logger.warning("Output unit {} specified for Variogram analysis. Reverted to meters".format(unit)) - make_variograms = VariogramAnalysis(df_stats.df, df_stats.gridpoints, col_name, unit, workdir, - df_stats.seasonalinterval, densitythreshold, binnedvariogram, - numCPUs, variogram_per_timeslice, variogram_errlimit) + logger.warning(f'Output unit {unit} specified for Variogram analysis. Reverted to meters') + make_variograms = VariogramAnalysis( + df_stats.df, + df_stats.gridpoints, + col_name, + unit, + workdir, + df_stats.seasonalinterval, + densitythreshold, + binnedvariogram, + numCPUs, + variogram_per_timeslice, + variogram_errlimit, + ) TOT_grids, TOT_res_robust_arr, TOT_res_robust_rmse = make_variograms.create_variograms() # get range - df_stats.grid_range = np.array([np.nan if i[0] not in TOT_grids else float(TOT_res_robust_arr[TOT_grids.index( - i[0])][0]) for i in enumerate(df_stats.gridpoints)]).reshape(df_stats.grid_dim).T + df_stats.grid_range = ( + np.array( + [ + np.nan if i[0] not in TOT_grids else float(TOT_res_robust_arr[TOT_grids.index(i[0])][0]) + for i in enumerate(df_stats.gridpoints) + ] + ) + .reshape(df_stats.grid_dim) + .T + ) # convert range to specified output unit df_stats.grid_range = convert_SI(df_stats.grid_range, 'm', unit) # get sill - df_stats.grid_variance = np.array([np.nan if i[0] not in TOT_grids else float(TOT_res_robust_arr[TOT_grids.index( - i[0])][1]) for i in enumerate(df_stats.gridpoints)]).reshape(df_stats.grid_dim).T + df_stats.grid_variance = ( + np.array( + [ + np.nan if i[0] not in TOT_grids else float(TOT_res_robust_arr[TOT_grids.index(i[0])][1]) + for i in enumerate(df_stats.gridpoints) + ] + ) + .reshape(df_stats.grid_dim) + .T + ) # convert sill to specified output unit df_stats.grid_range = convert_SI(df_stats.grid_range, 'm^2', unit.split('^2')[0] + '^2') # get variogram rmse - df_stats.grid_variogram_rmse = np.array([np.nan if i[0] not in TOT_grids else float(TOT_res_robust_rmse[TOT_grids.index( - i[0])]) for i in enumerate(df_stats.gridpoints)]).reshape(df_stats.grid_dim).T + df_stats.grid_variogram_rmse = ( + np.array( + [ + np.nan if i[0] not in TOT_grids else float(TOT_res_robust_rmse[TOT_grids.index(i[0])]) + for i in enumerate(df_stats.gridpoints) + ] + ) + .reshape(df_stats.grid_dim) + .T + ) # convert range to specified output unit df_stats.grid_variogram_rmse = convert_SI(df_stats.grid_variogram_rmse, 'm', unit) # If specified, save gridded array(s) if grid_to_raster: # write range gridfile_name = os.path.join(workdir, col_name + '_' + 'grid_range' + '.tif') - save_gridfile(df_stats.grid_range, 'grid_range', gridfile_name, df_stats.plotbbox, df_stats.spacing, - df_stats.unit, colorbarfmt='%1i', - stationsongrids=df_stats.stationsongrids, dtype='float32') + save_gridfile( + df_stats.grid_range, + 'grid_range', + gridfile_name, + df_stats.plotbbox, + df_stats.spacing, + df_stats.unit, + colorbarfmt='%1i', + stationsongrids=df_stats.stationsongrids, + dtype='float32', + ) # write sill gridfile_name = os.path.join(workdir, col_name + '_' + 'grid_variance' + '.tif') - save_gridfile(df_stats.grid_variance, 'grid_variance', gridfile_name, df_stats.plotbbox, df_stats.spacing, - df_stats.unit + '^2', colorbarfmt='%.3e', - stationsongrids=df_stats.stationsongrids, dtype='float32') + save_gridfile( + df_stats.grid_variance, + 'grid_variance', + gridfile_name, + df_stats.plotbbox, + df_stats.spacing, + df_stats.unit + '^2', + colorbarfmt='%.3e', + stationsongrids=df_stats.stationsongrids, + dtype='float32', + ) # write variogram rmse gridfile_name = os.path.join(workdir, col_name + '_' + 'grid_variogram_rmse' + '.tif') - save_gridfile(df_stats.grid_variogram_rmse, 'grid_variogram_rmse', gridfile_name, df_stats.plotbbox, df_stats.spacing, - df_stats.unit, colorbarfmt='%.2e', - stationsongrids=df_stats.stationsongrids, dtype='float32') + save_gridfile( + df_stats.grid_variogram_rmse, + 'grid_variogram_rmse', + gridfile_name, + df_stats.plotbbox, + df_stats.spacing, + df_stats.unit, + colorbarfmt='%.2e', + stationsongrids=df_stats.stationsongrids, + dtype='float32', + ) if isinstance(df_stats.grid_range, np.ndarray): # plot range heatmap - logger.info("- Plot variogram range per gridcell.") - df_stats(df_stats.grid_range, 'grid_range', workdir=os.path.join(workdir, 'figures'), - colorbarfmt='%1i', drawgridlines=drawgridlines, stationsongrids=stationsongrids, plotFormat=plot_fmt, userTitle=user_title) + logger.info('- Plot variogram range per gridcell.') + df_stats( + df_stats.grid_range, + 'grid_range', + workdir=os.path.join(workdir, 'figures'), + colorbarfmt='%1i', + drawgridlines=drawgridlines, + stationsongrids=stationsongrids, + plotFormat=plot_fmt, + userTitle=user_title, + ) if isinstance(df_stats.grid_variance, np.ndarray): # plot sill heatmap - logger.info("- Plot variogram sill per gridcell.") - df_stats(df_stats.grid_variance, 'grid_variance', workdir=os.path.join(workdir, 'figures'), drawgridlines=drawgridlines, - colorbarfmt='%.3e', stationsongrids=stationsongrids, plotFormat=plot_fmt, userTitle=user_title) + logger.info('- Plot variogram sill per gridcell.') + df_stats( + df_stats.grid_variance, + 'grid_variance', + workdir=os.path.join(workdir, 'figures'), + drawgridlines=drawgridlines, + colorbarfmt='%.3e', + stationsongrids=stationsongrids, + plotFormat=plot_fmt, + userTitle=user_title, + ) if isinstance(df_stats.grid_variogram_rmse, np.ndarray): # plot variogram rmse heatmap - logger.info("- Plot variogram RMSE per gridcell.") - df_stats(df_stats.grid_variogram_rmse, 'grid_variogram_rmse', workdir=os.path.join(workdir, 'figures'), drawgridlines=drawgridlines, - colorbarfmt='%.2e', stationsongrids=stationsongrids, plotFormat=plot_fmt, userTitle=user_title) - -def main(): + logger.info('- Plot variogram RMSE per gridcell.') + df_stats( + df_stats.grid_variogram_rmse, + 'grid_variogram_rmse', + workdir=os.path.join(workdir, 'figures'), + drawgridlines=drawgridlines, + colorbarfmt='%.2e', + stationsongrids=stationsongrids, + plotFormat=plot_fmt, + userTitle=user_title, + ) + + +def main() -> None: inps = cmd_line_parse() stats_analyses( @@ -2102,5 +3465,5 @@ def main(): inps.variogramplot, inps.binnedvariogram, inps.variogram_per_timeslice, - inps.variogram_errlimit + inps.variogram_errlimit, ) diff --git a/tools/RAiDER/cli/types.py b/tools/RAiDER/cli/types.py new file mode 100644 index 000000000..73ad5ce19 --- /dev/null +++ b/tools/RAiDER/cli/types.py @@ -0,0 +1,209 @@ +import argparse +import dataclasses +import datetime as dt +import itertools +import time +from pathlib import Path +from typing import Literal, Optional, Union + +import numpy as np + +from RAiDER.constants import _CUBE_SPACING_IN_M +from RAiDER.llreader import AOI +from RAiDER.losreader import LOS +from RAiDER.models.weatherModel import WeatherModel +from RAiDER.types import BB, LookDir, TimeInterpolationMethod + + +LOSConvention = Literal['isce', 'hyp3'] + +@dataclasses.dataclass +class DateGroupUnparsed: + date_start: Optional[Union[int, str]] = None + date_end: Optional[Union[int, str]] = None + date_step: Optional[Union[int, str]] = None + date_list: Optional[Union[int, str]] = None + +@dataclasses.dataclass +class DateGroup: + # After the dates have been parsed, only the date list is valid, and all + # other fields of the date group should not be used. + date_list: list[dt.date] + + +class TimeGroup: + """Parse an input time (required to be ISO 8601).""" + _DEFAULT_ACQUISITION_WINDOW_SEC = 30 + TIME_FORMATS = ( + '', + 'T%H:%M:%S.%f', + 'T%H%M%S.%f', + '%H%M%S.%f', + 'T%H:%M:%S', + '%H:%M:%S', + 'T%H%M%S', + '%H%M%S', + 'T%H:%M', + 'T%H%M', + '%H:%M', + 'T%H', + ) + TIMEZONE_FORMATS = ( + '', + 'Z', + '%z', + ) + time: dt.time + end_time: dt.time + interpolate_time: Optional[TimeInterpolationMethod] + + def __init__( + self, + time: Optional[Union[str, dt.time]] = None, + end_time: Optional[Union[str, dt.time]] = None, + interpolate_time: Optional[TimeInterpolationMethod] = None, + ) -> None: + self.interpolate_time = interpolate_time + + if time is None: + raise ValueError('You must specify a "time" in the input config file') + if isinstance(time, dt.time): + self.time = time + else: + self.time = TimeGroup.coerce_into_time(time) + + if end_time is not None: + if isinstance(end_time, dt.time): + self.end_time = end_time + else: + self.end_time = TimeGroup.coerce_into_time(end_time) + if self.end_time < self.time: + raise ValueError( + 'Acquisition start time must be before end time. ' + f'Provided start time {self.time} is later than end time {self.end_time}' + ) + else: + sentinel_datetime = dt.datetime.combine(dt.date(1900, 1, 1), self.time) + new_end_time = sentinel_datetime + dt.timedelta(seconds=TimeGroup._DEFAULT_ACQUISITION_WINDOW_SEC) + self.end_time = new_end_time.time() + if self.end_time < self.time: + raise ValueError( + 'Acquisition start time must be before end time. ' + f'Provided start time {self.time} is later than end time {self.end_time} ' + f'(with default window of {TimeGroup._DEFAULT_ACQUISITION_WINDOW_SEC} seconds)' + ) + + @staticmethod + def coerce_into_time(val: Union[int, str]) -> dt.time: + val = str(val) + all_formats = map(''.join, itertools.product(TimeGroup.TIME_FORMATS, TimeGroup.TIMEZONE_FORMATS)) + for tf in all_formats: + try: + return dt.time(*time.strptime(val, tf)[3:6]) + except ValueError: + pass + raise ValueError(f'Unable to coerce "{val}" to a time. Try T%H:%M:%S') + +@dataclasses.dataclass +class AOIGroupUnparsed: + bounding_box: Optional[Union[str, list[Union[float, int]], BB.SNWE]] = None + geocoded_file: Optional[str] = None + lat_file: Optional[str] = None + lon_file: Optional[str] = None + station_file: Optional[str] = None + geo_cube: Optional[str] = None + +@dataclasses.dataclass +class AOIGroup: + # Once the AOI group is parsed, the members from the config file should not + # be read again. Instead, the parsed AOI will be available on AOIGroup.aoi. + aoi: AOI + + +@dataclasses.dataclass +class HeightGroupUnparsed: + dem: Optional[str] = None + use_dem_latlon: bool = False + height_file_rdr: Optional[str] = None + height_levels: Optional[Union[str, list[Union[float, int]]]] = None + +@dataclasses.dataclass +class HeightGroup: + dem: Optional[str] + use_dem_latlon: bool + height_file_rdr: Optional[str] + height_levels: Optional[list[float]] + + +@dataclasses.dataclass +class LOSGroupUnparsed: + ray_trace: bool = False + los_file: Optional[str] = None + los_convention: LOSConvention = 'isce' + los_cube: Optional[str] = None + orbit_file: Optional[str] = None + zref: Optional[np.float64] = None + +@dataclasses.dataclass +class LOSGroup: + los: LOS + ray_trace: bool = False + los_file: Optional[str] = None + los_convention: LOSConvention = 'isce' + los_cube: Optional[str] = None + orbit_file: Optional[str] = None + zref: Optional[np.float64] = None + +class RuntimeGroup: + raster_format: str + file_format: str # TODO(garlic-os): redundant with raster_format? + verbose: bool + output_projection: str + cube_spacing_in_m: float + download_only: bool + output_directory: Path + weather_model_directory: Path + + def __init__( + self, + raster_format: str = 'GTiff', + file_format: str = 'GTiff', + verbose: bool = True, + output_projection: str = 'EPSG:4326', + cube_spacing_in_m: float = _CUBE_SPACING_IN_M, + download_only: bool = False, + output_directory: str = '.', + weather_model_directory: Optional[str] = None, + ): + self.raster_format = raster_format + self.file_format = file_format + self.verbose = verbose + self.output_projection = output_projection + self.cube_spacing_in_m = cube_spacing_in_m + self.download_only = download_only + self.output_directory = Path(output_directory) + if weather_model_directory is not None: + self.weather_model_directory = Path(weather_model_directory) + else: + self.weather_model_directory = self.output_directory / 'weather_files' + + +@dataclasses.dataclass +class RunConfig: + weather_model: WeatherModel + date_group: DateGroup + time_group: TimeGroup + aoi_group: AOIGroup + height_group: HeightGroup + los_group: LOSGroup + runtime_group: RuntimeGroup + look_dir: LookDir = 'right' + cube_spacing_in_m: Optional[float] = None # deprecated + wetFilenames: Optional[list[str]] = None + hydroFilenames: Optional[list[str]] = None + + +class RAiDERArgs(argparse.Namespace): + download_only: bool = False + generate_config: Optional[str] = None + run_config_file: Optional[Path] diff --git a/tools/RAiDER/cli/validators.py b/tools/RAiDER/cli/validators.py index 78c470248..dda62ec54 100755 --- a/tools/RAiDER/cli/validators.py +++ b/tools/RAiDER/cli/validators.py @@ -1,165 +1,181 @@ -from argparse import Action, ArgumentError, ArgumentTypeError - +import argparse +import datetime as dt import importlib -import itertools -import os import re +import sys +from pathlib import Path +from typing import Any, Optional, Union -import pandas as pd import numpy as np +import pandas as pd -from datetime import time, timedelta, datetime, date -from textwrap import dedent -from time import strptime -from RAiDER.llreader import BoundingBox, Geocube, RasterRDR, StationFile, GeocodedFile, Geocube -from RAiDER.losreader import Zenith, Conventional -from RAiDER.utilFcns import rio_extents, rio_profile +if sys.version_info >= (3,11): + from typing import Self +else: + Self = Any + +from RAiDER.cli.types import ( + AOIGroupUnparsed, + DateGroup, + DateGroupUnparsed, + HeightGroup, + HeightGroupUnparsed, + LOSGroupUnparsed, + RuntimeGroup, +) +from RAiDER.llreader import AOI, BoundingBox, GeocodedFile, Geocube, RasterRDR, StationFile from RAiDER.logger import logger +from RAiDER.losreader import LOS, Conventional, Zenith +from RAiDER.models.weatherModel import WeatherModel +from RAiDER.types import BB +from RAiDER.utilFcns import rio_extents, rio_profile -_BUFFER_SIZE = 0.2 # default buffer size in lat/lon degrees -def enforce_wm(value, aoi): - model = value.upper().replace("-", "") +_BUFFER_SIZE = 0.2 # default buffer size in lat/lon degrees + + +def parse_weather_model(weather_model_name: str, aoi: AOI) -> WeatherModel: + weather_model_name = weather_model_name.upper().replace('-', '') try: - _, model_obj = modelName2Module(model) + _, Model = get_wm_by_name(weather_model_name) except ModuleNotFoundError: raise NotImplementedError( - dedent(''' - Model {} is not yet fully implemented, - please contribute! - '''.format(model)) + f'Model {weather_model_name} is not yet fully implemented, please contribute!' ) - ## check the user requsted bounding box is within the weather model domain - modObj = model_obj().checkValidBounds(aoi.bounds()) + # Check that the user-requested bounding box is within the weather model domain + model: WeatherModel = Model() + model.checkValidBounds(aoi.bounds()) - return modObj + return model -def get_los(args): - if args.get('orbit_file'): - if args.get('ray_trace'): +def get_los(los_group: LOSGroupUnparsed) -> LOS: + if los_group.orbit_file is not None: + if los_group.ray_trace: from RAiDER.losreader import Raytracing - los = Raytracing(args.orbit_file) + los = Raytracing(los_group.orbit_file) else: - los = Conventional(args.orbit_file) - elif args.get('los_file'): - if args.ray_trace: + los = Conventional(los_group.orbit_file) + + elif los_group.los_file is not None: + if los_group.ray_trace: from RAiDER.losreader import Raytracing - los = Raytracing(args.los_file, args.los_convention) + los = Raytracing(los_group.los_file, los_group.los_convention) else: - los = Conventional(args.los_file, args.los_convention) + los = Conventional(los_group.los_file, los_group.los_convention) - elif args.get('los_cube'): + elif los_group.los_cube is not None: raise NotImplementedError('LOS_cube is not yet implemented') -# if args.ray_trace: -# los = Raytracing(args.los_cube) -# else: -# los = Conventional(args.los_cube) + # if los_group.ray_trace: + # los = Raytracing(los_group.los_cube) + # else: + # los = Conventional(los_group.los_cube) else: los = Zenith() return los -def get_heights(args, out, station_file, bounding_box=None): - ''' - Parse the Height info and download a DEM if needed - ''' - dem_path = out - - out = { - 'dem': args.get('dem'), - 'height_file_rdr': None, - 'height_levels': None, - } - - if args.get('dem'): - if (station_file is not None): - if 'Hgt_m' not in pd.read_csv(station_file): - out['dem'] = os.path.join(dem_path, 'GLO30.dem') - elif os.path.exists(args.dem): - out['dem'] = args.dem +def get_heights(height_group: HeightGroupUnparsed, aoi_group: AOIGroupUnparsed, runtime_group: RuntimeGroup) -> HeightGroup: + """Parse the Height info and download a DEM if needed.""" + result = HeightGroup( + dem=height_group.dem, + use_dem_latlon=height_group.use_dem_latlon, + height_file_rdr=height_group.height_file_rdr, + height_levels=None, + ) + + if height_group.dem is not None: + if aoi_group.station_file is not None: + station_data = pd.read_csv(aoi_group.station_file) + if 'Hgt_m' not in station_data: + result.dem = runtime_group.output_directory / 'GLO30.dem' + elif Path(height_group.dem).exists(): # crop the DEM - if bounding_box is not None: - dem_bounds = rio_extents(rio_profile(args.dem)) - lats = dem_bounds[:2] - lons = dem_bounds[2:] + if aoi_group.bounding_box is not None: + dem_bounds = rio_extents(rio_profile(height_group.dem)) + lats: BB.SN = dem_bounds[:2] + lons: BB.WE = dem_bounds[2:] if isOutside( - bounding_box, + parse_bbox(aoi_group.bounding_box), getBufferedExtent( lats, lons, - buf=_BUFFER_SIZE, - ) + buffer_size=_BUFFER_SIZE, + ), ): raise ValueError( - 'Existing DEM does not cover the area of the input lat/lon ' - 'points; either move the DEM, delete it, or change the input ' - 'points.' - ) - else: - pass # will download the dem later - - elif args.get('height_file_rdr'): - out['height_file_rdr'] = args.height_file_rdr + 'Existing DEM does not cover the area of the input lat/lon points; either move the DEM, delete ' + 'it, or change the input points.' + ) + # else: will download the dem later - else: + elif height_group.height_file_rdr is None: # download the DEM if needed - out['dem'] = os.path.join(dem_path, 'GLO30.dem') + result.dem = runtime_group.output_directory / 'GLO30.dem' - if args.get('height_levels'): - if isinstance(args.height_levels, str): - l = re.findall('[-0-9]+', args.height_levels) + if height_group.height_levels is not None: + if isinstance(height_group.height_levels, str): + levels = re.findall('[-0-9]+', height_group.height_levels) else: - l = args.height_levels + levels = height_group.height_levels - out['height_levels'] = np.array([float(ll) for ll in l]) - if np.any(out['height_levels'] < 0): - logger.warning('Weather model only extends to the surface topography; ' - 'height levels below the topography will be interpolated from the surface ' - 'and may be inaccurate.') + levels = np.array([float(level) for level in levels]) + if np.any(levels < 0): + logger.warning( + 'Weather model only extends to the surface topography; ' + 'height levels below the topography will be interpolated from the surface and may be inaccurate.' + ) + result.height_levels = list(levels) - return out + return result -def get_query_region(args): - ''' - Parse the query region from inputs - ''' +def get_query_region(aoi_group: AOIGroupUnparsed, height_group: HeightGroupUnparsed, cube_spacing_in_m: float) -> AOI: + """Parse the query region from inputs. + + This function determines the query region from the input parameters. It will return an AOI object that can be used + to query the weather model. + Note: both an AOI group and a height group are necessary in case a DEM is needed. + """ # Get bounds from the inputs # make sure this is first - if args.get('use_dem_latlon'): - query = GeocodedFile(args.dem, is_dem=True) - - elif args.get('lat_file'): - hgt_file = args.get('height_file_rdr') # only get it if exists - dem_file = args.get('dem') - query = RasterRDR(args.lat_file, args.lon_file, hgt_file, dem_file) + if height_group.use_dem_latlon: + query = GeocodedFile(Path(height_group.dem), is_dem=True, cube_spacing_in_m=cube_spacing_in_m) + + elif aoi_group.lat_file is not None or aoi_group.lon_file is not None: + if aoi_group.lat_file is None or aoi_group.lon_file is None: + raise ValueError('A lon_file must be specified if a lat_file is specified') + query = RasterRDR( + aoi_group.lat_file, aoi_group.lon_file, + height_group.height_file_rdr, height_group.dem, + cube_spacing_in_m=cube_spacing_in_m + ) - elif args.get('station_file'): - query = StationFile(args.station_file) + elif aoi_group.station_file is not None: + query = StationFile(aoi_group.station_file, cube_spacing_in_m=cube_spacing_in_m) - elif args.get('bounding_box'): - bbox = enforce_bbox(args.bounding_box) - if (np.min(bbox[0]) < -90) | (np.max(bbox[1]) > 90): + elif aoi_group.bounding_box is not None: + bbox = parse_bbox(aoi_group.bounding_box) + if np.min(bbox[0]) < -90 or np.max(bbox[1]) > 90: raise ValueError('Lats are out of N/S bounds; are your lat/lon coordinates switched? Should be SNWE') - query = BoundingBox(bbox) + query = BoundingBox(bbox, cube_spacing_in_m=cube_spacing_in_m) - elif args.get('geocoded_file'): - gfile = os.path.basename(args.geocoded_file).upper() - if (gfile.startswith('SRTM') or gfile.startswith('GLO')): - logger.debug('Using user DEM: %s', gfile) + elif aoi_group.geocoded_file is not None: + geocoded_file_path = Path(aoi_group.geocoded_file) + filename = geocoded_file_path.name.upper() + if filename.startswith('SRTM') or filename.startswith('GLO'): + logger.debug('Using user DEM: %s', filename) is_dem = True else: is_dem = False + query = GeocodedFile(geocoded_file_path, is_dem=is_dem, cube_spacing_in_m=cube_spacing_in_m) - query = GeocodedFile(args.geocoded_file, is_dem=is_dem) - - ## untested - elif args.get('geo_cube'): - query = Geocube(args.geo_cube) + # untested + elif aoi_group.geo_cube is not None: + query = Geocube(aoi_group.geo_cube, cube_spacing_in_m) else: # TODO: Need to incorporate the cube @@ -168,10 +184,8 @@ def get_query_region(args): return query -def enforce_bbox(bbox): - """ - Enforce a valid bounding box - """ +def parse_bbox(bbox: Union[str, list[Union[int, float]], tuple]) -> BB.SNWE: + """Parse a bounding box string input and ensure it is valid.""" if isinstance(bbox, str): bbox = [float(d) for d in bbox.strip().split()] else: @@ -179,7 +193,7 @@ def enforce_bbox(bbox): # Check the bbox if len(bbox) != 4: - raise ValueError("bounding box must have 4 elements!") + raise ValueError('bounding box must have 4 elements!') S, N, W, E = bbox if N <= S or E <= W: @@ -191,51 +205,51 @@ def enforce_bbox(bbox): for we in (W, E): if we < -180 or we > 180: - raise ValueError('Lons are out of W/E bounds (-180 to 180); Lons in the format of (0 to 360) are not supported.') - - return bbox + raise ValueError( + 'Lons are out of W/E bounds (-180 to 180); Lons in the format of (0 to 360) are not supported.' + ) + return S, N, W, E -def parse_dates(arg_dict): - ''' - Determine the requested dates from the input parameters - ''' - if arg_dict.get('date_list'): - l = arg_dict['date_list'] - if isinstance(l, str): - l = re.findall('[0-9]+', l) - elif isinstance(l, int): - l = [l] - L = [enforce_valid_dates(d) for d in l] +def parse_dates(date_group: DateGroupUnparsed) -> DateGroup: + """Determine the requested dates from the input parameters.""" + if date_group.date_list is not None: + if isinstance(date_group.date_list, str): + unparsed_dates = re.findall('[0-9]+', date_group.date_list) + elif isinstance(date_group.date_list, int): + unparsed_dates = [date_group.date_list] + else: + unparsed_dates = date_group.date_list + date_list = [coerce_into_date(d) for d in unparsed_dates] else: - try: - start = arg_dict['date_start'] - except KeyError: + if date_group.date_start is None: raise ValueError('Inputs must include either date_list or date_start') - start = enforce_valid_dates(start) + start = coerce_into_date(date_group.date_start) - if arg_dict.get('date_end'): - end = arg_dict['date_end'] - end = enforce_valid_dates(end) + if date_group.date_end is not None: + end = coerce_into_date(date_group.date_end) else: - end = start + end = start - if arg_dict.get('date_step'): - step = int(arg_dict['date_step']) + if date_group.date_step: + step = int(date_group.date_step) else: step = 1 - L = [start + timedelta(days=step) for step in range(0, (end - start).days + 1, step)] - - return L + date_list = [ + start + dt.timedelta(days=step) + for step in range(0, (end - start).days + 1, step) + ] + + return DateGroup( + date_list=date_list, + ) -def enforce_valid_dates(arg): - """ - Parse a date from a string in pseudo-ISO 8601 format. - """ +def coerce_into_date(val: Union[int, str]) -> dt.date: + """Parse a date from a string in pseudo-ISO 8601 format.""" year_formats = ( '%Y-%m-%d', '%Y%m%d', @@ -245,166 +259,80 @@ def enforce_valid_dates(arg): for yf in year_formats: try: - return datetime.strptime(str(arg), yf) + return dt.datetime.strptime(str(val), yf).date() except ValueError: pass + raise ValueError(f'Unable to coerce {val} to a date. Try %Y-%m-%d') - raise ValueError( - 'Unable to coerce {} to a date. Try %Y-%m-%d'.format(arg) - ) - - -def enforce_time(arg_dict): - ''' - Parse an input time (required to be ISO 8601) - ''' - try: - arg_dict['time'] = convert_time(arg_dict['time']) - except KeyError: - raise ValueError('You must specify a "time" in the input config file') - - if 'end_time' in arg_dict.keys(): - arg_dict['end_time'] = convert_time(arg_dict['end_time']) - return arg_dict - - -def convert_time(inp): - time_formats = ( - '', - 'T%H:%M:%S.%f', - 'T%H%M%S.%f', - '%H%M%S.%f', - 'T%H:%M:%S', - '%H:%M:%S', - 'T%H%M%S', - '%H%M%S', - 'T%H:%M', - 'T%H%M', - '%H:%M', - 'T%H', - ) - timezone_formats = ( - '', - 'Z', - '%z', - ) - all_formats = map( - ''.join, - itertools.product(time_formats, timezone_formats) - ) - - for tf in all_formats: - try: - return time(*strptime(inp, tf)[3:6]) - except ValueError: - pass - - raise ValueError( - 'Unable to coerce {} to a time.'+ - 'Try T%H:%M:%S'.format(inp) - ) +def get_wm_by_name(model_name: str) -> tuple[str, WeatherModel]: + """ + Turn an arbitrary string into a module name. -def modelName2Module(model_name): - """Turn an arbitrary string into a module name. Takes as input a model name, which hopefully looks like ERA-I, and - converts it to a module name, which will look like erai. I doesn't + converts it to a module name, which will look like erai. It doesn't always produce a valid module name, but that's not the goal. The goal is just to handle common cases. Inputs: model_name - Name of an allowed weather model (e.g., 'era-5') Outputs: module_name - Name of the module - wmObject - callable, weather model object + wmObject - callable, weather model object. """ module_name = 'RAiDER.models.' + model_name.lower().replace('-', '') - model_module = importlib.import_module(module_name) - wmObject = getattr(model_module, model_name.upper().replace('-', '')) - return module_name, wmObject + module = importlib.import_module(module_name) + Model = getattr(module, model_name.upper().replace('-', '')) + return module_name, Model + + +def getBufferedExtent(lats: BB.SN, lons: BB.WE, buffer_size: float=0.0) -> BB.SNWE: + """Get the bounding box around a set of lats/lons.""" + return ( + min(lats) - buffer_size, + max(lats) + buffer_size, + min(lons) - buffer_size, + max(lons) + buffer_size + ) -def getBufferedExtent(lats, lons=None, buf=0.): - ''' - get the bounding box around a set of lats/lons - ''' - if lons is None: - lats, lons = lats[..., 0], lons[..., 1] +def isOutside(extent1: BB.SNWE, extent2: BB.SNWE) -> bool: + """Determine whether any of extent1 lies outside extent2. - try: - if (lats.size == 1) & (lons.size == 1): - out = [lats - buf, lats + buf, lons - buf, lons + buf] - elif (lats.size > 1) & (lons.size > 1): - out = [np.nanmin(lats), np.nanmax(lats), np.nanmin(lons), np.nanmax(lons)] - elif lats.size == 1: - out = [lats - buf, lats + buf, np.nanmin(lons), np.nanmax(lons)] - elif lons.size == 1: - out = [np.nanmin(lats), np.nanmax(lats), lons - buf, lons + buf] - except AttributeError: - if (isinstance(lats, tuple) or isinstance(lats, list)) and len(lats) == 2: - out = [min(lats) - buf, max(lats) + buf, min(lons) - buf, max(lons) + buf] - except Exception as e: - raise RuntimeError('Not a valid lat/lon shape or variable') - - return np.array(out) - - -def isOutside(extent1, extent2): - ''' - Determine whether any of extent1 lies outside extent2 - extent1/2 should be a list containing [lower_lat, upper_lat, left_lon, right_lon] - Equal extents are considered "inside" - ''' + extent1/2 should be a list containing [lower_lat, upper_lat, left_lon, right_lon] (SNWE). + Equal extents are considered "inside". + """ t1 = extent1[0] < extent2[0] t2 = extent1[1] > extent2[1] t3 = extent1[2] < extent2[2] t4 = extent1[3] > extent2[3] - if np.any([t1, t2, t3, t4]): - return True - return False + return any((t1, t2, t3, t4)) + +def isInside(extent1: BB.SNWE, extent2: BB.SNWE) -> bool: + """Determine whether all of extent1 lies inside extent2. -def isInside(extent1, extent2): - ''' - Determine whether all of extent1 lies inside extent2 - extent1/2 should be a list containing [lower_lat, upper_lat, left_lon, right_lon]. - Equal extents are considered "inside" - ''' + extent1/2 should be a list containing [lower_lat, upper_lat, left_lon, right_lon] (SNWE). + Equal extents are considered "inside". + """ t1 = extent1[0] <= extent2[0] t2 = extent1[1] >= extent2[1] t3 = extent1[2] <= extent2[2] t4 = extent1[3] >= extent2[3] - if np.all([t1, t2, t3, t4]): - return True - return False - - -## below are for downloadGNSSDelays -def date_type(arg): - """ - Parse a date from a string in pseudo-ISO 8601 format. - """ - year_formats = ( - '%Y-%m-%d', - '%Y%m%d', - '%d', - '%j', - ) + return all((t1, t2, t3, t4)) - for yf in year_formats: - try: - return date(*strptime(arg, yf)[0:3]) - except ValueError: - pass - raise ArgumentTypeError( - 'Unable to coerce {} to a date. Try %Y-%m-%d'.format(arg) - ) +# below are for downloadGNSSDelays +def date_type(val: Union[int, str]) -> dt.date: + """Parse a date from a string in pseudo-ISO 8601 format.""" + try: + return coerce_into_date(val) + except ValueError as exc: + raise argparse.ArgumentTypeError(str(exc)) -class MappingType(object): - """ - A type that maps arguments to constants. +class MappingType: + """A type that maps arguments to constants. # Example ``` @@ -414,34 +342,31 @@ class MappingType(object): assert mapping("hello") is None ``` """ + UNSET = object() + _default: Union[object, Any] - def __init__(self, **kwargs): + def __init__(self, **kwargs: dict[str, Any]) -> None: self.mapping = kwargs self._default = self.UNSET - def default(self, default): - """Set a default value if no mapping is found""" + def default(self, default: Any) -> Self: # noqa: ANN401 + """Set a default value if no mapping is found.""" self._default = default return self - def __call__(self, arg): + def __call__(self, arg: str) -> Any: # noqa: ANN401 if arg in self.mapping: return self.mapping[arg] if self._default is self.UNSET: - raise KeyError( - "Invalid choice '{}', must be one of {}".format( - arg, list(self.mapping.keys()) - ) - ) + raise KeyError(f"Invalid choice '{arg}', must be one of {list(self.mapping.keys())}") return self._default -class IntegerType(object): - """ - A type that converts arguments to integers. +class IntegerOnRangeType: + """A type that converts arguments to integers and enforces that they are on a certain range. # Example ``` @@ -452,24 +377,23 @@ class IntegerType(object): ``` """ - def __init__(self, lo=None, hi=None): + def __init__(self, lo: Optional[int]=None, hi: Optional[int]=None) -> None: self.lo = lo self.hi = hi - def __call__(self, arg): + def __call__(self, arg: Any) -> int: # noqa: ANN401 integer = int(arg) if self.lo is not None and integer < self.lo: - raise ArgumentTypeError("Must be greater than {}".format(self.lo)) + raise argparse.ArgumentTypeError(f'Must be greater than {self.lo}') if self.hi is not None and integer > self.hi: - raise ArgumentTypeError("Must be less than {}".format(self.hi)) + raise argparse.ArgumentTypeError(f'Must be less than {self.hi}') return integer -class IntegerMappingType(MappingType, IntegerType): - """ - An integer type that converts non-integer types through a mapping. +class IntegerMappingType(MappingType, IntegerOnRangeType): + """An integer type that converts non-integer types through a mapping. # Example ``` @@ -480,36 +404,36 @@ class IntegerMappingType(MappingType, IntegerType): ``` """ - def __init__(self, lo=None, hi=None, mapping={}, **kwargs): - IntegerType.__init__(self, lo, hi) + def __init__(self, lo: Optional[int]=None, hi: Optional[int]=None, mapping: Optional[dict[str, Any]]={}, **kwargs: dict[str, Any]) -> None: + IntegerOnRangeType.__init__(self, lo, hi) kwargs.update(mapping) MappingType.__init__(self, **kwargs) - def __call__(self, arg): + def __call__(self, arg: Any) -> Union[int, Any]: # noqa: ANN401 try: - return IntegerType.__call__(self, arg) + return IntegerOnRangeType.__call__(self, arg) except ValueError: return MappingType.__call__(self, arg) -class DateListAction(Action): - """An Action that parses and stores a list of dates""" +class DateListAction(argparse.Action): + """An Action that parses and stores a list of dates.""" def __init__( self, - option_strings, - dest, - nargs=None, - const=None, - default=None, - type=None, - choices=None, - required=False, - help=None, - metavar=None - ): + option_strings, # noqa: ANN001 -- see argparse.Action.__init__ + dest, # noqa: ANN001 + nargs=None, # noqa: ANN001 + const=None, # noqa: ANN001 + default=None, # noqa: ANN001 + type=None, # noqa: ANN001 + choices=None, # noqa: ANN001 + required=False, # noqa: ANN001 + help=None, # noqa: ANN001 + metavar=None, # noqa: ANN001 + ) -> None: if type is not date_type: - raise ValueError("type must be `date_type`!") + raise ValueError('type must be `date_type`!') super().__init__( option_strings=option_strings, @@ -521,49 +445,48 @@ def __init__( choices=choices, required=required, help=help, - metavar=metavar + metavar=metavar, ) - def __call__(self, parser, namespace, values, option_string=None): + def __call__(self, _, namespace, values, __=None): # noqa: ANN001, ANN204 -- see argparse.Action.__call__ if len(values) > 3 or not values: - raise ArgumentError(self, "Only 1, 2 dates, or 2 dates and interval may be supplied") + raise argparse.ArgumentError(self, 'Only 1, 2 dates, or 2 dates and interval may be supplied') if len(values) == 2: start, end = values - values = [start + timedelta(days=k) for k in range(0, (end - start).days + 1, 1)] + values = [start + dt.timedelta(days=k) for k in range(0, (end - start).days + 1, 1)] elif len(values) == 3: start, end, stepsize = values if not isinstance(stepsize.day, int): - raise ArgumentError(self, "The stepsize should be in integer days") + raise argparse.ArgumentError(self, 'The stepsize should be in integer days') - new_year = date(year=stepsize.year, month=1, day=1) + new_year = dt.date(year=stepsize.year, month=1, day=1) stepsize = (stepsize - new_year).days + 1 - values = [start + timedelta(days=k) - for k in range(0, (end - start).days + 1, stepsize)] + values = [start + dt.timedelta(days=k) for k in range(0, (end - start).days + 1, stepsize)] setattr(namespace, self.dest, values) -class BBoxAction(Action): - """An Action that parses and stores a valid bounding box""" +class BBoxAction(argparse.Action): + """An Action that parses and stores a valid bounding box.""" def __init__( self, - option_strings, - dest, - nargs=None, - const=None, - default=None, - type=None, - choices=None, - required=False, - help=None, - metavar=None - ): + option_strings, # noqa: ANN001 -- see argparse.Action.__init__ + dest, # noqa: ANN001 + nargs=None, # noqa: ANN001 + const=None, # noqa: ANN001 + default=None, # noqa: ANN001 + type=None, # noqa: ANN001 + choices=None, # noqa: ANN001 + required=False, # noqa: ANN001 + help=None, # noqa: ANN001 + metavar=None, # noqa: ANN001 + ) -> None: if nargs != 4: - raise ValueError("nargs must be 4!") + raise ValueError('nargs must be 4!') super().__init__( option_strings=option_strings, @@ -575,21 +498,24 @@ def __init__( choices=choices, required=required, help=help, - metavar=metavar + metavar=metavar, ) - def __call__(self, parser, namespace, values, option_string=None): + def __call__(self, _, namespace, values, __=None): # noqa: ANN001, ANN204 -- see argparse.Action.__call__ S, N, W, E = values if N <= S or E <= W: - raise ArgumentError(self, 'Bounding box has no size; make sure you use "S N W E"') + raise argparse.ArgumentError(self, 'Bounding box has no size; make sure you use "S N W E"') for sn in (S, N): if sn < -90 or sn > 90: - raise ArgumentError(self, 'Lats are out of S/N bounds (-90 to 90).') + raise argparse.ArgumentError(self, 'Lats are out of S/N bounds (-90 to 90).') for we in (W, E): if we < -180 or we > 180: - raise ArgumentError(self, 'Lons are out of W/E bounds (-180 to 180); Lons in the format of (0 to 360) are not supported.') + raise argparse.ArgumentError( + self, + 'Lons are out of W/E bounds (-180 to 180); Lons in the format of (0 to 360) are not supported.', + ) setattr(namespace, self.dest, values) diff --git a/tools/RAiDER/constants.py b/tools/RAiDER/constants.py index 00673c1bc..0bfa37bef 100644 --- a/tools/RAiDER/constants.py +++ b/tools/RAiDER/constants.py @@ -7,18 +7,17 @@ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ import numpy as np -_ZMIN = np.float64(-100) # minimum required height + +_ZMIN = np.float64(-100) # minimum required height _ZREF = np.float64(26000) # maximum integration height when not specified by user -_STEP = np.float64(15.0) # integration step size in meters +_STEP = np.float64(15.0) # integration step size in meters -_g0 = np.float64(9.80665) # Standard gravitational constant -_g1 = np.float64(9.80616) # Gravitational constant @ 45° latitude used for corrections of earth's centrifugal force +_g0 = np.float64(9.80665) # Standard gravitational constant +_g1 = np.float64(9.80616) # Gravitational constant @ 45° latitude used for corrections of earth's centrifugal force _RE = np.float64(6371008.7714) R_EARTH_MAX_WGS84 = 6378137 R_EARTH_MIN_WGS84 = 6356752 _CUBE_SPACING_IN_M = float(2000) # Horizontal spacing of cube -_THRESHOLD_SECONDS = 1 * 60 # Threshold delta_time in seconds - - +_THRESHOLD_SECONDS = 1 * 60 # Threshold delta_time in seconds diff --git a/tools/RAiDER/delay.py b/tools/RAiDER/delay.py index bd1097baf..ff7f34368 100755 --- a/tools/RAiDER/delay.py +++ b/tools/RAiDER/delay.py @@ -12,39 +12,44 @@ models are accessed as NETCDF files and should have "wet" "hydro" "wet_total" and "hydro_total" fields specified. """ -import os -import pyproj -import xarray -from datetime import datetime, timezone -from pyproj import CRS, Transformer -from typing import List, Union +import datetime as dt +import os +from typing import Optional, Union import numpy as np +import pyproj +import xarray as xr +from pyproj import CRS, Transformer from RAiDER.constants import _ZREF from RAiDER.delayFcns import getInterpolators +from RAiDER.llreader import AOI, BoundingBox, Geocube from RAiDER.logger import logger -from RAiDER.losreader import build_ray +from RAiDER.losreader import LOS, build_ray +from RAiDER.types import CRSLike +from RAiDER.utilFcns import parse_crs + ############################################################################### def tropo_delay( - dt, - weather_model_file: str, - aoi, - los, - height_levels: List[float]=None, - out_proj: Union[int, str] =4326, - zref: Union[int, float]=_ZREF, - ): - """ - Calculate integrated delays on query points. Options are: + datetime: dt.datetime, + weather_model_file: str, + aoi: AOI, + los: LOS, + height_levels: Optional[list[float]] = None, + out_proj: Union[int, str] = 4326, + zref: Optional[np.float64] = None, +): + """Calculate integrated delays on query points. + + Options are: 1. Zenith delays (ZTD) 2. Zenith delays projected to the line-of-sight (STD-projected) 3. Slant delays integrated along the raypath (STD-raytracing) Args: - dt: Datetime - Datetime object for determining when to calculate delays + datetime: Datetime - Datetime object for determining when to calculate delays weather_model_File: string - Name of the NETCDF file containing a pre-processed weather model aoi: AOI object - AOI object los: LOS object - LOS object @@ -52,45 +57,45 @@ def tropo_delay( out_proj: int,str - (optional) EPSG code for output projection zref: int,float - (optional) maximum height to integrate up to during raytracing - Returns: xarray Dataset *or* ndarrays: - wet and hydrostatic delays at the grid nodes / query points. """ crs = CRS(out_proj) # Load CRS from weather model file - with xarray.load_dataset(weather_model_file) as ds: + with xr.load_dataset(weather_model_file) as ds: try: wm_proj = CRS.from_wkt(ds['proj'].attrs['crs_wkt']) except KeyError: - logger.warning("WARNING: I can't find a CRS in the weather model file, so I will assume you are using WGS84") + logger.warning( + "WARNING: I can't find a CRS in the weather model file, so I will assume you are using WGS84" + ) wm_proj = CRS.from_epsg(4326) # get heights - with xarray.load_dataset(weather_model_file) as ds: + with xr.load_dataset(weather_model_file) as ds: wm_levels = ds.z.values - toa = wm_levels.max() - 1 + toa = wm_levels.max() - 1 if height_levels is None: - if aoi.type() == 'Geocube': + if isinstance(aoi, Geocube): height_levels = aoi.readZ() else: height_levels = wm_levels - if not zref: + if zref is None: zref = toa if zref > toa: zref = toa - logger.warning('Requested integration height (zref) is higher than top of weather model. Forcing to top ({toa}).') + logger.warning( + 'Requested integration height (zref) is higher than top of weather model. Forcing to top ({toa}).' + ) + # TODO: expose this as library function + ds = _get_delays_on_cube(datetime, weather_model_file, wm_proj, aoi, height_levels, los, crs, zref) - #TODO: expose this as library function - ds = _get_delays_on_cube( - dt, weather_model_file, wm_proj, aoi, height_levels, los, crs, zref - ) - - if (aoi.type() == 'bounding_box') or (aoi.type() == 'Geocube'): + if isinstance(aoi, (BoundingBox, Geocube)): return ds, None else: @@ -98,7 +103,7 @@ def tropo_delay( try: out_proj = CRS.from_epsg(out_proj) except pyproj.exceptions.CRSError: - out_proj = out_proj + pass pnt_proj = CRS.from_epsg(4326) lats, lons = aoi.readLL() @@ -106,33 +111,31 @@ def tropo_delay( pnts = transformPoints(lats, lons, hgts, pnt_proj, out_proj) try: - ifWet, ifHydro = getInterpolators(ds, "ztd") + ifWet, ifHydro = getInterpolators(ds, 'ztd') except RuntimeError: - logger.exception('Failed to get weather model %s interpolators.', weather_model_file) + raise RuntimeError(f'Failed to get weather model {weather_model_file} interpolators.') wetDelay = ifWet(pnts) hydroDelay = ifHydro(pnts) # return the delays (ZTD or STD) if los.is_Projected(): - los.setTime(dt) + los.setTime(datetime) los.setPoints(lats, lons, hgts) - wetDelay = los(wetDelay) + wetDelay = los(wetDelay) hydroDelay = los(hydroDelay) return wetDelay, hydroDelay -def _get_delays_on_cube(dt, weather_model_file, wm_proj, aoi, heights, los, crs, zref, nproc=1): - """ - raider cube generation function. - """ +def _get_delays_on_cube(datetime: dt.datetime, weather_model_file, wm_proj, aoi, heights, los, crs, zref, nproc=1): + """Raider cube generation function.""" zpts = np.array(heights) try: aoi.xpts except AttributeError: - with xarray.load_dataset(weather_model_file) as ds: + with xr.load_dataset(weather_model_file) as ds: x_spacing = ds.x.diff(dim='x').values.mean() y_spacing = ds.y.diff(dim='y').values.mean() aoi.set_output_spacing(ll_res=np.min([x_spacing, y_spacing])) @@ -140,28 +143,25 @@ def _get_delays_on_cube(dt, weather_model_file, wm_proj, aoi, heights, los, crs, # If no orbit is provided if los.is_Zenith() or los.is_Projected(): - out_type = ["zenith" if los.is_Zenith() else 'slant - projected'][0] + out_type = ['zenith' if los.is_Zenith() else 'slant - projected'][0] # Get ZTD interpolators try: - ifWet, ifHydro = getInterpolators(weather_model_file, "total") + ifWet, ifHydro = getInterpolators(weather_model_file, 'total') except RuntimeError: logger.exception('Failed to get weather model %s interpolators.', weather_model_file) - # Build cube - wetDelay, hydroDelay = _build_cube( - aoi.xpts, aoi.ypts, zpts, - wm_proj, crs, [ifWet, ifHydro]) + wetDelay, hydroDelay = _build_cube(aoi.xpts, aoi.ypts, zpts, wm_proj, crs, [ifWet, ifHydro]) else: - out_type = "slant - raytracing" + out_type = 'slant - raytracing' # Get pointwise interpolators try: ifWet, ifHydro = getInterpolators( weather_model_file, - kind="pointwise", + kind='pointwise', shared=(nproc > 1), ) except RuntimeError: @@ -170,9 +170,8 @@ def _get_delays_on_cube(dt, weather_model_file, wm_proj, aoi, heights, los, crs, # Build cube if nproc == 1: wetDelay, hydroDelay = _build_cube_ray( - aoi.xpts, aoi.ypts, zpts, los, - wm_proj, crs, - [ifWet, ifHydro], MAX_TROPO_HEIGHT=zref) + aoi.xpts, aoi.ypts, zpts, los, wm_proj, crs, [ifWet, ifHydro], MAX_TROPO_HEIGHT=zref + ) ### Use multi-processing here else: @@ -187,48 +186,48 @@ def _get_delays_on_cube(dt, weather_model_file, wm_proj, aoi, heights, los, crs, logger.critical('There are missing delay values. Check your inputs.') # Write output file - ds = writeResultsToXarray(dt, aoi.xpts, aoi.ypts, zpts, crs, wetDelay, - hydroDelay, weather_model_file, out_type) + ds = writeResultsToXarray(datetime, aoi.xpts, aoi.ypts, zpts, crs, wetDelay, hydroDelay, weather_model_file, out_type) return ds def _build_cube(xpts, ypts, zpts, model_crs, pts_crs, interpolators): - """ - Iterate over interpolators and build a cube using Zenith - """ + """Iterate over interpolators and build a cube using Zenith.""" # Create a regular 2D grid xx, yy = np.meshgrid(xpts, ypts) # Output arrays - outputArrs = [np.zeros((zpts.size, ypts.size, xpts.size)) - for mm in range(len(interpolators))] - + outputArrs = [np.zeros((zpts.size, ypts.size, xpts.size)) for mm in range(len(interpolators))] # Loop over heights and compute delays for ii, ht in enumerate(zpts): - # pts is in weather model system; if model_crs != pts_crs: # lat / lon / height for hrrr - pts = transformPoints(yy, xx, np.full(yy.shape, ht), - pts_crs, model_crs) + pts = transformPoints(yy, xx, np.full(yy.shape, ht), pts_crs, model_crs) else: pts = np.stack([yy, xx, np.full(yy.shape, ht)], axis=-1) for mm, intp in enumerate(interpolators): - outputArrs[mm][ii,...] = intp(pts) + outputArrs[mm][ii, ...] = intp(pts) return outputArrs def _build_cube_ray( - xpts, ypts, zpts, los, model_crs, - pts_crs, interpolators, outputArrs=None, MAX_SEGMENT_LENGTH=1000., - MAX_TROPO_HEIGHT=_ZREF, - ): + xpts, + ypts, + zpts, + los, + model_crs, + pts_crs, + interpolators, + outputArrs=None, + MAX_SEGMENT_LENGTH=1000.0, + MAX_TROPO_HEIGHT=_ZREF, +): """ - Iterate over interpolators and build a cube using raytracing + Iterate over interpolators and build a cube using raytracing. MAX_TROPO_HEIGHT should not extend above the top of the weather model """ @@ -244,8 +243,7 @@ def _build_cube_ray( output_created_here = False if outputArrs is None: output_created_here = True - outputArrs = [np.zeros((zpts.size, ypts.size, xpts.size)) - for mm in range(len(interpolators))] + outputArrs = [np.zeros((zpts.size, ypts.size, xpts.size)) for mm in range(len(interpolators))] # Various transformers needed here epsg4326 = CRS.from_epsg(4326) @@ -254,7 +252,7 @@ def _build_cube_ray( # Loop over heights of output cube and compute delays for hh, ht in enumerate(zpts): - logger.info(f"Processing slice {hh+1} / {len(zpts)}: {ht}") + logger.info(f'Processing slice {hh+1} / {len(zpts)}: {ht}') # Slices to fill on output outSubs = [x[hh, ...] for x in outputArrs] @@ -270,22 +268,21 @@ def _build_cube_ray( LOS = los.getLookVectors(ht, llh, xyz, yy) # Step 3 - Determine delays between each model height per ray - ray_lengths, low_xyzs, high_xyzs = \ - build_ray(model_zs, ht, xyz, LOS, MAX_TROPO_HEIGHT) + ray_lengths, low_xyzs, high_xyzs = build_ray(model_zs, ht, xyz, LOS, MAX_TROPO_HEIGHT) # if the top most height layer doesnt contribute to the integral, skip it if ray_lengths is None and ht == zpts[-1]: continue elif np.isnan(ray_lengths).all(): - raise ValueError("geo2rdr did not converge. Check orbit coverage") + raise ValueError('geo2rdr did not converge. Check orbit coverage') # Determine number of parts to break ray into (this is what gets integrated over) - nParts = np.ceil(ray_lengths.max((1,2)) / MAX_SEGMENT_LENGTH).astype(int) + 1 + nParts = np.ceil(ray_lengths.max((1, 2)) / MAX_SEGMENT_LENGTH).astype(int) + 1 # iterate over weather model height levels for zz, nparts in enumerate(nParts): - fracs = np.linspace(0., 1., num=nparts) + fracs = np.linspace(0.0, 1.0, num=nparts) # Integrate over chunks of ray for findex, ff in enumerate(fracs): @@ -293,11 +290,7 @@ def _build_cube_ray( pts_xyz = low_xyzs[zz] + ff * (high_xyzs[zz] - low_xyzs[zz]) # Ray point in model coordinates (x, y, z) - pts = ecef_to_model.transform( - pts_xyz[..., 0], - pts_xyz[..., 1], - pts_xyz[..., 2] - ) + pts = ecef_to_model.transform(pts_xyz[..., 0], pts_xyz[..., 1], pts_xyz[..., 2]) # Order for the interpolator (from xyz to yxz) pts = np.stack((pts[1], pts[0], pts[2]), axis=-1) @@ -316,12 +309,12 @@ def _build_cube_ray( pts[:, :, -1] = np.array(model_zs).max() # Trapezoidal integration with scaling - wt = 0.5 if findex in [0, fracs.size-1] else 1.0 - wt *= ray_lengths[zz] *1.0e-6 / (nparts - 1.0) + wt = 0.5 if findex in [0, fracs.size - 1] else 1.0 + wt *= ray_lengths[zz] * 1.0e-6 / (nparts - 1.0) # For each interpolator, integrate between levels for mm, out in enumerate(outSubs): - val = interpolators[mm](pts) + val = interpolators[mm](pts) # TODO - This should not occur if there is enough padding in model # val[np.isnan(val)] = 0.0 @@ -331,83 +324,90 @@ def _build_cube_ray( return outputArrs -def writeResultsToXarray(dt, xpts, ypts, zpts, crs, wetDelay, hydroDelay, weather_model_file, out_type): - ''' - write a 1-D array to a NETCDF5 file - ''' - # Modify this as needed for NISAR / other projects - ds = xarray.Dataset( +def writeResultsToXarray(datetime: dt.datetime, xpts, ypts, zpts, crs, wetDelay, hydroDelay, weather_model_file, out_type): + """Write a 1-D array to a NETCDF5 file.""" + # Modify this as needed for NISAR / other projects + ds = xr.Dataset( data_vars=dict( - wet=(["z", "y", "x"], - wetDelay, - {"units" : "m", - "description": f"wet {out_type} delay", - # 'crs': crs.to_epsg(), - "grid_mapping": "crs", - - }), - hydro=(["z", "y", "x"], - hydroDelay, - {"units": "m", + wet=( + ['z', 'y', 'x'], + wetDelay, + { + 'units': 'm', + 'description': f'wet {out_type} delay', + # 'crs': crs.to_epsg(), + 'grid_mapping': 'crs', + }, + ), + hydro=( + ['z', 'y', 'x'], + hydroDelay, + { + 'units': 'm', # 'crs': crs.to_epsg(), - "description": f"hydrostatic {out_type} delay", - "grid_mapping": "crs", - }), + 'description': f'hydrostatic {out_type} delay', + 'grid_mapping': 'crs', + }, + ), ), coords=dict( - x=(["x"], xpts), - y=(["y"], ypts), - z=(["z"], zpts), + x=(['x'], xpts), + y=(['y'], ypts), + z=(['z'], zpts), ), attrs=dict( - Conventions="CF-1.7", - title="RAiDER geo cube", + Conventions='CF-1.7', + title='RAiDER geo cube', source=os.path.basename(weather_model_file), - history=str(datetime.now(tz=timezone.utc)) + " RAiDER", - description=f"RAiDER geo cube - {out_type}", - reference_time=dt.strftime("%Y%m%dT%H:%M:%S"), + history=str(dt.datetime.now(tz=dt.timezone.utc)) + ' RAiDER', + description=f'RAiDER geo cube - {out_type}', + reference_time=datetime.strftime('%Y%m%dT%H:%M:%S'), ), ) # Write projection system mapping - ds["crs"] = int(-2147483647) # dummy placeholder + ds['crs'] = -2147483647 # dummy placeholder for k, v in crs.to_cf().items(): ds.crs.attrs[k] = v # Write z-axis information - ds.z.attrs["axis"] = "Z" - ds.z.attrs["units"] = "m" - ds.z.attrs["description"] = "height above ellipsoid" + ds.z.attrs['axis'] = 'Z' + ds.z.attrs['units'] = 'm' + ds.z.attrs['description'] = 'height above ellipsoid' # If in degrees - if crs.axis_info[0].unit_name == "degree": - ds.y.attrs["units"] = "degrees_north" - ds.y.attrs["standard_name"] = "latitude" - ds.y.attrs["long_name"] = "latitude" + if crs.axis_info[0].unit_name == 'degree': + ds.y.attrs['units'] = 'degrees_north' + ds.y.attrs['standard_name'] = 'latitude' + ds.y.attrs['long_name'] = 'latitude' - ds.x.attrs["units"] = "degrees_east" - ds.x.attrs["standard_name"] = "longitude" - ds.x.attrs["long_name"] = "longitude" + ds.x.attrs['units'] = 'degrees_east' + ds.x.attrs['standard_name'] = 'longitude' + ds.x.attrs['long_name'] = 'longitude' else: - ds.y.attrs["axis"] = "Y" - ds.y.attrs["standard_name"] = "projection_y_coordinate" - ds.y.attrs["long_name"] = "y-coordinate in projected coordinate system" - ds.y.attrs["units"] = "m" + ds.y.attrs['axis'] = 'Y' + ds.y.attrs['standard_name'] = 'projection_y_coordinate' + ds.y.attrs['long_name'] = 'y-coordinate in projected coordinate system' + ds.y.attrs['units'] = 'm' - ds.x.attrs["axis"] = "X" - ds.x.attrs["standard_name"] = "projection_x_coordinate" - ds.x.attrs["long_name"] = "x-coordinate in projected coordinate system" - ds.x.attrs["units"] = "m" + ds.x.attrs['axis'] = 'X' + ds.x.attrs['standard_name'] = 'projection_x_coordinate' + ds.x.attrs['long_name'] = 'x-coordinate in projected coordinate system' + ds.x.attrs['units'] = 'm' return ds - -def transformPoints(lats: np.ndarray, lons: np.ndarray, hgts: np.ndarray, old_proj: CRS, new_proj: CRS) -> np.ndarray: - ''' - Transform lat/lon/hgt data to an array of points in a new - projection +def transformPoints( + lats: np.ndarray, + lons: np.ndarray, + hgts: np.ndarray, + old_proj: CRSLike, + new_proj: CRSLike, +) -> np.ndarray: + """ + Transform lat/lon/hgt data to an array of points in a new projection. Args: lats: ndarray - WGS-84 latitude (EPSG: 4326) @@ -418,19 +418,17 @@ def transformPoints(lats: np.ndarray, lons: np.ndarray, hgts: np.ndarray, old_pr Returns: ndarray: the array of query points in the weather model coordinate system (YX) - ''' + """ # Flags for flipping inputs or outputs - if not isinstance(new_proj, CRS): - new_proj = CRS.from_epsg(new_proj.lstrip('EPSG:')) - if not isinstance(old_proj, CRS): - old_proj = CRS.from_epsg(old_proj.lstrip('EPSG:')) + old_proj = parse_crs(old_proj) + new_proj = parse_crs(new_proj) t = Transformer.from_crs(old_proj, new_proj, always_xy=True) # in_flip = old_proj.axis_info[0].direction # out_flip = new_proj.axis_info[0].direction - res = t.transform(lons, lats, hgts) + res = t.transform(lons, lats, hgts) # lat/lon/height - return np.stack([res[1], res[0], res[2]], axis=-1) + return np.stack([res[1], res[0], res[2]], axis=-1) diff --git a/tools/RAiDER/delayFcns.py b/tools/RAiDER/delayFcns.py index a71de3b21..e55823fae 100755 --- a/tools/RAiDER/delayFcns.py +++ b/tools/RAiDER/delayFcns.py @@ -10,40 +10,39 @@ except ImportError: mp = None -import xarray +from pathlib import Path +from typing import Union import numpy as np - +import xarray as xr from scipy.interpolate import RegularGridInterpolator as Interpolator from RAiDER.logger import logger -def getInterpolators(wm_file, kind='pointwise', shared=False): - ''' +# TODO(garlic-os): type annotate the choices for kind +def getInterpolators(wm_file: Union[xr.Dataset, Path, str], kind: str='pointwise', shared: bool=False) -> tuple[Interpolator, Interpolator]: + """ Read 3D gridded data from a processed weather model file and wrap it with - the scipy RegularGridInterpolator + the scipy RegularGridInterpolator. The interpolator grid is (y, x, z) - ''' + """ # Get the weather model data - try: - ds = xarray.load_dataset(wm_file) - except ValueError: - ds = wm_file + ds = wm_file if isinstance(wm_file, xr.Dataset) else xr.load_dataset(wm_file) xs_wm = np.array(ds.variables['x'][:]) ys_wm = np.array(ds.variables['y'][:]) zs_wm = np.array(ds.variables['z'][:]) - wet = ds.variables['wet_total' if kind=='total' else 'wet'][:] - hydro = ds.variables['hydro_total' if kind=='total' else 'hydro'][:] + wet = ds.variables['wet_total' if kind == 'total' else 'wet'][:] + hydro = ds.variables['hydro_total' if kind == 'total' else 'hydro'][:] wet = np.array(wet).transpose(1, 2, 0) hydro = np.array(hydro).transpose(1, 2, 0) if np.any(np.isnan(wet)) or np.any(np.isnan(hydro)): - logger.critical(f'Weather model contains NaNs!') + logger.critical('Weather model contains NaNs!') # If shared interpolators are requested # The arrays are not modified - so turning off lock for performance @@ -51,31 +50,25 @@ def getInterpolators(wm_file, kind='pointwise', shared=False): xs_wm = make_shared_raw(xs_wm) ys_wm = make_shared_raw(ys_wm) zs_wm = make_shared_raw(zs_wm) - wet = make_shared_raw(wet) + wet = make_shared_raw(wet) hydro = make_shared_raw(hydro) - - ifWet = Interpolator((ys_wm, xs_wm, zs_wm), wet, fill_value=np.nan, bounds_error = False) - ifHydro = Interpolator((ys_wm, xs_wm, zs_wm), hydro, fill_value=np.nan, bounds_error = False) + ifWet = Interpolator((ys_wm, xs_wm, zs_wm), wet, fill_value=np.nan, bounds_error=False) + ifHydro = Interpolator((ys_wm, xs_wm, zs_wm), hydro, fill_value=np.nan, bounds_error=False) return ifWet, ifHydro def make_shared_raw(inarr): - """ - Make numpy view array of mp.Array - """ + """Make numpy view array of mp.Array.""" # Create flat shared array if mp is None: raise ImportError('multiprocessing is not available') - + shared_arr = mp.RawArray('d', inarr.size) # Create a numpy view of it - shared_arr_np = np.ndarray(inarr.shape, dtype=np.float64, - buffer=shared_arr) + shared_arr_np = np.ndarray(inarr.shape, dtype=np.float64, buffer=shared_arr) # Copy data to shared array np.copyto(shared_arr_np, inarr) return shared_arr_np - - diff --git a/tools/RAiDER/dem.py b/tools/RAiDER/dem.py index 7d6099c46..53922ed8b 100644 --- a/tools/RAiDER/dem.py +++ b/tools/RAiDER/dem.py @@ -5,71 +5,70 @@ # RESERVED. United States Government Sponsorship acknowledged. # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -from logging import warn -import os +from pathlib import Path +from typing import Optional, cast import numpy as np -import pandas as pd - import rasterio from dem_stitcher.stitcher import stitch_dem from RAiDER.logger import logger +from RAiDER.types import BB, RIO from RAiDER.utilFcns import rio_open def download_dem( - ll_bounds=None, - demName='warpedDEM.dem', - overwrite=False, - writeDEM=False, - buf=0.02, - ): - """ - Download a DEM if one is not already present. + ll_bounds=None, + dem_path: Path=Path('warpedDEM.dem'), + overwrite: bool=False, + writeDEM: bool=False, + buf: float=0.02, +) -> tuple[np.ndarray, Optional[RIO.Profile]]: + """Download a DEM if one is not already present. + Args: - llbounds: list/ndarry of floats -lat/lon bounds of the area to download. Values should be ordered in the following way: [S, N, W, E] - writeDEM: boolean -write the DEM to file - outName: string -name of the DEM file - buf: float -buffer to add to the bounds - overwrite: boolean -overwrite existing DEM + llbounds: list/ndarry of floats - lat/lon bounds of the area to download. Values should be ordered in the following way: [S, N, W, E] + writeDEM: bool - write the DEM to file + outName: string - name of the DEM file + buf: float - buffer to add to the bounds + overwrite: bool - overwrite existing DEM Returns: - zvals: np.array -DEM heights - metadata: -metadata for the DEM + zvals: np.array - DEM heights + metadata: - metadata for the DEM """ - if os.path.exists(demName): - if overwrite: - download = True - else: - download = False + if dem_path.exists(): + download = overwrite else: download = True - if download and (ll_bounds is None): + if download and ll_bounds is None: raise ValueError('download_dem: Either an existing file or lat/lon bounds must be passed') if not download: - logger.info('Using existing DEM: %s', demName) - zvals, metadata = rio_open(demName, returnProj=True) + logger.info('Using existing DEM: %s', dem_path) + zvals, metadata = rio_open(dem_path) else: # download the dem # inExtent is SNWE # dem-stitcher wants WSEN - bounds = [ - np.floor(ll_bounds[2]) - buf, np.floor(ll_bounds[0]) - buf, - np.ceil(ll_bounds[3]) + buf, np.ceil(ll_bounds[1]) + buf - ] + bounds: BB.WSEN = ( + np.floor(ll_bounds[2]) - buf, + np.floor(ll_bounds[0]) - buf, + np.ceil(ll_bounds[3]) + buf, + np.ceil(ll_bounds[1]) + buf, + ) zvals, metadata = stitch_dem( - bounds, + list(bounds), dem_name='glo_30', dst_ellipsoidal_height=True, dst_area_or_point='Area', ) + metadata = cast(RIO.Profile, metadata) if writeDEM: - with rasterio.open(demName, 'w', **metadata) as ds: + with rasterio.open(dem_path, 'w', **metadata) as ds: ds.write(zvals, 1) ds.update_tags(AREA_OR_POINT='Point') - logger.info('Wrote DEM: %s', demName) + logger.info('Wrote DEM: %s', dem_path) return zvals, metadata diff --git a/tools/RAiDER/getStationDelays.py b/tools/RAiDER/getStationDelays.py index f61f8b8db..9e77d07ee 100644 --- a/tools/RAiDER/getStationDelays.py +++ b/tools/RAiDER/getStationDelays.py @@ -8,7 +8,7 @@ import datetime as dt import gzip import io -import multiprocessing +import multiprocessing as mp import os import zipfile @@ -19,21 +19,21 @@ from RAiDER.logger import logger -def get_delays_UNR(stationFile, filename, dateList, returnTime=None): - ''' +def get_delays_UNR(stationFile, filename, dateList, returnTime=None) -> None: + """ Parses and returns a dictionary containing either (1) all the GPS delays, if returnTime is None, or (2) only the delay at the closest times to to returnTime. - + Args: stationFile: binary - a .gz station delay file - filename: ? - ? + filename: ? - ? dateList: list of datetime - ? returnTime: datetime - specified time of GPS delay (default all times) Returns: None - + The function writes a CSV file containing the times and delay information (delay in mm, delay uncertainty, delay gradients) @@ -43,23 +43,23 @@ def get_delays_UNR(stationFile, filename, dateList, returnTime=None): Wet and hydrostratic delays were derived as so: Constants —> k1 = 0.704, k2 = 0.776, k3 = 3739.0, m = 18.0152/28.9644, k2' = k2-(k1*m) = 0.33812796398337275, Rv = 461.5 J/(kg·K), ρl = 997 kg/m^3 - - *NOTE: wet delays passed here are computed using - PMV = precipitable water vapor, - P = total atm pressure, + + *NOTE: wet delays passed here are computed using + PMV = precipitable water vapor, + P = total atm pressure, Tm = mean temp of the column, as: Wet zenith delay = 10^-6 ρlRv(k2' + k3/Tm) PMV Hydrostatic zenith delay = Total zenith delay - wet zenith delay = k1*(P/Tm) - + Source —> Hanssen, R. F. (2001) eqns. 6.2.7-10 - *NOTE: Due to a formatting error in the tropo SINEX files, the two - tropospheric gradient columns (TGNTOT and TGETOT) are interchanged, + *NOTE: Due to a formatting error in the tropo SINEX files, the two + tropospheric gradient columns (TGNTOT and TGETOT) are interchanged, as are the formal error columns (_SIG). Source —> http://geodesy.unr.edu/gps_timeseries/README_trop2.txt) - ''' + """ # sort through station zip files allstationTarfiles = [] # if URL @@ -97,19 +97,21 @@ def get_delays_UNR(stationFile, filename, dateList, returnTime=None): try: split_lines = line.split() # units: mm, mm, mm, deg, deg, deg, deg, mm, mm, K - trotot, trototSD, trwet, tgetot, tgetotSD, tgntot, tgntotSD, wvapor, wvaporSD, mtemp = \ - [float(t) for t in split_lines[2:]] - except BaseException: # TODO: What error(s)? + trotot, trototSD, trwet, tgetot, tgetotSD, tgntot, tgntotSD, wvapor, wvaporSD, mtemp = ( + float(t) for t in split_lines[2:] + ) + except: # TODO: What error(s)? continue site = split_lines[0] - year, doy, seconds = [int(n) - for n in split_lines[1].split(':')] + year, doy, seconds = (int(n) for n in split_lines[1].split(':')) # Break iteration if time from line in file does not match date reported in filename if doy != doyFromFile: logger.warning( 'time %s from line in conflict with time %s from file ' '%s, will continue reading next tarfile(s)', - doy, doyFromFile, j + doy, + doyFromFile, + j, ) continue # convert units from mm to m @@ -124,16 +126,15 @@ def get_delays_UNR(stationFile, filename, dateList, returnTime=None): # Break iteration if file contains no data. if d == []: logger.warning( - 'file %s for station %s is empty, will continue reading next ' - 'tarfile(s)', j, j.split('.')[0] + 'file %s for station %s is empty, will continue reading next tarfile(s)', + j, j.split('.')[0] ) continue # check for missing times true_times = list(range(0, 86400, 300)) if len(timesList) != len(true_times): - missing = [ - True if t not in timesList else False for t in true_times] + missing = [t not in timesList for t in true_times] mask = np.array(missing) delay, sig, wet_delay, hydro_delay = [np.full((288,), np.nan)] * 4 delay[~mask] = d @@ -150,14 +151,29 @@ def get_delays_UNR(stationFile, filename, dateList, returnTime=None): # if time not specified, pass all times if returnTime is None: - filtoutput = {'ID': [site] * len(wet_delay), 'Date': [time] * len(wet_delay), 'ZTD': delay, 'wet_delay': wet_delay, - 'hydrostatic_delay': hydro_delay, 'times': times, 'sigZTD': sig} - filtoutput = [{key: value[k] for key, value in filtoutput.items()} - for k in range(len(filtoutput['ID']))] + filtoutput = { + 'ID': [site] * len(wet_delay), + 'Date': [time] * len(wet_delay), + 'ZTD': delay, + 'wet_delay': wet_delay, + 'hydrostatic_delay': hydro_delay, + 'times': times, + 'sigZTD': sig, + } + filtoutput = [{key: value[k] for key, value in filtoutput.items()} for k in range(len(filtoutput['ID']))] else: index = np.argmin(np.abs(np.array(timesList) - returnTime)) - filtoutput = [{'ID': site, 'Date': time, 'ZTD': delay[index], 'wet_delay': wet_delay[index], - 'hydrostatic_delay': hydro_delay[index], 'times': times[index], 'sigZTD': sig[index]}] + filtoutput = [ + { + 'ID': site, + 'Date': time, + 'ZTD': delay[index], + 'wet_delay': wet_delay[index], + 'hydrostatic_delay': hydro_delay[index], + 'times': times[index], + 'sigZTD': sig[index], + } + ] # setup pandas array and write output to CSV, making sure to update existing CSV. filtoutput = pd.DataFrame(filtoutput) if os.path.exists(filename): @@ -166,18 +182,13 @@ def get_delays_UNR(stationFile, filename, dateList, returnTime=None): filtoutput.to_csv(filename, index=False) # record all used tar files - allstationTarfiles.extend([os.path.join(stationFile, k) - for k in stationTarlist]) + allstationTarfiles.extend([os.path.join(stationFile, k) for k in stationTarlist]) allstationTarfiles.sort() del ziprepo - return - -def get_station_data(inFile, dateList, gps_repo=None, numCPUs=8, outDir=None, returnTime=None): - ''' - Pull tropospheric delay data for a given station name - ''' +def get_station_data(inFile, dateList, gps_repo=None, numCPUs=8, outDir=None, returnTime=None) -> None: + """Pull tropospheric delay data for a given station name.""" if outDir is None: outDir = os.getcwd() @@ -188,13 +199,12 @@ def get_station_data(inFile, dateList, gps_repo=None, numCPUs=8, outDir=None, re returnTime = seconds_of_day(returnTime) # print warning if not divisible by 3 seconds if returnTime % 3 != 0: - index = np.argmin( - np.abs(np.array(list(range(0, 86400, 300))) - returnTime)) - updatedreturnTime = str(dt.timedelta( - seconds=list(range(0, 86400, 300))[index])) + index = np.argmin(np.abs(np.array(list(range(0, 86400, 300))) - returnTime)) + updatedreturnTime = str(dt.timedelta(seconds=list(range(0, 86400, 300))[index])) logger.warning( - 'input time %s not divisible by 3 seconds, so next closest time %s ' - 'will be chosen', returnTime, updatedreturnTime + 'input time %s not divisible by 3 seconds, so next closest time %s will be chosen', + returnTime, + updatedreturnTime, ) returnTime = updatedreturnTime @@ -214,26 +224,28 @@ def get_station_data(inFile, dateList, gps_repo=None, numCPUs=8, outDir=None, re args.append((sf, name, dateList, returnTime)) outputfiles.append(name) # Parallelize remote querying of zenith delays - with multiprocessing.Pool(numCPUs) as multipool: + with mp.Pool(numCPUs) as multipool: multipool.starmap(get_delays_UNR, args) # confirm file exists (i.e. valid delays exists for specified time/region). outputfiles = [i for i in outputfiles if os.path.exists(i)] # Consolidate all CSV files into one object - if outputfiles == []: + if len(outputfiles) == 0: raise Exception('No valid delays found for specified time/region.') - name = os.path.join(outDir, '{}combinedGPS_ztd.csv'.format(gps_repo)) + name = os.path.join(outDir, f'{gps_repo}combinedGPS_ztd.csv') statsFile = pd.concat([pd.read_csv(i) for i in outputfiles]) # drop all duplicate lines statsFile.drop_duplicates(inplace=True) # Convert the above object into a csv file and export - statsFile.to_csv(name, index=False, encoding="utf-8") + statsFile.to_csv(name, index=False, encoding='utf-8') del statsFile # Add lat/lon/height info origstatsFile = pd.read_csv(inFile) statsFile = pd.read_csv(name) - statsFile = pd.merge(left=statsFile, right=origstatsFile[['ID', 'Lat', 'Lon', 'Hgt_m']], how='left', left_on='ID', right_on='ID') + statsFile = pd.merge( + left=statsFile, right=origstatsFile[['ID', 'Lat', 'Lon', 'Hgt_m']], how='left', left_on='ID', right_on='ID' + ) # drop all lines with nans and sort by station ID and year statsFile.dropna(how='any', inplace=True) # drop all duplicate lines @@ -244,25 +256,18 @@ def get_station_data(inFile, dateList, gps_repo=None, numCPUs=8, outDir=None, re def get_date(stationFile): - ''' - extract the date from a station delay file - ''' - + """Extract the date from a station delay file.""" # find the date info year = int(stationFile[1]) doy = int(stationFile[2]) date = dt.datetime(year, 1, 1) + dt.timedelta(doy - 1) - return date, year, doy def seconds_of_day(returnTime): - ''' - Convert HH:MM:SS format time-tag to seconds of day. - ''' + """Convert HH:MM:SS format time-tag to seconds of day.""" if isinstance(returnTime, dt.time): h, m, s = returnTime.hour, returnTime.minute, returnTime.second else: - h, m, s = map(int, returnTime.split(":")) - - return h * 3600 + m * 60 + s + h, m, s = map(int, returnTime.split(':')) + return h * 3600 + m * 60 + s diff --git a/tools/RAiDER/gnss/downloadGNSSDelays.py b/tools/RAiDER/gnss/downloadGNSSDelays.py index 1713ce022..67c97c9f6 100755 --- a/tools/RAiDER/gnss/downloadGNSSDelays.py +++ b/tools/RAiDER/gnss/downloadGNSSDelays.py @@ -6,31 +6,33 @@ # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ import itertools -import multiprocessing +import multiprocessing as mp import os + import pandas as pd -from RAiDER.logger import logger, logging from RAiDER.getStationDelays import get_station_data -from RAiDER.utilFcns import requests_retry_session +from RAiDER.logger import logger, logging from RAiDER.models.customExceptions import NoStationDataFoundError +from RAiDER.utilFcns import requests_retry_session + # base URL for UNR repository -_UNR_URL = "http://geodesy.unr.edu/" +_UNR_URL = 'http://geodesy.unr.edu/' def get_station_list( - bbox=None, - stationFile=None, - writeLoc=None, - name_appendix='', - writeStationFile=True, - ): - ''' - Creates a list of stations inside a lat/lon bounding box from a source + bbox=None, + stationFile=None, + writeLoc=None, + name_appendix='', + writeStationFile=True, +): + """ + Creates a list of stations inside a lat/lon bounding box from a source. Args: - bbox: list of float - length-4 list of floats that describes a bounding box. + bbox: list of float - length-4 list of floats that describes a bounding box. Format is S N W E station_file: str - Name of a .csv or .txt file to read containing station IDs writeStationFile: bool - Whether to write out the station dataframe to a .csv file @@ -40,7 +42,7 @@ def get_station_list( Returns: stations: list of strings - station IDs to access output_file: string or dataframe - file to write delays - ''' + """ if bbox is not None: station_data = get_stats_by_llh(llhBox=bbox) else: @@ -48,9 +50,9 @@ def get_station_list( station_data = pd.read_csv(stationFile) except: stations = [] - with open(stationFile, 'r') as f: + with open(stationFile) as f: for k, line in enumerate(f): - if k ==0: + if k == 0: names = line.strip().split() else: stations.append([line.strip().split()]) @@ -58,35 +60,27 @@ def get_station_list( # write to file and pass final stations list if writeStationFile: - output_file = os.path.join( - writeLoc or os.getcwd(), - 'gnssStationList_overbbox' + name_appendix + '.csv' - ) + output_file = os.path.join(writeLoc or os.getcwd(), 'gnssStationList_overbbox' + name_appendix + '.csv') station_data.to_csv(output_file, index=False) return list(station_data['ID'].values), [output_file if writeStationFile else station_data][0] def get_stats_by_llh(llhBox=None, baseURL=_UNR_URL): - ''' + """ Function to pull lat, lon, height, beginning date, end date, and number of solutions for stations inside the bounding box llhBox. llhBox should be a tuple in SNWE format. - ''' + """ if llhBox is None: llhBox = [-90, 90, 0, 360] S, N, W, E = llhBox if (W < 0) or (E < 0): - raise ValueError( - 'get_stats_by_llh: bounding box must be on lon range [0, 360]') + raise ValueError('get_stats_by_llh: bounding box must be on lon range [0, 360]') - stationHoldings = '{}NGLStationPages/llh.out'.format(baseURL) + stationHoldings = f'{baseURL}NGLStationPages/llh.out' # it's a file like object and works just like a file - stations = pd.read_csv( - stationHoldings, - sep=r'\s+', - names=['ID', 'Lat', 'Lon', 'Hgt_m'] - ) + stations = pd.read_csv(stationHoldings, sep=r'\s+', names=['ID', 'Lat', 'Lon', 'Hgt_m']) # convert lons from [0, 360] to [-180, 180] stations['Lon'] = ((stations['Lon'].values + 180) % 360) - 180 @@ -97,15 +91,16 @@ def get_stats_by_llh(llhBox=None, baseURL=_UNR_URL): def download_tropo_delays( - stats, years, + stats, + years, gps_repo='UNR', writeDir='.', numCPUs=8, download=False, -): - ''' - Check for and download GNSS tropospheric delays from an archive. If - download is True then files will be physically downloaded, but this +) -> None: + """ + Check for and download GNSS tropospheric delays from an archive. If + download is True then files will be physically downloaded, but this is not necessary as data can be virtually accessed. Args: @@ -118,8 +113,7 @@ def download_tropo_delays( Returns: None - ''' - + """ # argument checking if not isinstance(stats, (list, str)): raise TypeError('stats should be a string or a list of strings') @@ -132,42 +126,35 @@ def download_tropo_delays( # Parallelize remote querying of station locations results = [] - with multiprocessing.Pool(numCPUs) as multipool: + with mp.Pool(numCPUs) as multipool: # only record valid path if gps_repo == 'UNR': - results = [ - fileurl for fileurl in multipool.starmap(download_UNR, stat_year_tup) - if fileurl['path'] - ] + results = [fileurl for fileurl in multipool.starmap(download_UNR, stat_year_tup) if fileurl['path']] else: - raise NotImplementedError( - 'download_tropo_delays: gps_repo "{}" not yet implemented'.format(gps_repo)) + raise NotImplementedError(f'download_tropo_delays: gps_repo "{gps_repo}" not yet implemented') # Write results to file if len(results) == 0: - raise NoStationDataFoundError( - station_list=stats['ID'].to_list(), years=years) + raise NoStationDataFoundError(station_list=stats['ID'].to_list(), years=years) statDF = pd.DataFrame(results).set_index('ID') - statDF.to_csv(os.path.join( - writeDir, '{}gnssStationList_overbbox_withpaths.csv'.format(gps_repo))) + statDF.to_csv(os.path.join(writeDir, f'{gps_repo}gnssStationList_overbbox_withpaths.csv')) def download_UNR(statID, year, writeDir='.', download=False, baseURL=_UNR_URL): - ''' - Download a zip file containing tropospheric delays for a given station and year + """ + Download a zip file containing tropospheric delays for a given station and year. The URL format is http://geodesy.unr.edu/gps_timeseries/trop//..trop.zip Inputs: statID - 4-character station identifier year - 4-numeral year - ''' + """ if baseURL not in [_UNR_URL]: - raise NotImplementedError('Data repository {} has not yet been implemented'.format(baseURL)) - - URL = "{0}gps_timeseries/trop/{1}/{1}.{2}.trop.zip".format( - baseURL, statID.upper(), year) + raise NotImplementedError(f'Data repository {baseURL} has not yet been implemented') + + URL = '{0}gps_timeseries/trop/{1}/{1}.{2}.trop.zip'.format(baseURL, statID.upper(), year) logger.debug('Currently checking station %s in %s', statID, year) if download: - saveLoc = os.path.abspath(os.path.join(writeDir, '{0}.{1}.trop.zip'.format(statID.upper(), year))) + saveLoc = os.path.abspath(os.path.join(writeDir, f'{statID.upper()}.{year}.trop.zip')) filepath = download_url(URL, saveLoc) if filepath == '': raise ValueError('Year or station ID does not exist') @@ -177,10 +164,10 @@ def download_UNR(statID, year, writeDir='.', download=False, baseURL=_UNR_URL): def download_url(url, save_path, chunk_size=2048): - ''' + """ Download a file from a URL. Modified from - https://stackoverflow.com/questions/9419162/download-returned-zip-file-from-url - ''' + https://stackoverflow.com/questions/9419162/download-returned-zip-file-from-url. + """ session = requests_retry_session() r = session.get(url, stream=True) @@ -196,10 +183,10 @@ def download_url(url, save_path, chunk_size=2048): def check_url(url): - ''' + """ Check whether a file exists at a URL. Modified from - https://stackoverflow.com/questions/9419162/download-returned-zip-file-from-url - ''' + https://stackoverflow.com/questions/9419162/download-returned-zip-file-from-url. + """ session = requests_retry_session() r = session.head(url) if r.status_code == 404: @@ -208,16 +195,12 @@ def check_url(url): def in_box(lat, lon, llhbox): - ''' - Checks whether the given lat, lon pair are inside the bounding box llhbox - ''' + """Checks whether the given lat, lon pair are inside the bounding box llhbox.""" return lat < llhbox[1] and lat > llhbox[0] and lon < llhbox[3] and lon > llhbox[2] def fix_lons(lon): - """ - Fix the given longitudes into the range `[-180, 180]`. - """ + """Fix the given longitudes into the range `[-180, 180]`.""" fixed_lon = ((lon + 180) % 360) - 180 # Make the positive 180s positive again. if fixed_lon == -180 and lon > 0: @@ -226,17 +209,13 @@ def fix_lons(lon): def get_ID(line): - ''' - Pulls the station ID, lat, lon, and height for a given entry in the UNR text file - ''' + """Pulls the station ID, lat, lon, and height for a given entry in the UNR text file.""" stat_id, lat, lon, height = line.split()[:4] return stat_id, float(lat), float(lon), float(height) -def main(inps=None): - """ - Main workflow for querying supported GPS repositories for zenith delay information. - """ +def main(inps=None) -> None: + """Main workflow for querying supported GPS repositories for zenith delay information.""" try: dateList = inps.date_list returnTime = inps.time @@ -272,47 +251,40 @@ def main(inps=None): # iterate over years years = list(set([i.year for i in dateList])) - download_tropo_delays( - stats, years, gps_repo=gps_repo, writeDir=out, download=download - ) + download_tropo_delays(stats, years, gps_repo=gps_repo, writeDir=out, download=download) # Combine station data with URL info - pathsdf = pd.read_csv(os.path.join(out, '{}gnssStationList_overbbox_withpaths.csv'.format(gps_repo))) + pathsdf = pd.read_csv(os.path.join(out, f'{gps_repo}gnssStationList_overbbox_withpaths.csv')) pathsdf = pd.merge(left=pathsdf, right=statdf, how='left', left_on='ID', right_on='ID') - pathsdf.to_csv(os.path.join(out, '{}gnssStationList_overbbox_withpaths.csv'.format(gps_repo)), index=False) + pathsdf.to_csv(os.path.join(out, f'{gps_repo}gnssStationList_overbbox_withpaths.csv'), index=False) del statdf, pathsdf # Extract delays for each station dateList = [k.strftime('%Y-%m-%d') for k in dateList] get_station_data( - os.path.join( - out, '{}gnssStationList_overbbox_withpaths.csv'.format(gps_repo)), + os.path.join(out, f'{gps_repo}gnssStationList_overbbox_withpaths.csv'), dateList, gps_repo=gps_repo, numCPUs=cpus, outDir=out, - returnTime=returnTime + returnTime=returnTime, ) logger.debug('Completed processing') def parse_bbox(bounding_box): - ''' - Parse bounding box arguments - ''' + """Parse bounding box arguments.""" if isinstance(bounding_box, str) and not os.path.isfile(bounding_box): try: bbox = [float(val) for val in bounding_box.split()] except ValueError: - raise Exception( - 'Cannot understand the --bbox argument. String input is incorrect or path does not exist.') + raise Exception('Cannot understand the --bbox argument. String input is incorrect or path does not exist.') elif isinstance(bounding_box, list): bbox = bounding_box else: - raise Exception( - 'Passing a file with a bounding box not yet supported.') + raise Exception('Passing a file with a bounding box not yet supported.') long_cross_zero = 1 if bbox[2] * bbox[3] < 0 else 0 @@ -327,9 +299,7 @@ def parse_bbox(bounding_box): def get_stats(bbox, long_cross_zero, out, station_file): - ''' - Pull the stations needed - ''' + """Pull the stations needed.""" if long_cross_zero == 1: bbox1 = bbox.copy() bbox2 = bbox.copy() @@ -347,18 +317,16 @@ def get_stats(bbox, long_cross_zero, out, station_file): else: if bbox[3] < bbox[2]: bbox[3] = 360.0 - stats, statdata = get_station_list( - bbox=bbox, stationFile=station_file, writeStationFile=False - ) - + stats, statdata = get_station_list(bbox=bbox, stationFile=station_file, writeStationFile=False) + statdata.to_csv(station_file, index=False) return stats, statdata def filterToBBox(stations, llhBox): - ''' + """ Filter a dataframe by lat/lon. - *NOTE: llhBox longitude format should be [0, 360] + *NOTE: llhBox longitude format should be [0, 360]. Args: stations: DataFrame - a pandas dataframe with "Lat" and "Lon" columns @@ -366,7 +334,7 @@ def filterToBBox(stations, llhBox): Returns: a Pandas Dataframe with stations removed that are not inside llhBox - ''' + """ S, N, W, E = llhBox if (W < 0) or (E < 0): raise ValueError('llhBox longitude format should 0-360') @@ -381,15 +349,13 @@ def filterToBBox(stations, llhBox): index = k break if index is None: - raise KeyError( - 'filterToBBox: No valid column names found for latitude and longitude') + raise KeyError('filterToBBox: No valid column names found for latitude and longitude') lon_key = lon_keys[k] lat_key = lat_keys[k] if stations[lon_key].min() < 0: # convert lon format to -180 to 180 - W, E = [((D + 180) % 360) - 180 for D in [W, E]] + W, E = (((D + 180) % 360) - 180 for D in [W, E]) - mask = (stations[lat_key] > S) & (stations[lat_key] < N) & ( - stations[lon_key] < E) & (stations[lon_key] > W) + mask = (stations[lat_key] > S) & (stations[lat_key] < N) & (stations[lon_key] < E) & (stations[lon_key] > W) return stations[mask] diff --git a/tools/RAiDER/gnss/processDelayFiles.py b/tools/RAiDER/gnss/processDelayFiles.py index bd80c829e..4872eb525 100644 --- a/tools/RAiDER/gnss/processDelayFiles.py +++ b/tools/RAiDER/gnss/processDelayFiles.py @@ -1,154 +1,135 @@ -from textwrap import dedent import argparse -import datetime -import glob -import os -import re +import datetime as dt import math +import re +from pathlib import Path +from textwrap import dedent +from typing import Optional +import pandas as pd from tqdm import tqdm -import pandas as pd + pd.options.mode.chained_assignment = None # default='warn' -def combineDelayFiles(outName, loc=os.getcwd(), source='model', ext='.csv', ref=None, col_name='ZTD'): - files = glob.glob(os.path.join(loc, '*' + ext)) +def combineDelayFiles( + out_path: Path, + loc: Path=Path.cwd(), + source: str='model', + ext: str='.csv', + ref: Optional[Path]=None, + col_name: str='ZTD' +) -> None: + file_paths = list(loc.glob('*' + ext)) if source == 'model': print('Ensuring that "Datetime" column exists in files') - addDateTimeToFiles(files) + addDateTimeToFiles(file_paths) # If single file, just copy source - if len(files) == 1: + if len(file_paths) == 1: if source == 'model': import shutil - shutil.copy(files[0], outName) + shutil.copy(file_paths[0], out_path) else: - files = readZTDFile(files[0], col_name=col_name) + file_paths = readZTDFile(file_paths[0], col_name=col_name) # drop all lines with nans - files.dropna(how='any', inplace=True) + file_paths.dropna(how='any', inplace=True) # drop all duplicate lines - files.drop_duplicates(inplace=True) - files.to_csv(outName, index=False) + file_paths.drop_duplicates(inplace=True) + file_paths.to_csv(out_path, index=False) return - print('Combining {} delay files'.format(source)) + print(f'Combining {source} delay files') try: - concatDelayFiles( - files, - sort_list=['ID', 'Datetime'], - outName=outName, - source=source - ) - except BaseException: - concatDelayFiles( - files, - sort_list=['ID', 'Date'], - outName=outName, - source=source, - ref=ref, - col_name=col_name - ) - + concatDelayFiles(file_paths, sort_list=['ID', 'Datetime'], outName=out_path, source=source) + except: + concatDelayFiles(file_paths, sort_list=['ID', 'Date'], outName=out_path, source=source, ref=ref, col_name=col_name) -def addDateTimeToFiles(fileList, force=False, verbose=False): - ''' Run through a list of files and add the datetime of each file as a column ''' +def addDateTimeToFiles(file_paths: list[Path], force: bool=False, verbose: bool=False) -> None: + """Run through a list of files and add the datetime of each file as a column.""" print('Adding Datetime to delay files') - for f in tqdm(fileList): - data = pd.read_csv(f) + for path in tqdm(file_paths): + data = pd.read_csv(path) if 'Datetime' in data.columns and not force: if verbose: print( - 'File {} already has a "Datetime" column, pass' + f'File {path} already has a "Datetime" column, pass' '"force = True" if you want to override and ' - 're-process'.format(f) + 're-process' ) else: try: - dt = getDateTime(f) - data['Datetime'] = dt + data['Datetime'] = getDateTime(path) # drop all lines with nans data.dropna(how='any', inplace=True) # drop all duplicate lines data.drop_duplicates(inplace=True) - data.to_csv(f, index=False) + data.to_csv(path, index=False) except (AttributeError, ValueError): - print( - 'File {} does not contain datetime info, skipping' - .format(f) - ) + print(f'File {path} does not contain datetime info, skipping') del data -def getDateTime(filename): - ''' Parse a datetime from a RAiDER delay filename ''' - filename = os.path.basename(filename) - dtr = re.compile(r'\d{8}T\d{6}') - dt = dtr.search(filename) - return datetime.datetime.strptime( - dt.group(), - '%Y%m%dT%H%M%S' - ) +def getDateTime(path: Path) -> dt.datetime: + """Parse a datetime from a RAiDER delay filename.""" + datetime_pattern = re.compile(r'\d{8}T\d{6}') + match = datetime_pattern.search(path.name) + return dt.datetime.strptime(match.group(), '%Y%m%dT%H%M%S') def update_time(row, localTime_hrs): - '''Update with local origin time''' - localTime_estimate = row['Datetime'].replace(hour=localTime_hrs, - minute=0, second=0) + """Update with local origin time.""" + localTime_estimate = row['Datetime'].replace(hour=localTime_hrs, minute=0, second=0) # determine if you need to shift days - time_shift = datetime.timedelta(days=0) + time_shift = dt.timedelta(days=0) # round to nearest hour - days_diff = (row['Datetime'] - - datetime.timedelta(seconds=math.floor( - row['Localtime']) * 3600)).day - \ - localTime_estimate.day + days_diff = ( + row['Datetime'] - dt.timedelta(seconds=math.floor(row['Localtime']) * 3600) + ).day - localTime_estimate.day # if lon <0, check if you need to add day if row['Lon'] < 0: # add day if days_diff != 0: - time_shift = datetime.timedelta(days=1) + time_shift = dt.timedelta(days=1) # if lon >0, check if you need to subtract day if row['Lon'] > 0: # subtract day if days_diff != 0: - time_shift = -datetime.timedelta(days=1) - return localTime_estimate + datetime.timedelta(seconds=row['Localtime'] - * 3600) + time_shift + time_shift = -dt.timedelta(days=1) + return localTime_estimate + dt.timedelta(seconds=row['Localtime'] * 3600) + time_shift def pass_common_obs(reference, target, localtime=None): - '''Pass only observations in target spatiotemporally common to reference''' + """Pass only observations in target spatiotemporally common to reference.""" if isinstance(target['Datetime'].iloc[0], str): - target['Datetime'] = target['Datetime'].apply(lambda x: - datetime.datetime.strptime(x, '%Y-%m-%d %H:%M:%S')) + target['Datetime'] = target['Datetime'].apply( + lambda x: dt.datetime.strptime(x, '%Y-%m-%d %H:%M:%S') + ) if localtime: - return target[target['Datetime'].dt.date.isin(reference['Datetime'] - .dt.date) & - target['ID'].isin(reference['ID']) & - target[localtime].isin(reference[localtime])] + return target[ + target['Datetime'].dt.date.isin(reference['Datetime'].dt.date) + & target['ID'].isin(reference['ID']) + & target[localtime].isin(reference[localtime]) + ] else: - return target[target['Datetime'].dt.date.isin(reference['Datetime'] - .dt.date) & - target['ID'].isin(reference['ID'])] + return target[ + target['Datetime'].dt.date.isin(reference['Datetime'].dt.date) & + target['ID'].isin(reference['ID']) + ] def concatDelayFiles( - fileList, - sort_list=['ID', 'Datetime'], - return_df=False, - outName=None, - source='model', - ref=None, - col_name='ZTD' + fileList, sort_list=['ID', 'Datetime'], return_df=False, outName=None, source='model', ref=None, col_name='ZTD' ): - ''' + """ Read a list of .csv files containing the same columns and append them - together, sorting by specified columns - ''' + together, sorting by specified columns. + """ dfList = [] print('Concatenating delay files') @@ -165,17 +146,11 @@ def concatDelayFiles( dfList[i[0]] = pass_common_obs(dfr, i[1]) del dfr - df_c = pd.concat( - dfList, - ignore_index=True - ).drop_duplicates().reset_index(drop=True) + df_c = pd.concat(dfList, ignore_index=True).drop_duplicates().reset_index(drop=True) df_c.sort_values(by=sort_list, inplace=True) - print('Total number of rows in the concatenated file: {}'.format(df_c.shape[0])) - print('Total number of rows containing NaNs: {}'.format( - df_c[df_c.isna().any(axis=1)].shape[0] - ) - ) + print(f'Total number of rows in the concatenated file: {df_c.shape[0]}') + print(f'Total number of rows containing NaNs: {df_c[df_c.isna().any(axis=1)].shape[0]}') if return_df or outName is None: return df_c @@ -188,48 +163,38 @@ def concatDelayFiles( def local_time_filter(raiderFile, ztdFile, dfr, dfz, localTime): - ''' - Convert to local-time reference frame WRT 0 longitude - ''' + """Convert to local-time reference frame WRT 0 longitude.""" localTime_hrs = int(localTime.split(' ')[0]) localTime_hrthreshold = int(localTime.split(' ')[1]) # with rotation rate and distance to 0 lon, get localtime shift WRT 00 UTC at 0 lon # *rotation rate at given point = (360deg/23.9333333333hr) = 15.041782729825965 deg/hr - dfr['Localtime'] = (dfr['Lon'] / 15.041782729825965) - dfz['Localtime'] = (dfz['Lon'] / 15.041782729825965) + dfr['Localtime'] = dfr['Lon'] / 15.041782729825965 + dfz['Localtime'] = dfz['Lon'] / 15.041782729825965 # estimate local-times - dfr['Localtime'] = dfr.apply(lambda r: update_time(r, localTime_hrs), - axis=1) - dfz['Localtime'] = dfz.apply(lambda r: update_time(r, localTime_hrs), - axis=1) + dfr['Localtime'] = dfr.apply(lambda r: update_time(r, localTime_hrs), axis=1) + dfz['Localtime'] = dfz.apply(lambda r: update_time(r, localTime_hrs), axis=1) # filter out data outside of --localtime hour threshold - dfr['Localtime_u'] = dfr['Localtime'] + \ - datetime.timedelta(hours=localTime_hrthreshold) - dfr['Localtime_l'] = dfr['Localtime'] - \ - datetime.timedelta(hours=localTime_hrthreshold) + dfr['Localtime_u'] = dfr['Localtime'] + dt.timedelta(hours=localTime_hrthreshold) + dfr['Localtime_l'] = dfr['Localtime'] - dt.timedelta(hours=localTime_hrthreshold) OG_total = dfr.shape[0] - dfr = dfr[(dfr['Datetime'] >= dfr['Localtime_l']) & - (dfr['Datetime'] <= dfr['Localtime_u'])] + dfr = dfr[(dfr['Datetime'] >= dfr['Localtime_l']) & (dfr['Datetime'] <= dfr['Localtime_u'])] # only keep observation closest to Localtime - print('Total number of datapoints dropped in {} for not being within ' - '{} hrs of specified local-time {}: {} out of {}'.format( - raiderFile, localTime.split(' ')[1], localTime.split(' ')[0], - dfr.shape[0], OG_total)) - dfz['Localtime_u'] = dfz['Localtime'] + \ - datetime.timedelta(hours=localTime_hrthreshold) - dfz['Localtime_l'] = dfz['Localtime'] - \ - datetime.timedelta(hours=localTime_hrthreshold) + print( + f'Total number of datapoints dropped in {raiderFile} for not being within {localTime.split(" ")[1]} hrs of ' + f'specified local-time {localTime.split(" ")[0]}: {dfr.shape[0]} out of {OG_total}' + ) + dfz['Localtime_u'] = dfz['Localtime'] + dt.timedelta(hours=localTime_hrthreshold) + dfz['Localtime_l'] = dfz['Localtime'] - dt.timedelta(hours=localTime_hrthreshold) OG_total = dfz.shape[0] - dfz = dfz[(dfz['Datetime'] >= dfz['Localtime_l']) & - (dfz['Datetime'] <= dfz['Localtime_u'])] + dfz = dfz[(dfz['Datetime'] >= dfz['Localtime_l']) & (dfz['Datetime'] <= dfz['Localtime_u'])] # only keep observation closest to Localtime - print('Total number of datapoints dropped in {} for not being within ' - '{} hrs of specified local-time {}: {} out of {}'.format( - ztdFile, localTime.split(' ')[1], localTime.split(' ')[0], - dfz.shape[0], OG_total)) + print( + f'Total number of datapoints dropped in {ztdFile} for not being within {localTime.split(" ")[1]} hrs of ' + f'specified local-time {localTime.split(" ")[0]}: {dfz.shape[0]} out of {OG_total}' + ) # drop all lines with nans dfr.dropna(how='any', inplace=True) @@ -245,12 +210,10 @@ def local_time_filter(raiderFile, ztdFile, dfr, dfz, localTime): def readZTDFile(filename, col_name='ZTD'): - ''' - Read and parse a GPS zenith delay file - ''' + """Read and parse a GPS zenith delay file.""" try: data = pd.read_csv(filename, parse_dates=['Date']) - times = data['times'].apply(lambda x: datetime.timedelta(seconds=x)) + times = data['times'].apply(lambda x: dt.timedelta(seconds=x)) data['Datetime'] = data['Date'] + times except (KeyError, ValueError): data = pd.read_csv(filename, parse_dates=['Datetime']) @@ -259,7 +222,20 @@ def readZTDFile(filename, col_name='ZTD'): return data -def create_parser(): +def file_choices(p: argparse.ArgumentParser, choices: tuple[str], s: str) -> Path: + path = Path(s) + if path.suffix not in choices: + p.error(f"File must end with one of {choices}") + return path + +def parse_dir(p: argparse.ArgumentParser, s: str) -> Path: + path = Path(s) + if not path.is_dir(): + p.error("Path must be a directory") + return path + + +def create_parser() -> argparse.ArgumentParser: """Parse command line arguments using argparse.""" p = argparse.ArgumentParser( formatter_class=argparse.RawDescriptionHelpFormatter, @@ -269,44 +245,54 @@ def create_parser(): raiderCombine.py --raiderDir './*' --raider 'combined_raider_delays.csv' raiderCombine.py --raiderDir ERA5/ --raider ERA5_combined_delays.csv --raider_column totalDelay --gnssDir GNSS/ --gnss UNRCombined_gnss.csv --column ZTD -o Combined_delays.csv raiderCombine.py --raiderDir ERA5_2019/ --raider ERA5_combined_delays_2019.csv --raider_column totalDelay --gnssDir GNSS_2019/ --gnss UNRCombined_gnss_2019.csv --column ZTD -o Combined_delays_2019_UTTC18.csv --localtime '18:00:00 1' - """) + """), ) p.add_argument( - '--raider', dest='raider_file', + '--raider', + dest='raider_file', help=dedent("""\ .csv file containing RAiDER-derived Zenith Delays. Should contain columns "ID" and "Datetime" in addition to the delay column If the file does not exist, I will attempt to create it from a directory of delay files. """), - required=True + required=True, + type=lambda s: file_choices(p, ('csv',), s), ) p.add_argument( - '--raiderDir', '-d', dest='raider_folder', + '--raiderDir', + '-d', + dest='raider_folder', help=dedent("""\ Directory containing RAiDER-derived Zenith Delay files. Files should be named with a Datetime in the name and contain the column "ID" as the delay column names. """), - default=os.getcwd() + type=lambda s: parse_dir(p, s), + default=Path.cwd(), ) p.add_argument( - '--gnssDir', '-gd', dest='gnss_folder', + '--gnssDir', + '-gd', + dest='gnss_folder', help=dedent("""\ Directory containing GNSS-derived Zenith Delay files. Files should contain the column "ID" as the delay column names and times should be denoted by the "Date" key. """), - default=os.getcwd() + type=lambda s: parse_dir(p, s), + default=Path.cwd(), ) p.add_argument( - '--gnss', dest='gnss_file', + '--gnss', + dest='gnss_file', help=dedent("""\ Optional .csv file containing GPS Zenith Delays. Should contain columns "ID", "ZTD", and "Datetime" """), - default=None + default=None, + type=lambda s: file_choices(p, ('csv',), s), ) p.add_argument( @@ -316,7 +302,7 @@ def create_parser(): help=dedent("""\ Name of the column containing RAiDER delays. Only used with the "--gnss" option """), - default='totalDelay' + default='totalDelay', ) p.add_argument( '--column', @@ -326,7 +312,7 @@ def create_parser(): Name of the column containing GPS Zenith delays. Only used with the "--gnss" option """), - default='ZTD' + default='ZTD', ) p.add_argument( @@ -335,9 +321,9 @@ def create_parser(): dest='out_name', help=dedent("""\ Name to use for the combined delay file. Only used with the "--gnss" option - """), - default='Combined_delays.csv' + type=Path, + default=Path('Combined_delays.csv'), ) p.add_argument( @@ -349,42 +335,52 @@ def create_parser(): and within +/- specified hour threshold (2nd argument). By default UTC is passed as is without local-time conversions. Input in 'HH H', e.g. '16 1'" - """), - default=None + default=None, ) return p -def main(raiderFile, ztdFile, col_name='ZTD', raider_delay='totalDelay', outName=None, localTime=None): - ''' - Merge a combined RAiDER delays file with a GPS ZTD delay file - ''' - print('Merging delay files {} and {}'.format(raiderFile, ztdFile)) - dfr = pd.read_csv(raiderFile, parse_dates=['Datetime']) +def main( + raider_file: Path, + ztd_file: Path, + col_name: str='ZTD', + raider_delay: str='totalDelay', + out_path: Optional[Path]=None, + local_time=None +): + """Merge a combined RAiDER delays file with a GPS ZTD delay file.""" + print(f'Merging delay files {raider_file} and {ztd_file}') + dfr = pd.read_csv(raider_file, parse_dates=['Datetime']) # drop extra columns - expected_data_columns = ['ID', 'Lat', 'Lon', 'Hgt_m', 'Datetime', 'wetDelay', - 'hydroDelay', raider_delay] - dfr = dfr.drop(columns=[col for col in dfr if col not in - expected_data_columns]) - dfz = pd.read_csv(ztdFile, parse_dates=['Date']) - if not 'Datetime' in dfz.keys(): + expected_data_columns = ['ID', 'Lat', 'Lon', 'Hgt_m', 'Datetime', 'wetDelay', 'hydroDelay', raider_delay] + dfr = dfr.drop(columns=[col for col in dfr if col not in expected_data_columns]) + dfz = pd.read_csv(ztd_file, parse_dates=['Date']) + if 'Datetime' not in dfz.keys(): dfz.rename(columns={'Date': 'Datetime'}, inplace=True) # drop extra columns - expected_data_columns = ['ID', 'Datetime', 'wet_delay', 'hydrostatic_delay', - 'times', 'sigZTD', 'Lat', 'Lon', 'Hgt_m', - col_name] - dfz = dfz.drop(columns=[col for col in dfz if col not in - expected_data_columns]) + expected_data_columns = [ + 'ID', + 'Datetime', + 'wet_delay', + 'hydrostatic_delay', + 'times', + 'sigZTD', + 'Lat', + 'Lon', + 'Hgt_m', + col_name, + ] + dfz = dfz.drop(columns=[col for col in dfz if col not in expected_data_columns]) # only pass common locations and times dfz = pass_common_obs(dfr, dfz) dfr = pass_common_obs(dfz, dfr) # If specified, convert to local-time reference frame WRT 0 longitude common_keys = ['Datetime', 'ID'] - if localTime is not None: - dfr, dfz = local_time_filter(raiderFile, ztdFile, dfr, dfz, localTime) + if local_time is not None: + dfr, dfz = local_time_filter(raider_file, ztd_file, dfr, dfz, local_time) common_keys.append('Localtime') # only pass common locations and times dfz = pass_common_obs(dfr, dfz, localtime='Localtime') @@ -400,37 +396,27 @@ def main(raiderFile, ztdFile, col_name='ZTD', raider_delay='totalDelay', outName print('Beginning merge') dfc = dfr.merge( - dfz[common_keys + ['ZTD', 'sigZTD']], - how='left', - left_on=common_keys, - right_on=common_keys, - sort=True + dfz[common_keys + ['ZTD', 'sigZTD']], how='left', left_on=common_keys, right_on=common_keys, sort=True ) # only keep observation closest to Localtime if 'Localtime' in dfc.keys(): - dfc['Localtimediff'] = abs((dfc['Datetime'] - - dfc['Localtime']).dt.total_seconds() / 3600) - dfc = dfc.loc[dfc.groupby(['ID', 'Localtime']).Localtimediff.idxmin() - ].reset_index(drop=True) + dfc['Localtimediff'] = abs((dfc['Datetime'] - dfc['Localtime']).dt.total_seconds() / 3600) + dfc = dfc.loc[dfc.groupby(['ID', 'Localtime']).Localtimediff.idxmin()].reset_index(drop=True) dfc.drop(columns=['Localtimediff'], inplace=True) # estimate residual dfc['ZTD_minus_RAiDER'] = dfc['ZTD'] - dfc[raider_delay] - print('Total number of rows in the concatenated file: ' - '{}'.format(dfc.shape[0])) - print('Total number of rows containing NaNs: {}'.format( - dfc[dfc.isna().any(axis=1)].shape[0] - ) - ) + print('Total number of rows in the concatenated file: ' f'{dfc.shape[0]}') + print(f'Total number of rows containing NaNs: {dfc[dfc.isna().any(axis=1)].shape[0]}') print('Merge finished') - - if outName is None: + + if out_path is None: return dfc else: # drop all lines with nans dfc.dropna(how='any', inplace=True) # drop all duplicate lines dfc.drop_duplicates(inplace=True) - dfc.to_csv(outName, index=False) + dfc.to_csv(out_path, index=False) diff --git a/tools/RAiDER/gnss/types.py b/tools/RAiDER/gnss/types.py new file mode 100644 index 000000000..bf537fa81 --- /dev/null +++ b/tools/RAiDER/gnss/types.py @@ -0,0 +1,14 @@ +import argparse +from pathlib import Path +from typing import Optional + + +class RAiDERCombineArgs(argparse.Namespace): + raider_file: Path + raider_folder: Path + gnss_folder: Path + gnss_file: Optional[Path] + raider_column_name: str + column_name: str + out_name: Path + local_time: Optional[str] diff --git a/tools/RAiDER/interpolator.py b/tools/RAiDER/interpolator.py index 1b02f0030..08ed2218c 100644 --- a/tools/RAiDER/interpolator.py +++ b/tools/RAiDER/interpolator.py @@ -5,28 +5,28 @@ # RESERVED. United States Government Sponsorship acknowledged. # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +from pathlib import Path +from typing import Union import numpy as np import pandas as pd - -from scipy.interpolate import interp1d, RegularGridInterpolator as rgi +from scipy.interpolate import interp1d from RAiDER.interpolate import interpolate -class RegularGridInterpolator(object): +class RegularGridInterpolator: """ Provides a wrapper around RAiDER.interpolate.interpolate with a similar interface to scipy.interpolate.RegularGridInterpolator. """ - def __init__( self, grid, values, fill_value=None, assume_sorted=False, - max_threads=8 - ): + max_threads=8, + ) -> None: self.grid = grid self.values = values self.fill_value = fill_value @@ -37,7 +37,7 @@ def __call__(self, points): if isinstance(points, tuple): shape = points[0].shape for arr in points: - assert arr.shape == shape, "All dimensions must contain the same number of points!" + assert arr.shape == shape, 'All dimensions must contain the same number of points!' interp_points = np.stack(points, axis=-1) in_shape = interp_points.shape elif points.ndim > 2: @@ -53,62 +53,60 @@ def __call__(self, points): interp_points, fill_value=self.fill_value, assume_sorted=self.assume_sorted, - max_threads=self.max_threads + max_threads=self.max_threads, ) return out.reshape(in_shape[:-1]) def interp_along_axis(oldCoord, newCoord, data, axis=2, pad=False): - ''' + """ DEPRECATED: Use RAiDER.interpolate.interpolate_along_axis instead (it is much faster). This function now primarily exists to verify the behavior of the new one. Interpolate an array of 3-D data along one axis. This function assumes that the x-coordinate increases monotonically. - ''' + """ if oldCoord.ndim > 1: stackedData = np.concatenate([oldCoord, data, newCoord], axis=axis) out = np.apply_along_axis(interpVector, axis=axis, arr=stackedData, Nx=oldCoord.shape[axis]) else: - out = np.apply_along_axis(interpV, axis=axis, arr=data, old_x=oldCoord, new_x=newCoord, - left=np.nan, right=np.nan) + out = np.apply_along_axis( + interpV, axis=axis, arr=data, old_x=oldCoord, new_x=newCoord, left=np.nan, right=np.nan + ) return out def interpV(y, old_x, new_x, left=None, right=None, period=None): - ''' - Rearrange np.interp's arguments - ''' + """Rearrange np.interp's arguments.""" return np.interp(new_x, old_x, y, left=left, right=right, period=period) def interpVector(vec, Nx): - ''' + """ Interpolate data from a single vector containing the original x, the original y, and the new x, in that order. Nx tells the number of original x-points. - ''' + """ x = vec[:Nx] - y = vec[Nx:2 * Nx] - xnew = vec[2 * Nx:] + y = vec[Nx : 2 * Nx] + xnew = vec[2 * Nx :] f = interp1d(x, y, bounds_error=False, copy=False, assume_sorted=True) return f(xnew) -def fillna3D(array, axis=-1, fill_value=0.): - ''' +def fillna3D(array, axis=-1, fill_value=0.0): + """ This function fills in NaNs in 3D arrays, specifically using the nearest non-nan value - for "low" NaNs and 0s for "high" NaNs. + for "low" NaNs and 0s for "high" NaNs. - Arguments: + Arguments: array - 3D array, where the last axis is the "z" dimension - - Returns: - 3D array with low NaNs filled as nearest neighbors and high NaNs filled as 0s - ''' + Returns: + 3D array with low NaNs filled as nearest neighbors and high NaNs filled as 0s + """ # fill lower NaNs with nearest neighbor narr = np.moveaxis(array, axis, -1) nars = narr.reshape((np.prod(narr.shape[:-1]),) + (narr.shape[-1],)) @@ -121,16 +119,20 @@ def fillna3D(array, axis=-1, fill_value=0.): return outmat -def interpolateDEM(demFile, outLL, method='nearest'): - """ Interpolate a DEM raster to a set of lat/lon query points using rioxarray +def interpolateDEM(dem_path: Union[Path, str], outLL: tuple[np.ndarray, np.ndarray], method='nearest') -> np.ndarray: + """Interpolate a DEM raster to a set of lat/lon query points using rioxarray. outLL will be a tuple of (lats, lons). lats/lons can either be 1D arrays or 2 For now will only use first row/col of 2D """ import rioxarray as xrr - da_dem = xrr.open_rasterio(demFile, band_as_variable=True)['band_1'] + from xarray import Dataset + + data = xrr.open_rasterio(dem_path, band_as_variable=True) + assert isinstance(data, Dataset), 'DEM could not be opened as a rioxarray dataset' + da_dem = data['band_1'] lats, lons = outLL - lats = lats[:, 0] if lats.ndim==2 else lats - lons = lons[0, :] if lons.ndim==2 else lons - z_out = da_dem.interp(y=np.sort(lats)[::-1], x=lons).data - return z_out \ No newline at end of file + lats = lats[:, 0] if lats.ndim == 2 else lats + lons = lons[0, :] if lons.ndim == 2 else lons + z_out: np.ndarray = da_dem.interp(y=np.sort(lats)[::-1], x=lons).data + return z_out diff --git a/tools/RAiDER/llreader.py b/tools/RAiDER/llreader.py index a92b59745..bd6220dc7 100644 --- a/tools/RAiDER/llreader.py +++ b/tools/RAiDER/llreader.py @@ -6,10 +6,13 @@ # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ import os -import pyproj -import xarray +from pathlib import Path +from typing import Optional, Union import numpy as np +import pyproj +import xarray as xr + try: import pandas as pd @@ -18,45 +21,42 @@ from pyproj import CRS -from RAiDER.utilFcns import rio_open, rio_stats from RAiDER.logger import logger +from RAiDER.types import BB, RIO +from RAiDER.utilFcns import rio_open, rio_stats -class AOI(object): - ''' +class AOI: + """ This instantiates a generic AOI class object. Attributes: _bounding_box - S N W E bounding box _proj - pyproj-compatible CRS _type - Type of AOI - ''' - def __init__(self): - self._output_directory = os.getcwd() - self._bounding_box = None - self._proj = CRS.from_epsg(4326) - self._geotransform = None - self._cube_spacing_m = None + """ + def __init__(self, cube_spacing_in_m: Optional[float]=None) -> None: + self._output_directory = os.getcwd() + self._bounding_box = None + self._proj = CRS.from_epsg(4326) + self._geotransform = None + self._cube_spacing_m = cube_spacing_in_m def type(self): return self._type - def bounds(self): return list(self._bounding_box).copy() - def geotransform(self): return self._geotransform - def projection(self): return self._proj - def get_output_spacing(self, crs=4326): - """ Return the output spacing in desired units """ + """Return the output spacing in desired units.""" output_spacing_deg = self._output_spacing if not isinstance(crs, CRS): crs = CRS.from_epsg(crs) @@ -65,30 +65,26 @@ def get_output_spacing(self, crs=4326): if all(axis_info.unit_name == 'degree' for axis_info in crs.axis_info): output_spacing = output_spacing_deg else: - output_spacing = output_spacing_deg*1e5 + output_spacing = output_spacing_deg * 1e5 return output_spacing - - def set_output_spacing(self, ll_res=None): - """ Calculate the spacing for the output grid and weather model + def set_output_spacing(self, ll_res=None) -> None: + """Calculate the spacing for the output grid and weather model. Use the requested spacing if exists or the weather model grid itself Returns: None. Sets self._output_spacing """ - assert ll_res or self._cube_spacing_m, \ - 'Must pass lat/lon resolution if _cube_spacing_m is None' + assert ll_res or self._cube_spacing_m, 'Must pass lat/lon resolution if _cube_spacing_m is None' - out_spacing = self._cube_spacing_m / 1e5 \ - if self._cube_spacing_m else ll_res + out_spacing = self._cube_spacing_m / 1e5 if self._cube_spacing_m else ll_res logger.debug(f'Output cube spacing: {out_spacing} degrees') self._output_spacing = out_spacing - - def add_buffer(self, ll_res, digits=2): + def add_buffer(self, ll_res, digits=2) -> None: """ Add a fixed buffer to the AOI, accounting for the cube spacing. @@ -114,21 +110,19 @@ def add_buffer(self, ll_res, digits=2): ## add an extra buffer around the user specified region S, N, W, E = self.bounds() - buffer = (1.5 * ll_res) - S, N = np.max([S-buffer, -90]), np.min([N+buffer, 90]) - W, E = W-buffer, E+buffer # TODO: handle dateline crossings + buffer = 1.5 * ll_res + S, N = np.max([S - buffer, -90]), np.min([N + buffer, 90]) + W, E = W - buffer, E + buffer # TODO: handle dateline crossings ## clip the buffered region to a multiple of the spacing self.set_output_spacing(ll_res) - S, N, W, E = clip_bbox([S,N,W,E], self._output_spacing) + S, N, W, E = clip_bbox([S, N, W, E], self._output_spacing) if np.max([np.abs(W), np.abs(E)]) > 180: logger.warning('Bounds extend past +/- 180. Results may be incorrect.') self._bounding_box = [np.round(a, digits) for a in (S, N, W, E)] - return - def calc_buffer_ray(self, direction, lookDir='right', incAngle=30, maxZ=80, digits=2): """ @@ -148,20 +142,18 @@ def calc_buffer_ray(self, direction, lookDir='right', incAngle=30, maxZ=80, digi except AttributeError: lookDir = lookDir.lower() - assert direction in 'asc desc'.split(), \ - f'Incorrection orbital direction: {direction}. Choose asc or desc.' - assert lookDir in 'right light'.split(), \ - f'Incorrection look direction: {lookDir}. Choose right or left.' + assert direction in 'asc desc'.split(), f'Incorrection orbital direction: {direction}. Choose asc or desc.' + assert lookDir in 'right light'.split(), f'Incorrection look direction: {lookDir}. Choose right or left.' S, N, W, E = self.bounds() # use a small look angle to calculate near range lat_max = np.max([np.abs(S), np.abs(N)]) - near = maxZ * np.tan(np.deg2rad(incAngle)) - buffer = near / (np.cos(np.deg2rad(lat_max)) * 100) + near = maxZ * np.tan(np.deg2rad(incAngle)) + buffer = near / (np.cos(np.deg2rad(lat_max)) * 100) # buffer on the side nearest the sensor - if ((lookDir == 'right') and (direction == 'asc')) or ((lookDir == 'left') and (direction == 'desc')): + if (lookDir == 'right' and direction == 'asc') or (lookDir == 'left' and direction == 'desc'): W = W - buffer else: E = E + buffer @@ -171,14 +163,11 @@ def calc_buffer_ray(self, direction, lookDir='right', incAngle=30, maxZ=80, digi logger.warning('Bounds extend past +/- 180. Results may be incorrect.') return bounds - - def set_output_directory(self, output_directory): + def set_output_directory(self, output_directory) -> None: self._output_directory = output_directory - return - - def set_output_xygrid(self, dst_crs=4326): - """ Define the locations where the delays will be returned """ + def set_output_xygrid(self, dst_crs: Union[int, str]=4326) -> None: + """Define the locations where the delays will be returned.""" from RAiDER.utilFcns import transform_bbox try: @@ -189,54 +178,52 @@ def set_output_xygrid(self, dst_crs=4326): except pyproj.exceptions.CRSError: out_proj = dst_crs - out_snwe = transform_bbox(self.bounds(), src_crs=4326, dest_crs=out_proj) - logger.debug(f"Output SNWE: {out_snwe}") + logger.debug(f'Output SNWE: {out_snwe}') # Build the output grid out_spacing = self.get_output_spacing(out_proj) self.xpts = np.arange(out_snwe[2], out_snwe[3] + out_spacing, out_spacing) self.ypts = np.arange(out_snwe[1], out_snwe[0] - out_spacing, -out_spacing) - return class StationFile(AOI): - '''Use a .csv file containing at least Lat, Lon, and optionally Hgt_m columns''' - def __init__(self, station_file, demFile=None): - super().__init__() + """Use a .csv file containing at least Lat, Lon, and optionally Hgt_m columns.""" + + def __init__(self, station_file, demFile=None, cube_spacing_in_m: Optional[float]=None) -> None: + super().__init__(cube_spacing_in_m) self._filename = station_file - self._demfile = demFile + self._demfile = demFile self._bounding_box = bounds_from_csv(station_file) self._type = 'station_file' - - def readLL(self): - '''Read the station lat/lons from the csv file''' - df = pd.read_csv(self._filename).drop_duplicates(subset=["Lat", "Lon"]) - return df['Lat'].values, df['Lon'].values - + def readLL(self) -> tuple[np.ndarray, np.ndarray]: + """Read the station lat/lons from the csv file.""" + df = pd.read_csv(self._filename).drop_duplicates(subset=['Lat', 'Lon']) + return df['Lat'].to_numpy(), df['Lon'].to_numpy() def readZ(self): - ''' - Read the station heights from the file, or download a DEM if not present - ''' - df = pd.read_csv(self._filename).drop_duplicates(subset=["Lat", "Lon"]) + """Read the station heights from the file, or download a DEM if not present.""" + df = pd.read_csv(self._filename).drop_duplicates(subset=['Lat', 'Lon']) if 'Hgt_m' in df.columns: return df['Hgt_m'].values else: # Download the DEM from RAiDER.dem import download_dem from RAiDER.interpolator import interpolateDEM - - demFile = os.path.join(self._output_directory, 'GLO30_fullres_dem.tif') \ - if self._demfile is None else self._demfile - _, _ = download_dem( + demFile = ( + os.path.join(self._output_directory, 'GLO30_fullres_dem.tif') + if self._demfile is None + else self._demfile + ) + + download_dem( self._bounding_box, writeDEM=True, - demName=demFile, + dem_path=Path(demFile), ) - + ## interpolate the DEM to the query points z_out0 = interpolateDEM(demFile, self.readLL()) if np.isnan(z_out0).all(): @@ -251,11 +238,10 @@ def readZ(self): class RasterRDR(AOI): - ''' - Use a 2-band raster file containing lat/lon coordinates. - ''' - def __init__(self, lat_file, lon_file=None, hgt_file=None, dem_file=None, convention='isce'): - super().__init__() + """Use a 2-band raster file containing lat/lon coordinates.""" + + def __init__(self, lat_file, lon_file=None, hgt_file=None, dem_file=None, convention='isce', cube_spacing_in_m: Optional[float]=None) -> None: + super().__init__(cube_spacing_in_m) self._type = 'radar_rasters' self._latfile = lat_file self._lonfile = lon_file @@ -277,37 +263,38 @@ def __init__(self, lat_file, lon_file=None, hgt_file=None, dem_file=None, conven self._demfile = dem_file self._convention = convention - - def readLL(self): + def readLL(self) -> tuple[np.ndarray, Optional[np.ndarray]]: # allow for 2-band lat/lon raster - lats = rio_open(self._latfile) + lats, _ = rio_open(Path(self._latfile)) if self._lonfile is None: - return lats + return lats, None else: - return lats, rio_open(self._lonfile) - + lons, _ = rio_open(Path(self._lonfile)) + return lats, lons - def readZ(self): - ''' - Read the heights from the raster file, or download a DEM if not present - ''' + def readZ(self) -> np.ndarray: + """Read the heights from the raster file, or download a DEM if not present.""" if self._hgtfile is not None and os.path.exists(self._hgtfile): logger.info('Using existing heights at: %s', self._hgtfile) - return rio_open(self._hgtfile) + hgts, _ = rio_open(self._hgtfile) + return hgts else: # Download the DEM from RAiDER.dem import download_dem from RAiDER.interpolator import interpolateDEM - - demFile = os.path.join(self._output_directory, 'GLO30_fullres_dem.tif') \ - if self._demfile is None else self._demfile - _, _ = download_dem( + demFile = ( + os.path.join(self._output_directory, 'GLO30_fullres_dem.tif') + if self._demfile is None + else self._demfile + ) + + download_dem( self._bounding_box, writeDEM=True, - demName=demFile, + dem_path=Path(demFile), ) z_out = interpolateDEM(demFile, self.readLL()) @@ -315,104 +302,105 @@ def readZ(self): class BoundingBox(AOI): - '''Parse a bounding box AOI''' - def __init__(self, bbox): - AOI.__init__(self) + """Parse a bounding box AOI.""" + + def __init__(self, bbox, cube_spacing_in_m: Optional[float]=None) -> None: + super().__init__(cube_spacing_in_m) self._bounding_box = bbox self._type = 'bounding_box' class GeocodedFile(AOI): - '''Parse a Geocoded file for coordinates''' - def __init__(self, filename, is_dem=False): - super().__init__() + """Parse a Geocoded file for coordinates.""" + + p: RIO.Profile + _bounding_box: BB.SNWE + _is_dem: bool - from RAiDER.utilFcns import rio_profile, rio_extents + def __init__(self, path: Path, is_dem=False, cube_spacing_in_m: Optional[float]=None) -> None: + super().__init__(cube_spacing_in_m) - self._filename = filename - self.p = rio_profile(filename) + from RAiDER.utilFcns import rio_extents, rio_profile + + self._filename = path + self.p = rio_profile(path) self._bounding_box = rio_extents(self.p) - self._is_dem = is_dem - _, self._proj, self._geotransform = rio_stats(filename) + self._is_dem = is_dem + _, self._proj, self._geotransform = rio_stats(path) self._type = 'geocoded_file' try: self.crs = self.p['crs'] except KeyError: self.crs = None - - def readLL(self): + def readLL(self) -> tuple[np.ndarray, np.ndarray]: # ll_bounds are SNWE S, N, W, E = self._bounding_box w, h = self.p['width'], self.p['height'] - px = (E - W) / w - py = (N - S) / h + px = (E - W) / w + py = (N - S) / h x = np.array([W + (t * px) for t in range(w)]) y = np.array([S + (t * py) for t in range(h)]) - X, Y = np.meshgrid(x,y) - return Y, X # lats, lons - + X, Y = np.meshgrid(x, y) + return Y, X # lats, lons def readZ(self): - ''' - Download a DEM for the file - ''' + """Download a DEM for the file.""" from RAiDER.dem import download_dem from RAiDER.interpolator import interpolateDEM demFile = self._filename if self._is_dem else 'GLO30_fullres_dem.tif' - bbox = self._bounding_box - _, _ = download_dem(bbox, writeDEM=True, demName=demFile) + bbox = self._bounding_box + _, _ = download_dem(bbox, writeDEM=True, dem_path=Path(demFile)) z_out = interpolateDEM(demFile, self.readLL()) return z_out class Geocube(AOI): - """ Pull lat/lon/height from a georeferenced data cube """ - def __init__(self, path_cube): - super().__init__() - self.path = path_cube + """Pull lat/lon/height from a georeferenced data cube.""" + + def __init__(self, path_cube, cube_spacing_in_m: Optional[float]=None) -> None: + super().__init__(cube_spacing_in_m) + self.path = path_cube self._type = 'Geocube' self._bounding_box = self.get_extent() _, self._proj, self._geotransform = rio_stats(path_cube) def get_extent(self): - with xarray.open_dataset(self.path) as ds: + with xr.open_dataset(self.path) as ds: S, N = ds.latitude.min().item(), ds.latitude.max().item() W, E = ds.longitude.min().item(), ds.longitude.max().item() return [S, N, W, E] ## untested - def readLL(self): - with xarray.open_dataset(self.path) as ds: + def readLL(self) -> tuple[np.ndarray, np.ndarray]: + with xr.open_dataset(self.path) as ds: lats = ds.latitutde.data() lons = ds.longitude.data() Lats, Lons = np.meshgrid(lats, lons) return Lats, Lons def readZ(self): - with xarray.open_dataset(self.path) as ds: + with xr.open_dataset(self.path) as ds: heights = ds.heights.data return heights -def bounds_from_latlon_rasters(latfile, lonfile): - ''' +def bounds_from_latlon_rasters(lat_filestr: str, lon_filestr: str) -> tuple[BB.SNWE, CRS, RIO.GDAL]: + """ Parse lat/lon/height inputs and return - the appropriate outputs - ''' + the appropriate outputs. + """ from RAiDER.utilFcns import get_file_and_band - latinfo = get_file_and_band(latfile) - loninfo = get_file_and_band(lonfile) + + latinfo = get_file_and_band(lat_filestr) + loninfo = get_file_and_band(lon_filestr) lat_stats, lat_proj, lat_gt = rio_stats(latinfo[0], band=latinfo[1]) lon_stats, lon_proj, lon_gt = rio_stats(loninfo[0], band=loninfo[1]) - if lat_proj != lon_proj: - raise ValueError('Projection information for Latitude and Longitude files does not match') - - if lat_gt != lon_gt: - raise ValueError('Affine transform for Latitude and Longitude files does not match') + assert lat_proj == lon_proj, 'Projection information for Latitude and Longitude files does not match' + assert lat_gt == lon_gt, 'Affine transform for Latitude and Longitude files does not match' # TODO - handle dateline crossing here snwe = (lat_stats.min, lat_stats.max, @@ -426,10 +414,10 @@ def bounds_from_latlon_rasters(latfile, lonfile): def bounds_from_csv(station_file): - ''' + """ station_file should be a comma-delimited file with at least "Lat" - and "Lon" columns, which should be EPSG: 4326 projection (i.e WGS84) - ''' - stats = pd.read_csv(station_file).drop_duplicates(subset=["Lat", "Lon"]) + and "Lon" columns, which should be EPSG: 4326 projection (i.e WGS84). + """ + stats = pd.read_csv(station_file).drop_duplicates(subset=['Lat', 'Lon']) snwe = [stats['Lat'].min(), stats['Lat'].max(), stats['Lon'].min(), stats['Lon'].max()] return snwe diff --git a/tools/RAiDER/logger.py b/tools/RAiDER/logger.py index cb3951a3f..3e7699dea 100644 --- a/tools/RAiDER/logger.py +++ b/tools/RAiDER/logger.py @@ -5,13 +5,13 @@ # RESERVED. United States Government Sponsorship acknowledged. # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -""" -Global logging configuration -""" +"""Global logging configuration.""" + import logging import os import sys -from logging import FileHandler, Formatter, StreamHandler +from logging import FileHandler, Formatter, LogRecord, StreamHandler +from pathlib import Path import RAiDER.cli.conf as conf @@ -19,10 +19,10 @@ # Inspired by # https://stackoverflow.com/questions/384076/how-can-i-color-python-logging-output class UnixColorFormatter(Formatter): - yellow = "\x1b[33;21m" - red = "\x1b[31;21m" - bold_red = "\x1b[31;1m" - reset = "\x1b[0m" + yellow = '\x1b[33;21m' + red = '\x1b[31;21m' + bold_red = '\x1b[31;1m' + reset = '\x1b[0m' COLORS = { logging.WARNING: yellow, @@ -30,57 +30,54 @@ class UnixColorFormatter(Formatter): logging.CRITICAL: bold_red } - def __init__(self, fmt=None, datefmt=None, style="%", use_color=True): + def __init__(self, fmt: str = None, datefmt: str = None, style: str = '%', use_color: bool=True) -> None: super().__init__(fmt, datefmt, style) # Save the old function so we can call it later self.__formatMessage = self.formatMessage if use_color: self.formatMessage = self.formatMessageColor - def formatMessageColor(self, record): + def formatMessageColor(self, record: LogRecord) -> str: message = self.__formatMessage(record) color = self.COLORS.get(record.levelno) if color: - message = "".join([color, message, self.reset]) + message = ''.join([color, message, self.reset]) return message class CustomFormatter(UnixColorFormatter): """Adds levelname prefixes to the message on warning or above.""" - - def formatMessage(self, record): + def formatMessage(self, record: LogRecord) -> str: message = super().formatMessage(record) if record.levelno >= logging.WARNING: - message = ": ".join((record.levelname, message)) + message = ': '.join((record.levelname, message)) return message ##################################### -## DEFINE THE LOGGER -if conf.LOGGER_PATH is None: - logger_path = os.getcwd() -else: +# DEFINE THE LOGGER +if conf.LOGGER_PATH is not None: logger_path = conf.LOGGER_PATH +else: + logger_path = Path.cwd() -logger = logging.getLogger("RAiDER") +logger = logging.getLogger('RAiDER') logger.setLevel(logging.DEBUG) stdout_handler = StreamHandler(sys.stdout) -stdout_handler.setFormatter(CustomFormatter(use_color=os.name != "nt")) +stdout_handler.setFormatter(CustomFormatter(use_color=os.name != 'nt')) stdout_handler.setLevel(logging.DEBUG) -debugfile_handler = FileHandler(os.path.join(logger_path, "debug.log")) -debugfile_handler.setFormatter(Formatter( - "[{asctime}] {levelname:<10} {module} {exc_info} {funcName:>20}:{lineno:<5} {message}", - style="{" -)) +debugfile_handler = FileHandler(logger_path / 'debug.log') +debugfile_handler.setFormatter( + Formatter('[{asctime}] {levelname:<10} {module} {exc_info} {funcName:>20}:{lineno:<5} {message}', style='{') +) debugfile_handler.setLevel(logging.DEBUG) -errorfile_handler = FileHandler(os.path.join(logger_path, "error.log")) -errorfile_handler.setFormatter(Formatter( - "[{asctime}] {levelname:<10} {module:<10} {exc_info} {funcName:>20}:{lineno:<5} {message}", - style="{" -)) +errorfile_handler = FileHandler(logger_path / 'error.log') +errorfile_handler.setFormatter( + Formatter('[{asctime}] {levelname:<10} {module:<10} {exc_info} {funcName:>20}:{lineno:<5} {message}', style='{') +) # , , , , , , errorfile_handler.setLevel(logging.WARNING) diff --git a/tools/RAiDER/losreader.py b/tools/RAiDER/losreader.py index 3faaa7ef5..7d5e44f13 100644 --- a/tools/RAiDER/losreader.py +++ b/tools/RAiDER/losreader.py @@ -6,14 +6,16 @@ # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +import datetime as dt import os -import datetime import shelve from abc import ABC -from typing import Union from pathlib import PosixPath +from typing import Literal, NoReturn, Union import numpy as np + + try: import xml.etree.ElementTree as ET except ImportError: @@ -24,25 +26,21 @@ isce = None from RAiDER.constants import _ZREF -from RAiDER.utilFcns import ( - cosd, sind, rio_open, lla2ecef, ecef2lla -) +from RAiDER.utilFcns import cosd, ecef2lla, lla2ecef, rio_open, sind class LOS(ABC): - ''' - LOS Class definition for handling look vectors - ''' - def __init__(self): + """LOS Class definition for handling look vectors.""" + + def __init__(self) -> None: self._lats, self._lons, self._heights = None, None, None self._look_vecs = None self._ray_trace = False self._is_zenith = False self._is_projected = False - - def setPoints(self, lats, lons=None, heights=None): - '''Set the pixel locations''' + def setPoints(self, lats, lons=None, heights=None) -> None: + """Set the pixel locations.""" if (lats is None) and (self._lats is None): raise RuntimeError("You haven't given any point locations yet") @@ -61,42 +59,35 @@ def setPoints(self, lats, lons=None, heights=None): self._lons = lons self._heights = heights - - def setTime(self, dt): - self._time = dt - + def setTime(self, datetime) -> None: + self._time = datetime def is_Zenith(self): return self._is_zenith - def is_Projected(self): return self._is_projected - def ray_trace(self): return self._ray_trace class Zenith(LOS): - """ - Class definition for a "Zenith" object. - """ - def __init__(self): + """Class definition for a "Zenith" object.""" + + def __init__(self) -> None: super().__init__() self._is_zenith = True - - def setLookVectors(self): - '''Set point locations and calculate Zenith look vectors''' + def setLookVectors(self) -> None: + """Set point locations and calculate Zenith look vectors.""" if self._lats is None: raise ValueError('Target points not set') if self._look_vecs is None: self._look_vecs = getZenithLookVecs(self._lats, self._lons, self._heights) - def __call__(self, delays): - '''Placeholder method for consistency with the other classes''' + """Placeholder method for consistency with the other classes.""" return delays @@ -105,21 +96,19 @@ class Conventional(LOS): Special value indicating that the zenith delay will be projected using the standard cos(inc) scaling. """ - def __init__(self, filename=None, los_convention='isce', time=None, pad=600): + + def __init__(self, filename=None, los_convention='isce', time=None, pad=600) -> None: super().__init__() self._file = filename self._time = time - self._pad = pad + self._pad = pad self._is_projected = True - self._convention = los_convention + self._convention = los_convention if self._convention.lower() != 'isce': raise NotImplementedError() - def __call__(self, delays): - ''' - Read the LOS file and convert it to look vectors - ''' + """Read the LOS file and convert it to look vectors.""" if self._lats is None: raise ValueError('Target points not set') if self._file is None: @@ -127,16 +116,16 @@ def __call__(self, delays): try: # if an ISCE-style los file is passed open it with GDAL - LOS_enu = inc_hd_to_enu(*rio_open(self._file)) + data, _ = rio_open(self._file) + LOS_enu = inc_hd_to_enu(*data) except (OSError, TypeError): # Otherwise, treat it as an orbit / statevector file - svs = np.stack( - get_sv(self._file, self._time, self._pad), axis=-1 + svs = np.stack(get_sv(self._file, self._time, self._pad), axis=-1) + LOS_enu = state_to_los( + svs, + [self._lats, self._lons, self._heights], ) - LOS_enu = state_to_los(svs, - [self._lats, self._lons, self._heights], - ) if delays.shape == LOS_enu.shape: return delays / LOS_enu @@ -182,16 +171,16 @@ class Raytracing(LOS): >>> import numpy as np """ - def __init__(self, filename=None, los_convention='isce', time=None, look_dir = 'right', pad=600): - '''read in and parse a statevector file''' + def __init__(self, filename=None, los_convention='isce', time=None, look_dir='right', pad=600) -> None: + """Read in and parse a statevector file.""" if isce is None: - raise ImportError(f'isce3 is required for this class. Use conda to install isce3`') + raise ImportError('isce3 is required for this class. Use conda to install isce3`') super().__init__() self._ray_trace = True self._file = filename self._time = time - self._pad = pad + self._pad = pad self._convention = los_convention self._orbit = None if self._convention.lower() != 'isce': @@ -203,42 +192,34 @@ def __init__(self, filename=None, los_convention='isce', time=None, look_dir = ' self._orbit = get_orbit(self._file, self._time, pad=pad) self._elp = isce.core.Ellipsoid() self._dop = isce.core.LUT2d() - if look_dir.lower() == "right": + if look_dir.lower() == 'right': self._look_dir = isce.core.LookSide.Right - elif look_dir.lower() == "left": + elif look_dir.lower() == 'left': self._look_dir = isce.core.LookSide.Left else: - raise RuntimeError(f"Unknown look direction: {look_dir}") + raise RuntimeError(f'Unknown look direction: {look_dir}') - - def getSensorDirection(self): + def getSensorDirection(self) -> Literal['desc', 'asc']: if self._orbit is None: raise ValueError('The orbit has not been set') - z = self._orbit.position[:,2] + z = self._orbit.position[:, 2] t = self._orbit.time start = np.argmin(t) end = np.argmax(t) - if z[start] > z[end]: - return 'desc' - else: - return 'asc' - + return 'desc' if z[start] > z[end] else 'asc' def getLookDirection(self): return self._look_dir # Called in checkArgs - def setTime(self, time, pad=600): + def setTime(self, time, pad=600) -> None: self._time = time self._orbit = get_orbit(self._file, self._time, pad=pad) - def getLookVectors(self, ht, llh, xyz, yy): - ''' - Calculate look vectors for raytracing - ''' + """Calculate look vectors for raytracing.""" if isce is None: - raise ImportError(f'isce3 is required for this method. Use conda to install isce3`') + raise ImportError('isce3 is required for this method. Use conda to install isce3`') # TODO - Modify when isce3 vectorization is available los = np.full(yy.shape + (3,), np.nan) @@ -257,26 +238,30 @@ def getLookVectors(self, ht, llh, xyz, yy): # Wavelength does not matter for try: aztime, slant_range = isce.geometry.geo2rdr( - inp, self._elp, self._orbit, self._dop, 0.06, self._look_dir, + inp, + self._elp, + self._orbit, + self._dop, + 0.06, + self._look_dir, threshold=1.0e-7, maxiter=30, - delta_range=10.0) + delta_range=10.0, + ) sat_xyz, _ = self._orbit.interpolate(aztime) los[ii, jj, :] = (sat_xyz - inp_xyz) / slant_range - except Exception as e: + except: los[ii, jj, :] = np.nan return los - def getIntersectionWithHeight(self, height): """ This function computes the intersection point of a ray at a height - level + level. """ # We just leverage the same code as finding top of atmosphere here return getTopOfAtmosphere(self._xyz, self._look_vecs, height) - def getIntersectionWithLevels(self, levels): """ This function returns the points at which rays intersect the @@ -302,12 +287,11 @@ def getIntersectionWithLevels(self, levels): return rays - - def calculateDelays(self, delays): - ''' + def calculateDelays(self, delays) -> NoReturn: + """ Here "delays" is point-wise delays (i.e. refractivities), not integrated ZTD/STD. - ''' + """ # Create rays (Use getIntersectionWithLevels above) # Interpolate delays to rays # Integrate along rays @@ -316,7 +300,7 @@ def calculateDelays(self, delays): def getZenithLookVecs(lats, lons, heights): - ''' + """ Returns look vectors when Zenith is used. Args: @@ -324,7 +308,7 @@ def getZenithLookVecs(lats, lons, heights): Returns: zenLookVecs (ndarray): - (in_shape) x 3 unit look vectors in an ECEF reference frame - ''' + """ x = np.cos(np.radians(lats)) * np.cos(np.radians(lons)) y = np.cos(np.radians(lats)) * np.sin(np.radians(lons)) z = np.sin(np.radians(lats)) @@ -332,11 +316,9 @@ def getZenithLookVecs(lats, lons, heights): return np.stack([x, y, z], axis=-1) -def get_sv(los_file: Union[str, list, PosixPath], - ref_time: datetime.datetime, - pad: int): +def get_sv(los_file: Union[str, list, PosixPath], ref_time: dt.datetime, pad: int): """ - Read an LOS file and return orbital state vectors + Read an LOS file and return orbital state vectors. Args: los_file (str, Path, list): - user-passed file containing either look @@ -366,6 +348,7 @@ def get_sv(los_file: Union[str, list, PosixPath], def filter_ESA_orbit_file_p(path: str) -> bool: return filter_ESA_orbit_file(path, ref_time) + los_files = list(filter(filter_ESA_orbit_file_p, los_files)) if not los_files: raise ValueError('There are no valid orbit files provided') @@ -373,17 +356,14 @@ def filter_ESA_orbit_file_p(path: str) -> bool: for orb_path in los_files: svs.extend(read_ESA_Orbit_file(orb_path)) - except BaseException: + except: try: svs = read_shelve(los_file) - except BaseException: - raise ValueError( - f'get_sv: I cannot parse the statevector file {los_file}' - ) + except: + raise ValueError(f'get_sv: I cannot parse the statevector file {los_file}') except: raise ValueError(f'get_sv: I cannot parse the statevector file {los_file}') - if ref_time: idx = cut_times(svs[0], ref_time, pad=pad) svs = [d[idx] for d in svs] @@ -392,7 +372,7 @@ def filter_ESA_orbit_file_p(path: str) -> bool: def inc_hd_to_enu(incidence, heading): - ''' + """ Convert incidence and heading to line-of-sight vectors from the ground to the top of the troposphere. @@ -405,7 +385,7 @@ def inc_hd_to_enu(incidence, heading): LOS: ndarray - (input_shape) x 3 array of unit look vectors in local ENU Algorithm referenced from http://earthdef.caltech.edu/boards/4/topics/327 - ''' + """ if np.any(incidence < 0): raise ValueError('inc_hd_to_enu: Incidence angle cannot be less than 0') @@ -447,7 +427,7 @@ def read_shelve(filename): def read_txt_file(filename): - ''' + """ Read a 7-column text file containing orbit statevectors. Time should be denoted as integer time in seconds since the reference epoch (user-requested time). @@ -461,7 +441,7 @@ def read_txt_file(filename): Returns: svs (list): - a length-7 list of numpy vectors containing the above variables - ''' + """ t = list() x = list() y = list() @@ -469,17 +449,18 @@ def read_txt_file(filename): vx = list() vy = list() vz = list() - with open(filename, 'r') as f: + with open(filename) as f: for line in f: try: parts = line.strip().split() - t_ = datetime.datetime.fromisoformat(parts[0]) - x_, y_, z_, vx_, vy_, vz_ = [float(t) for t in parts[1:]] + t_ = dt.datetime.fromisoformat(parts[0]) + x_, y_, z_, vx_, vy_, vz_ = (float(t) for t in parts[1:]) except ValueError: raise ValueError( - "I need {} to be a 7 column text file, with ".format(filename) + - "columns t, x, y, z, vx, vy, vz (Couldn't parse line " + - "{})".format(repr(line))) + f'I need {filename} to be a 7 column text file, with ' + + "columns t, x, y, z, vx, vy, vz (Couldn't parse line " + + f'{repr(line)})' + ) t.append(t_) x.append(x_) y.append(y_) @@ -489,14 +470,14 @@ def read_txt_file(filename): vz.append(vz_) if len(t) < 4: - raise ValueError('read_txt_file: File {} does not have enough statevectors'.format(filename)) + raise ValueError(f'read_txt_file: File {filename} does not have enough statevectors') return [np.array(a) for a in [t, x, y, z, vx, vy, vz]] def read_ESA_Orbit_file(filename): - ''' - Read orbit data from an orbit file supplied by ESA + """ + Read orbit data from an orbit file supplied by ESA. Args: ---------- @@ -508,7 +489,7 @@ def read_ESA_Orbit_file(filename): in python datetime x, y, z: Nt x 1 ndarrays - x/y/z positions of the sensor at the times t vx, vy, vz: Nt x 1 ndarrays - x/y/z velocities of the sensor at the times t - ''' + """ if ET is None: raise ImportError('read_ESA_Orbit_file: cannot import xml.etree.ElementTree') tree = ET.parse(filename) @@ -525,12 +506,7 @@ def read_ESA_Orbit_file(filename): vz = np.ones(numOSV) for i, st in enumerate(data_block[0]): - t.append( - datetime.datetime.strptime( - st[1].text, - 'UTC=%Y-%m-%dT%H:%M:%S.%f' - ) - ) + t.append(dt.datetime.strptime(st[1].text, 'UTC=%Y-%m-%dT%H:%M:%S.%f')) x[i] = float(st[4].text) y[i] = float(st[5].text) @@ -542,13 +518,13 @@ def read_ESA_Orbit_file(filename): return [t, x, y, z, vx, vy, vz] -def pick_ESA_orbit_file(list_files:list, ref_time:datetime.datetime): - """ From list of .EOF orbit files, pick the one that contains 'ref_time' """ +def pick_ESA_orbit_file(list_files: list, ref_time: dt.datetime): + """From list of .EOF orbit files, pick the one that contains 'ref_time'.""" orb_file = None for path in list_files: - f = os.path.basename(path) - t0 = datetime.datetime.strptime(f.split('_')[6].lstrip('V'), '%Y%m%dT%H%M%S') - t1 = datetime.datetime.strptime(f.split('_')[7].rstrip('.EOF'), '%Y%m%dT%H%M%S') + f = os.path.basename(path) + t0 = dt.datetime.strptime(f.split('_')[6].lstrip('V'), '%Y%m%dT%H%M%S') + t1 = dt.datetime.strptime(f.split('_')[7].rstrip('.EOF'), '%Y%m%dT%H%M%S') if t0 < ref_time < t1: orb_file = path break @@ -558,33 +534,33 @@ def pick_ESA_orbit_file(list_files:list, ref_time:datetime.datetime): return path -def filter_ESA_orbit_file(orbit_xml: str, - ref_time: datetime.datetime) -> bool: - """Returns true or false depending on whether orbit file contains ref time +def filter_ESA_orbit_file(orbit_xml: str, ref_time: dt.datetime) -> bool: + """Returns true or false depending on whether orbit file contains ref time. Parameters ---------- orbit_xml : str ESA orbit xml - ref_time : datetime.datetime + ref_time : dt.datetime - Returns + Returns: ------- bool True if ref time is within orbit_xml """ f = os.path.basename(orbit_xml) - t0 = datetime.datetime.strptime(f.split('_')[6].lstrip('V'), '%Y%m%dT%H%M%S') - t1 = datetime.datetime.strptime(f.split('_')[7].rstrip('.EOF'), '%Y%m%dT%H%M%S') - return (t0 < ref_time < t1) + t0 = dt.datetime.strptime(f.split('_')[6].lstrip('V'), '%Y%m%dT%H%M%S') + t1 = dt.datetime.strptime(f.split('_')[7].rstrip('.EOF'), '%Y%m%dT%H%M%S') + return t0 < ref_time < t1 ############################ def state_to_los(svs, llh_targets): - ''' + """ Converts information from a state vector for a satellite orbit, given in terms of position and velocity, to line-of-sight information at each (lon,lat, height) coordinate requested by the user. + Args: ---------- svs - t, x, y, z, vx, vy, vz - time, position, and velocity in ECEF of the sensor @@ -593,27 +569,25 @@ def state_to_los(svs, llh_targets): Returns: ------- LOS - * x 3 matrix of LOS unit vectors in ECEF (*not* ENU) + Example: - >>> import datetime - >>> import numpy + >>> import datetime as dt + >>> import numpy as np >>> from RAiDER.utilFcns import rio_open >>> import RAiDER.losreader as losr >>> lats, lons, heights = np.array([-76.1]), np.array([36.83]), np.array([0]) - >>> time = datetime.datetime(2018,11,12,23,0,0) + >>> time = dt.datetime(2018,11,12,23,0,0) >>> # download the orbit file beforehand >>> esa_orbit_file = 'S1A_OPER_AUX_POEORB_OPOD_20181203T120749_V20181112T225942_20181114T005942.EOF' >>> svs = losr.read_ESA_Orbit_file(esa_orbit_file) >>> LOS = losr.state_to_los(*svs, [lats, lons, heights], xyz) - ''' + """ if isce is None: - raise ImportError(f'isce3 is required for this function. Use conda to install isce3`') + raise ImportError('isce3 is required for this function. Use conda to install isce3`') # check the inputs if np.min(svs.shape) < 4: - raise RuntimeError( - 'state_to_los: At least 4 state vectors are required' - ' for orbit interpolation' - ) + raise RuntimeError('state_to_los: At least 4 state vectors are required for orbit interpolation') # Convert svs to isce3 orbit orb = isce.core.Orbit([ @@ -624,9 +598,8 @@ def state_to_los(svs, llh_targets): ]) # Flatten the input array for convenience - in_shape = llh_targets[0].shape + in_shape = llh_targets[0].shape target_llh = np.stack([x.flatten() for x in llh_targets], axis=-1) - Npts = len(target_llh) # Iterate through targets and compute LOS los_ang, _ = get_radar_pos(target_llh, orb) @@ -639,23 +612,23 @@ def cut_times(times, ref_time, pad): Slice the orbit file around the reference aquisition time. This is done by default using a three-hour window, which for Sentinel-1 empirically works out to be roughly the largest window allowed by the orbit time. + Args: ---------- times: Nt x 1 ndarray - Vector of orbit times as datetime ref_time: datetime - Reference time pad: int - integer time in seconds to use as padding + Returns: ------- idx: Nt x 1 logical ndarray - a mask of times within the padded request time. """ - diff = np.array( - [(x - ref_time).total_seconds() for x in times] - ) + diff = np.array([(x - ref_time).total_seconds() for x in times]) return np.abs(diff) < pad def get_radar_pos(llh, orb): - ''' + """ Calculate the coordinate of the sensor in ECEF at the time corresponding to ***. Args: @@ -668,17 +641,15 @@ def get_radar_pos(llh, orb): ------- los: ndarray - Satellite incidence angle sr: ndarray - Slant range in meters - ''' + """ if isce is None: - raise ImportError(f'isce3 is required for this function. Use conda to install isce3`') + raise ImportError('isce3 is required for this function. Use conda to install isce3`') num_iteration = 30 residual_threshold = 1.0e-7 # Get xyz positions of targets here from lat/lon/height - targ_xyz = np.stack( - lla2ecef(llh[:, 0], llh[:, 1], llh[:, 2]), axis=-1 - ) + targ_xyz = np.stack(lla2ecef(llh[:, 0], llh[:, 1], llh[:, 2]), axis=-1) # Get some isce3 constants for this inversion # TODO - Assuming right-looking for now @@ -694,31 +665,32 @@ def get_radar_pos(llh, orb): for ind, pt in enumerate(llh): if not any(np.isnan(pt)): # ISCE3 always uses xy convention - inp = np.array([np.deg2rad(pt[1]), - np.deg2rad(pt[0]), - pt[2]]) + inp = np.array([np.deg2rad(pt[1]), np.deg2rad(pt[0]), pt[2]]) # Local normal vector nv = elp.n_vector(inp[0], inp[1]) # Wavelength does not matter for zero doppler try: aztime, slant_range = isce.geometry.geo2rdr( - inp, elp, orb, dop, 0.06, look, + inp, + elp, + orb, + dop, + 0.06, + look, threshold=residual_threshold, maxiter=num_iteration, - delta_range=10.0) + delta_range=10.0, + ) sat_xyz, _ = orb.interpolate(aztime) sr[ind] = slant_range - delta = sat_xyz - targ_xyz[ind, :] # TODO - if we only ever need cos(lookang), # skip the arccos here and cos above delta = delta / np.linalg.norm(delta) - output[ind] = np.rad2deg( - np.arccos(np.dot(delta, nv)) - ) + output[ind] = np.rad2deg(np.arccos(np.dot(delta, nv))) except Exception as e: raise e @@ -749,32 +721,30 @@ def getTopOfAtmosphere(xyz, look_vecs, toaheight, factor=None): maxIter = 3 else: maxIter = 10 - factor = 1. + factor = 1.0 # Guess top point pos = xyz + toaheight * look_vecs for _ in range(maxIter): pos_llh = ecef2lla(pos[..., 0], pos[..., 1], pos[..., 2]) - pos = pos + look_vecs * ((toaheight - pos_llh[2])/factor)[..., None] + pos = pos + look_vecs * ((toaheight - pos_llh[2]) / factor)[..., None] return pos -def get_orbit(orbit_file: Union[list, str], - ref_time: datetime.datetime, - pad: int): - ''' +def get_orbit(orbit_file: Union[list, str], ref_time: dt.datetime, pad: int): + """ Returns state vectors from an orbit file; state vectors are unique and ordered in terms of time orbit file (str | list): - user-passed file(s) containing statevectors for the sensor (can be download with sentineleof libray). Lists of files are only accepted for Sentinel-1 EOF files. pad (int): - number of seconds to keep around the - requested time (should be about 600 seconds) + requested time (should be about 600 seconds). - ''' + """ if isce is None: - raise ImportError(f'isce3 is required for this function. Use conda to install isce3`') + raise ImportError('isce3 is required for this function. Use conda to install isce3`') # First load the state vectors into an isce orbit svs = np.stack(get_sv(orbit_file, ref_time, pad), axis=-1) @@ -801,7 +771,7 @@ def get_orbit(orbit_file: Union[list, str], def build_ray(model_zs, ht, xyz, LOS, MAX_TROPO_HEIGHT=_ZREF): """ - Compute the ray length in ECEF between each weather model layers + Compute the ray length in ECEF between each weather model layers. Only heights up to MAX_TROPO_HEIGHT are considered Assumption: model_zs (model) are assumed to be sorted in height @@ -812,7 +782,7 @@ def build_ray(model_zs, ht, xyz, LOS, MAX_TROPO_HEIGHT=_ZREF): cos_factor = None ray_lengths, low_xyzs, high_xyzs = [], [], [] - for zz in range(model_zs.size-1): + for zz in range(model_zs.size - 1): # Low and High for model interval low_ht = model_zs[zz] high_ht = model_zs[zz + 1] @@ -848,7 +818,7 @@ def build_ray(model_zs, ht, xyz, LOS, MAX_TROPO_HEIGHT=_ZREF): high_xyz = getTopOfAtmosphere(xyz, LOS, high_ht, factor=cos_factor) # Compute ray length - ray_length = np.linalg.norm(high_xyz - low_xyz, axis=-1) + ray_length = np.linalg.norm(high_xyz - low_xyz, axis=-1) # Compute cos_factor for first iteration if cos_factor is None: @@ -858,7 +828,7 @@ def build_ray(model_zs, ht, xyz, LOS, MAX_TROPO_HEIGHT=_ZREF): low_xyzs.append(low_xyz) high_xyzs.append(high_xyz) - ## if all weather model levels are requested the top most layer might not contribute anything + # if all weather model levels are requested the top most layer might not contribute anything if not ray_lengths: return None, None, None else: diff --git a/tools/RAiDER/models/__init__.py b/tools/RAiDER/models/__init__.py index 79666992e..86162384d 100644 --- a/tools/RAiDER/models/__init__.py +++ b/tools/RAiDER/models/__init__.py @@ -6,4 +6,5 @@ from .merra2 import MERRA2 from .ncmr import NCMR + __all__ = ['HRRR', 'HRRRAK', 'GMAO', 'ERA5', 'ERA5T', 'HRES', 'MERRA2'] diff --git a/tools/RAiDER/models/credentials.py b/tools/RAiDER/models/credentials.py index e6c238cc8..8d01ee533 100644 --- a/tools/RAiDER/models/credentials.py +++ b/tools/RAiDER/models/credentials.py @@ -1,6 +1,6 @@ -''' +""" API credential information and help url for downloading weather model data - saved in a hidden file in home directory + saved in a hidden file in home directory. api filename weather models UID KEY URL _________________________________________________________________________________ @@ -8,7 +8,7 @@ ecmwfapirc HRES email key https://api.ecmwf.int/v1 netrc GMAO, MERRA2 username password urs.earthdata.nasa.gov HRRR [public access] -''' +""" import os from pathlib import Path @@ -25,7 +25,7 @@ 'HRES': 'ecmwfapirc', 'GMAO': 'netrc', 'MERRA2': 'netrc', - 'HRRR': None + 'HRRR': None, } APIS = { @@ -35,7 +35,7 @@ 'key: {uid}:{key}\n' ), 'help_url': 'https://cds.climate.copernicus.eu/api-how-to', - 'default_host': 'https://cds.climate.copernicus.eu/api/v2' + 'default_host': 'https://cds.climate.copernicus.eu/api/v2', }, 'ecmwfapirc': { 'template': ( @@ -46,7 +46,7 @@ '}}\n' ), 'help_url': 'https://confluence.ecmwf.int/display/WEBAPI/Access+ECMWF+Public+Datasets#AccessECMWFPublicDatasets-key', - 'default_host': 'https://api.ecmwf.int/v1' + 'default_host': 'https://api.ecmwf.int/v1', }, 'netrc': { 'template': ( @@ -55,8 +55,8 @@ ' password {key}\n' ), 'help_url': 'https://wiki.earthdata.nasa.gov/display/EL/How+To+Access+Data+With+cURL+And+Wget', - 'default_host': 'urs.earthdata.nasa.gov' - } + 'default_host': 'urs.earthdata.nasa.gov', + }, } @@ -69,8 +69,7 @@ def _get_envs(model: str) -> Tuple[Optional[str], Optional[str], Optional[str]]: elif model == 'HRES': uid = os.getenv('RAIDER_HRES_EMAIL') key = os.getenv('RAIDER_HRES_API_KEY') - host = os.getenv('RAIDER_HRES_URL', - APIS['ecmwfapirc']['default_host']) + host = os.getenv('RAIDER_HRES_URL', APIS['ecmwfapirc']['default_host']) elif model in ('GMAO', 'MERRA2'): # same as in DockerizedTopsApp uid = os.getenv('EARTHDATA_USERNAME') @@ -81,11 +80,13 @@ def _get_envs(model: str) -> Tuple[Optional[str], Optional[str], Optional[str]]: return uid, key, host -def check_api(model: str, - uid: Optional[str] = None, - key: Optional[str] = None, - output_dir: str = '~/', - update_rc_file: bool = False) -> None: +def check_api( + model: str, + uid: Optional[str] = None, + key: Optional[str] = None, + output_dir: str = '~/', + update_rc_file: bool = False, +) -> None: # Weather model API RC filename # Typically stored in home dir as a hidden file rc_filename = RC_FILENAMES[model] @@ -96,7 +97,7 @@ def check_api(model: str, return # Get the target rc file's path - hidden_ext = '_' if system() == "Windows" else '.' + hidden_ext = '_' if system() == 'Windows' else '.' rc_path = Path(output_dir) / (hidden_ext + rc_filename) rc_path = rc_path.expanduser() @@ -118,19 +119,19 @@ def check_api(model: str, raise ValueError( f'ERROR: {model} API UID not provided in RAiDER arguments and ' 'not present in environment variables.\n' - f'See info for this model\'s API at \033[1m{help_url}\033[0m' + f"See info for this model's API at \033[1m{help_url}\033[0m" ) elif uid is not None and key is None: raise ValueError( f'ERROR: {model} API key not provided in RAiDER arguments and ' 'not present in environment variables.\n' - f'See info for this model\'s API at \033[1m{help_url}\033[0m' + f"See info for this model's API at \033[1m{help_url}\033[0m" ) else: raise ValueError( f'ERROR: {model} API credentials not provided in RAiDER ' 'arguments and not present in environment variables.\n' - f'See info for this model\'s API at \033[1m{help_url}\033[0m' + f"See info for this model's API at \033[1m{help_url}\033[0m" ) # Create file with the API credentials @@ -150,6 +151,7 @@ def check_api(model: str, # so extra care needs to be taken to make sure we only touch the # one that belongs to this URL. import netrc + rc_path.touch() netrc_credentials = netrc.netrc(str(rc_path)) netrc_credentials.hosts[url] = (uid, '', key) @@ -157,6 +159,6 @@ def check_api(model: str, rc_path.chmod(0o000600) -def setup_from_env(): +def setup_from_env() -> None: for model in RC_FILENAMES.keys(): check_api(model) diff --git a/tools/RAiDER/models/customExceptions.py b/tools/RAiDER/models/customExceptions.py index 3bc9b69a9..4bf9dc7f0 100644 --- a/tools/RAiDER/models/customExceptions.py +++ b/tools/RAiDER/models/customExceptions.py @@ -1,50 +1,51 @@ class DatetimeFailed(Exception): - def __init__(self, model, time): - msg = f"Weather model {model} failed to download for datetime {time}" + def __init__(self, model, time) -> None: + msg = f'Weather model {model} failed to download for datetime {time}' super().__init__(msg) class DatetimeNotAvailable(Exception): - def __init__(self, model, time): - msg = f"Weather model {model} was not found for datetime {time}" + def __init__(self, model, time) -> None: + msg = f'Weather model {model} was not found for datetime {time}' super().__init__(msg) class DatetimeOutsideRange(Exception): - def __init__(self, model, time): - msg = f"Time {time} is outside the available date range for weather model {model}" + def __init__(self, model, time) -> None: + msg = f'Time {time} is outside the available date range for weather model {model}' super().__init__(msg) class ExistingWeatherModelTooSmall(Exception): - def __init__(self): - msg = 'The weather model passed does not cover all of the input ' \ - 'points; you may need to download a larger area.' + def __init__(self) -> None: + msg = 'The weather model passed does not cover all of the input points; you may need to download a larger area.' super().__init__(msg) class TryToKeepGoingError(Exception): - def __init__(self, date=None): + def __init__(self, date=None) -> None: if date is not None: - msg = 'The weather model does not exist for date {date}, so I will try to use the closest available date.' + msg = f'The weather model does not exist for date {date}, so I will try to use the closest available date.' else: msg = 'I will try to keep going' super().__init__(msg) - + + class CriticalError(Exception): - def __init__(self): + def __init__(self) -> None: msg = 'I have experienced a critical error, please take a look at the log files' super().__init__(msg) + class WrongNumberOfFiles(Exception): - def __init__(self, Nexp, Navail): + def __init__(self, Nexp, Navail) -> None: msg = 'The number of files downloaded does not match the requested, ' - 'I expected {} and got {}, aborting'.format(Nexp, Navail) + f'I expected {Nexp} and got {Navail}, aborting' super().__init__(msg) - + class NoWeatherModelData(Exception): - def __init__(self, custom_msg=None): + def __init__(self, custom_msg=None) -> None: if custom_msg is None: msg = 'No weather model files were available to download, aborting' else: @@ -53,14 +54,13 @@ def __init__(self, custom_msg=None): class NoStationDataFoundError(Exception): - def __init__(self, station_list=None, years=None): - if (station_list is None) and (years is None): + def __init__(self, station_list=None, years=None) -> None: + if station_list is None and years is None: msg = 'No GNSS station data was found' - elif (years is None): - msg = 'No data was found for GNSS stations {}'.format(station_list) + elif years is None: + msg = f'No data was found for GNSS stations {station_list}' elif station_list is None: - msg = 'No data was found for years {}'.format(years) + msg = f'No data was found for years {years}' else: - msg = 'No data was found for GNSS stations {} and years {}'.format(station_list, years) - + msg = f'No data was found for GNSS stations {station_list} and years {years}' super().__init__(msg) diff --git a/tools/RAiDER/models/ecmwf.py b/tools/RAiDER/models/ecmwf.py index 239dfc7fc..e046ee994 100755 --- a/tools/RAiDER/models/ecmwf.py +++ b/tools/RAiDER/models/ecmwf.py @@ -1,35 +1,30 @@ -from abc import abstractmethod -import datetime +import datetime as dt import numpy as np import xarray as xr - from pyproj import CRS -from RAiDER.logger import logger from RAiDER import utilFcns as util +from RAiDER.logger import logger from RAiDER.models.model_levels import ( - LEVELS_137_HEIGHTS, - LEVELS_25_HEIGHTS, A_137_HRES, B_137_HRES, + LEVELS_25_HEIGHTS, + LEVELS_137_HEIGHTS, ) - -from RAiDER.models.weatherModel import WeatherModel, TIME_RES +from RAiDER.models.weatherModel import TIME_RES, WeatherModel class ECMWF(WeatherModel): - ''' - Implement ECMWF models - ''' + """Implement ECMWF models.""" - def __init__(self): + def __init__(self) -> None: # initialize a weather model WeatherModel.__init__(self) # model constants - self._k1 = 0.776 # [K/Pa] - self._k2 = 0.233 # [K/Pa] + self._k1 = 0.776 # [K/Pa] + self._k2 = 0.233 # [K/Pa] self._k3 = 3.75e3 # [K^2/Pa] self._time_res = TIME_RES['ECMWF'] @@ -40,36 +35,29 @@ def __init__(self): self._model_level_type = 'ml' # Default - def __pressure_levels__(self): self._zlevels = np.flipud(LEVELS_25_HEIGHTS) - self._levels = len(self._zlevels) - + self._levels = len(self._zlevels) def __model_levels__(self): - self._levels = 137 + self._levels = 137 self._zlevels = np.flipud(LEVELS_137_HEIGHTS) self._a = A_137_HRES self._b = B_137_HRES - - def load_weather(self, f=None, *args, **kwargs): - ''' + def load_weather(self, f=None, *args, **kwargs) -> None: + """ Consistent class method to be implemented across all weather model types. As a result of calling this method, all of the variables (x, y, z, p, q, t, wet_refractivity, hydrostatic refractivity, e) should be fully populated. - ''' - f = self.files[0] if f is None else f + """ + f = f if f is not None else self.files[0] self._load_model_level(f) - - def _load_model_level(self, fname): + def _load_model_level(self, fname) -> None: # read data from netcdf file - lats, lons, xs, ys, t, q, lnsp, z = self._makeDataCubes( - fname, - verbose=False - ) + lats, lons, xs, ys, t, q, lnsp, z = self._makeDataCubes(fname, verbose=False) # ECMWF appears to give me this backwards if lats[0] > lats[1]: @@ -122,161 +110,135 @@ def _load_model_level(self, fname): self._xs = self._lons.copy() self._zs = np.flip(h, axis=2) - - def _fetch(self, out): - ''' - Fetch a weather model from ECMWF - ''' + def _fetch(self, out) -> None: + """Fetch a weather model from ECMWF.""" # bounding box plus a buffer lat_min, lat_max, lon_min, lon_max = self._ll_bounds - # execute the search at ECMWF - self._get_from_ecmwf( - lat_min, - lat_max, - self._lat_res, - lon_min, - lon_max, - self._lon_res, - self._time, - out - ) - return + self._get_from_ecmwf(lat_min, lat_max, self._lat_res, lon_min, lon_max, self._lon_res, self._time, out) - - def _get_from_ecmwf(self, lat_min, lat_max, lat_step, lon_min, lon_max, - lon_step, time, out): + def _get_from_ecmwf(self, lat_min, lat_max, lat_step, lon_min, lon_max, lon_step, time, out) -> None: import ecmwfapi server = ecmwfapi.ECMWFDataServer() - corrected_DT = util.round_date(time, datetime.timedelta(hours=self._time_res)) + corrected_DT = util.round_date(time, dt.timedelta(hours=self._time_res)) if not corrected_DT == time: logger.warning('Rounded given datetime from %s to %s', time, corrected_DT) - server.retrieve({ - "class": self._classname, # ERA-Interim - 'dataset': self._dataset, - "expver": "{}".format(self._expver), - # They warn me against all, but it works well - "levelist": 'all', - "levtype": "ml", # Model levels - "param": "lnsp/q/z/t", # Necessary variables - "stream": "oper", - # date: Specify a single date as "2015-08-01" or a period as - # "2015-08-01/to/2015-08-31". - "date": datetime.datetime.strftime(corrected_DT, "%Y-%m-%d"), - # type: Use an (analysis) unless you have a particular reason to - # use fc (forecast). - "type": "an", - # time: With type=an, time can be any of - # "00:00:00/06:00:00/12:00:00/18:00:00". With type=fc, time can - # be any of "00:00:00/12:00:00", - "time": datetime.time.strftime(corrected_DT.time(), "%H:%M:%S"), - # step: With type=an, step is always "0". With type=fc, step can - # be any of "3/6/9/12". - "step": "0", - # grid: Only regular lat/lon grids are supported. - "grid": '{}/{}'.format(lat_step, lon_step), - "area": '{}/{}/{}/{}'.format(lat_max, lon_min, lat_min, lon_max), # area: N/W/S/E - "format": "netcdf", - "resol": "av", - "target": out, # target: the name of the output file. - }) - - - def _get_from_cds( - self, - lat_min, - lat_max, - lon_min, - lon_max, - acqTime, - outname - ): - """ Used for ERA5 """ + server.retrieve( + { + 'class': self._classname, # ERA-Interim + 'dataset': self._dataset, + 'expver': f'{self._expver}', + # They warn me against all, but it works well + 'levelist': 'all', + 'levtype': 'ml', # Model levels + 'param': 'lnsp/q/z/t', # Necessary variables + 'stream': 'oper', + # date: Specify a single date as "2015-08-01" or a period as + # "2015-08-01/to/2015-08-31". + 'date': dt.datetime.strftime(corrected_DT, '%Y-%m-%d'), + # type: Use an (analysis) unless you have a particular reason to + # use fc (forecast). + 'type': 'an', + # time: With type=an, time can be any of + # "00:00:00/06:00:00/12:00:00/18:00:00". With type=fc, time can + # be any of "00:00:00/12:00:00", + 'time': dt.time.strftime(corrected_DT.time(), '%H:%M:%S'), + # step: With type=an, step is always "0". With type=fc, step can + # be any of "3/6/9/12". + 'step': '0', + # grid: Only regular lat/lon grids are supported. + 'grid': f'{lat_step}/{lon_step}', + 'area': f'{lat_max}/{lon_min}/{lat_min}/{lon_max}', # area: N/W/S/E + 'format': 'netcdf', + 'resol': 'av', + 'target': out, # target: the name of the output file. + } + ) + + def _get_from_cds(self, lat_min, lat_max, lon_min, lon_max, acqTime, outname) -> None: + """Used for ERA5.""" import cdsapi + c = cdsapi.Client(verify=0) if self._model_level_type == 'pl': var = ['z', 'q', 't'] - levType = 'pressure_level' else: - var = "129/130/133/152" # 'lnsp', 'q', 'z', 't' - levType = 'model_level' + var = '129/130/133/152' # 'lnsp', 'q', 'z', 't' bbox = [lat_max, lon_min, lat_min, lon_max] # round to the closest legal time - corrected_DT = util.round_date(acqTime, datetime.timedelta(hours=self._time_res)) + corrected_DT = util.round_date(acqTime, dt.timedelta(hours=self._time_res)) if not corrected_DT == acqTime: logger.warning('Rounded given datetime from %s to %s', acqTime, corrected_DT) - # I referenced https://confluence.ecmwf.int/display/CKB/How+to+download+ERA5 dataDict = { - "class": "ea", - "expver": "1", - "levelist": 'all', - "levtype": "{}".format(self._model_level_type), # 'ml' for model levels or 'pl' for pressure levels + 'class': 'ea', + 'expver': '1', + 'levelist': 'all', + 'levtype': f'{self._model_level_type}', # 'ml' for model levels or 'pl' for pressure levels 'param': var, - "stream": "oper", - "type": "an", - "date": "{}".format(corrected_DT.strftime('%Y-%m-%d')), - "time": "{}".format(datetime.time.strftime(corrected_DT.time(), '%H:%M')), + 'stream': 'oper', + 'type': 'an', + 'date': corrected_DT.strftime('%Y-%m-%d'), + 'time': dt.time.strftime(corrected_DT.time(), '%H:%M'), # step: With type=an, step is always "0". With type=fc, step can # be any of "3/6/9/12". - "step": "0", - "area": bbox, - "grid": [0.25, .25], - "format": "netcdf"} + 'step': '0', + 'area': bbox, + 'grid': [0.25, 0.25], + 'format': 'netcdf', + } try: c.retrieve('reanalysis-era5-complete', dataDict, outname) - except Exception as e: + except: raise Exception - - def _download_ecmwf(self, lat_min, lat_max, lat_step, lon_min, lon_max, lon_step, time, out): - """ Used for HRES """ + def _download_ecmwf(self, lat_min, lat_max, lat_step, lon_min, lon_max, lon_step, time, out) -> None: + """Used for HRES.""" from ecmwfapi import ECMWFService - server = ECMWFService("mars") + server = ECMWFService('mars') # round to the closest legal time - corrected_DT = util.round_date(time, datetime.timedelta(hours=self._time_res)) + corrected_DT = util.round_date(time, dt.timedelta(hours=self._time_res)) if not corrected_DT == time: logger.warning('Rounded given datetime from %s to %s', time, corrected_DT) if self._model_level_type == 'ml': - param = "129/130/133/152" + param = '129/130/133/152' else: - param = "129.128/130.128/133.128/152" + param = '129.128/130.128/133.128/152' server.execute( { 'class': self._classname, 'dataset': self._dataset, - 'expver': "{}".format(self._expver), - 'resol': "av", - 'stream': "oper", - 'type': "an", - 'levelist': "all", - 'levtype': "{}".format(self._model_level_type), + 'expver': f'{self._expver}', + 'resol': 'av', + 'stream': 'oper', + 'type': 'an', + 'levelist': 'all', + 'levtype': f'{self._model_level_type}', 'param': param, - 'date': datetime.datetime.strftime(corrected_DT, "%Y-%m-%d"), - 'time': "{}".format(datetime.time.strftime(corrected_DT.time(), '%H:%M')), - 'step': "0", - 'grid': "{}/{}".format(lon_step, lat_step), - 'area': "{}/{}/{}/{}".format(lat_max, util.floorish(lon_min, 0.1), util.floorish(lat_min, 0.1), lon_max), - 'format': "netcdf", + 'date': dt.datetime.strftime(corrected_DT, '%Y-%m-%d'), + 'time': dt.time.strftime(corrected_DT.time(), '%H:%M'), + 'step': '0', + 'grid': f'{lon_step}/{lat_step}', + 'area': f'{lat_max}/{util.floorish(lon_min, 0.1)}/{util.floorish(lat_min, 0.1)}/{lon_max}', + 'format': 'netcdf', }, - out + out, ) - - def _load_pressure_level(self, filename, *args, **kwargs): + def _load_pressure_level(self, filename, *args, **kwargs) -> None: with xr.open_dataset(filename) as block: # Pull the data z = np.squeeze(block['z'].values) @@ -316,8 +278,7 @@ def _load_pressure_level(self, filename, *args, **kwargs): # correct heights for latitude self._get_heights(self._lats, geo_hgt) - self._p = np.broadcast_to(levels[np.newaxis, np.newaxis, :], - self._zs.shape) + self._p = np.broadcast_to(levels[np.newaxis, np.newaxis, :], self._zs.shape) # Re-structure from (heights, lats, lons) to (lons, lats, heights) self._t = self._t.transpose(1, 2, 0) @@ -330,12 +291,11 @@ def _load_pressure_level(self, filename, *args, **kwargs): self._t = np.flip(self._t, axis=2) self._q = np.flip(self._q, axis=2) - def _makeDataCubes(self, fname, verbose=False): - ''' + """ Create a cube of data representing temperature and relative humidity - at specified pressure levels - ''' + at specified pressure levels. + """ # get ll_bounds S, N, W, E = self._ll_bounds @@ -359,7 +319,6 @@ def _makeDataCubes(self, fname, verbose=False): ys = lats.copy() if z.size == 0: - raise RuntimeError('There is no data in z, ' - 'you may have a problem with your mask') + raise RuntimeError('There is no data in z, you may have a problem with your mask') return lats, lons, xs, ys, t, q, lnsp, z diff --git a/tools/RAiDER/models/era5.py b/tools/RAiDER/models/era5.py index c06593029..c4e8def5d 100755 --- a/tools/RAiDER/models/era5.py +++ b/tools/RAiDER/models/era5.py @@ -1,16 +1,15 @@ -import datetime -from dateutil.relativedelta import relativedelta +import datetime as dt +from dateutil.relativedelta import relativedelta from pyproj import CRS from RAiDER.models.ecmwf import ECMWF -from RAiDER.logger import logger class ERA5(ECMWF): # I took this from # https://www.ecmwf.int/en/forecasts/documentation-and-support/137-model-levels. - def __init__(self): + def __init__(self) -> None: ECMWF.__init__(self) self._humidityType = 'q' @@ -21,11 +20,11 @@ def __init__(self): self._proj = CRS.from_epsg(4326) # Tuple of min/max years where data is available. - lag_time = 3 # months - end_date = datetime.datetime.today() - relativedelta(months=lag_time) + lag_time = 3 # months + end_date = dt.datetime.today() - relativedelta(months=lag_time) self._valid_range = ( - datetime.datetime(1950, 1, 1).replace(tzinfo=datetime.timezone(offset=datetime.timedelta())), - end_date.replace(tzinfo=datetime.timezone(offset=datetime.timedelta())) + dt.datetime(1950, 1, 1).replace(tzinfo=dt.timezone(offset=dt.timedelta())), + end_date.replace(tzinfo=dt.timezone(offset=dt.timedelta())), ) # Availability lag time in days @@ -34,11 +33,8 @@ def __init__(self): # Default, need to change to ml self.setLevelType('ml') - - def _fetch(self, out): - ''' - Fetch a weather model from ECMWF - ''' + def _fetch(self, out) -> None: + """Fetch a weather model from ECMWF.""" # bounding box plus a buffer lat_min, lat_max, lon_min, lon_max = self._ll_bounds time = self._time @@ -46,9 +42,8 @@ def _fetch(self, out): # execute the search at ECMWF self._get_from_cds(lat_min, lat_max, lon_min, lon_max, time, out) - - def load_weather(self, f=None, *args, **kwargs): - '''Load either pressure or model level data''' + def load_weather(self, f=None, *args, **kwargs) -> None: + """Load either pressure or model level data.""" f = self.files[0] if f is None else f if self._model_level_type == 'pl': self._load_pressure_level(f, *args, **kwargs) diff --git a/tools/RAiDER/models/era5t.py b/tools/RAiDER/models/era5t.py index 456bdee92..577759a20 100644 --- a/tools/RAiDER/models/era5t.py +++ b/tools/RAiDER/models/era5t.py @@ -1,4 +1,4 @@ -import datetime +import datetime as dt from RAiDER.models.era5 import ERA5 @@ -6,16 +6,18 @@ class ERA5T(ERA5): # I took this from # https://www.ecmwf.int/en/forecasts/documentation-and-support/137-model-levels. - def __init__(self): + def __init__(self) -> None: ERA5.__init__(self) self._expver = '0005' self._dataset = 'era5t' self._Name = 'ERA-5T' - self._valid_range = (datetime.datetime(1950, 1, 1).replace(tzinfo=datetime.timezone(offset=datetime.timedelta())), - datetime.datetime.now(datetime.timezone.utc)) # Tuple of min/max years where data is available. + self._valid_range = ( + dt.datetime(1950, 1, 1).replace(tzinfo=dt.timezone(offset=dt.timedelta())), + dt.datetime.now(dt.timezone.utc), + ) # Tuple of min/max years where data is available. # Availability lag time in days; actually about 12 hours but unstable on ECMWF side - # https://confluence.ecmwf.int/display/CKB/ERA5%3A+data+documentation - # see data update frequency - self._lag_time = datetime.timedelta(days=1) + # https://confluence.ecmwf.int/display/CKB/ERA5%3A+data+documentation + # see data update frequency + self._lag_time = dt.timedelta(days=1) diff --git a/tools/RAiDER/models/erai.py b/tools/RAiDER/models/erai.py index b91267bf1..836ae6676 100755 --- a/tools/RAiDER/models/erai.py +++ b/tools/RAiDER/models/erai.py @@ -1,4 +1,4 @@ -import datetime +import datetime as dt from RAiDER.models.ecmwf import ECMWF from RAiDER.models.model_levels import A_ERAI, B_ERAI @@ -7,7 +7,7 @@ class ERAI(ECMWF): # A and B parameters to calculate pressures for model levels, # extracted from an ECMWF ERA-Interim GRIB file and then hardcoded here - def __init__(self): + def __init__(self) -> None: ECMWF.__init__(self) self._classname = 'ei' self._expver = '0001' @@ -17,11 +17,11 @@ def __init__(self): # Tuple of min/max years where data is available. self._valid_range = ( - datetime.datetime(1979, 1, 1).replace(tzinfo=datetime.timezone(offset=datetime.timedelta())), - datetime.datetime(2019, 8, 31).replace(tzinfo=datetime.timezone(offset=datetime.timedelta())) + dt.datetime(1979, 1, 1).replace(tzinfo=dt.timezone(offset=dt.timedelta())), + dt.datetime(2019, 8, 31).replace(tzinfo=dt.timezone(offset=dt.timedelta())), ) - self._lag_time = datetime.timedelta(days=30) # Availability lag time in days + self._lag_time = dt.timedelta(days=30) # Availability lag time in days def __model_levels__(self): self._levels = 60 diff --git a/tools/RAiDER/models/generateGACOSVRT.py b/tools/RAiDER/models/generateGACOSVRT.py index 60f703354..3cfcad233 100644 --- a/tools/RAiDER/models/generateGACOSVRT.py +++ b/tools/RAiDER/models/generateGACOSVRT.py @@ -3,26 +3,27 @@ # Copyright 2018 -def makeVRT(filename, dtype='Float32'): - ''' - Use an RSC file to create a GDAL-compatible VRT file for opening GACOS weather model files - ''' +def makeVRT(filename, dtype='Float32') -> None: + """Use an RSC file to create a GDAL-compatible VRT file for opening GACOS weather model files.""" fields = readRSC(filename) - string = vrtStr(fields['XMAX'], fields['YMAX'], fields['X_FIRST'], fields['Y_FIRST'], fields['X_STEP'], fields['Y_STEP'], filename.replace('.rsc', ''), dtype=dtype) - writeStringToFile(string, filename.replace('.rsc', '').replace('.ztd', '') + '.vrt') - - -def writeStringToFile(string, filename): - ''' - Write a string to a VRT file - ''' + string = vrtStr( + fields['XMAX'], + fields['YMAX'], + fields['X_FIRST'], + fields['Y_FIRST'], + fields['X_STEP'], + fields['Y_STEP'], + filename.replace('.rsc', ''), + dtype=dtype, + ) + filename = filename.replace('.rsc', '').replace('.ztd', '') + '.vrt' with open(filename, 'w') as f: f.write(string) def readRSC(rscFilename): fields = {} - with open(rscFilename, 'r') as f: + with open(rscFilename) as f: for line in f: fieldName, value = line.strip().split() fields[fieldName] = value @@ -30,35 +31,34 @@ def readRSC(rscFilename): def vrtStr(xSize, ySize, lon1, lat1, lonStep, latStep, filename, dtype='Float32'): - string = f''' - EPSG:4326 - {lon1}, {lonStep}, 0.0000000000000000e+00, {lat1}, 0.0000000000000000e+00, {latStep} - - {filename} - - -''' - - return string + return ( + f'' + ' EPSG:4326' + f' {lon1}, {lonStep}, 0.0000000000000000e+00, {lat1}, 0.0000000000000000e+00, {latStep}' + f' ' + f' {filename}' + ' ' + '' + ) -def convertAllFiles(dirLoc): - ''' - convert all RSC files to VRT files contained in dirLoc - ''' +def convertAllFiles(dirLoc) -> None: + """Convert all RSC files to VRT files contained in dirLoc.""" import glob + files = glob.glob('*.rsc') for f in files: makeVRT(f) -def main(): +def main() -> None: import sys + if len(sys.argv) == 2: makeVRT(sys.argv[1]) elif len(sys.argv) == 3: convertAllFiles(sys.argv[1]) - print('Converting all RSC files in {}'.format(sys.argv[1])) + print(f'Converting all RSC files in {sys.argv[1]}') else: print('Usage: ') print('python3 generateGACOSVRT.py ') diff --git a/tools/RAiDER/models/gmao.py b/tools/RAiDER/models/gmao.py index abd0a4a27..830070df6 100755 --- a/tools/RAiDER/models/gmao.py +++ b/tools/RAiDER/models/gmao.py @@ -1,24 +1,25 @@ +import datetime as dt import os -import datetime -import numpy as np import shutil + import h5py +import numpy as np import pydap.cas.urs import pydap.client from pyproj import CRS -from RAiDER.models.weatherModel import WeatherModel, TIME_RES from RAiDER.logger import logger -from RAiDER.utilFcns import writeWeatherVarsXarray, round_date, requests_retry_session from RAiDER.models.model_levels import ( LEVELS_137_HEIGHTS, ) +from RAiDER.models.weatherModel import TIME_RES, WeatherModel +from RAiDER.utilFcns import requests_retry_session, round_date, writeWeatherVarsXarray class GMAO(WeatherModel): # I took this from GMAO model level weblink # https://opendap.nccs.nasa.gov/dods/GEOS-5/fp/0.25_deg/assim/inst3_3d_asm_Nv - def __init__(self): + def __init__(self) -> None: # initialize a weather model WeatherModel.__init__(self) @@ -29,9 +30,11 @@ def __init__(self): self._dataset = 'gmao' # Tuple of min/max years where data is available. - self._valid_range = (datetime.datetime(2014, 2, 20).replace(tzinfo=datetime.timezone(offset=datetime.timedelta())), - datetime.datetime.now(datetime.timezone.utc)) - self._lag_time = datetime.timedelta(hours=24.0) # Availability lag time in hours + self._valid_range = ( + dt.datetime(2014, 2, 20).replace(tzinfo=dt.timezone(offset=dt.timedelta())), + dt.datetime.now(dt.timezone.utc), + ) + self._lag_time = dt.timedelta(hours=24.0) # Availability lag time in hours # model constants self._k1 = 0.776 # [K/Pa] @@ -40,7 +43,6 @@ def __init__(self): self._time_res = TIME_RES[self._dataset.upper()] - # horizontal grid spacing self._lat_res = 0.25 self._lon_res = 0.3125 @@ -56,11 +58,8 @@ def __init__(self): # Projection self._proj = CRS.from_epsg(4326) - - def _fetch(self, out): - ''' - Fetch weather model data from GMAO - ''' + def _fetch(self, out) -> None: + """Fetch weather model data from GMAO.""" acqTime = self._time # calculate the array indices for slicing the GMAO variable arrays @@ -69,9 +68,9 @@ def _fetch(self, out): lon_min_ind = int((self._ll_bounds[2] - (-180.0)) / self._lon_res) lon_max_ind = int((self._ll_bounds[3] - (-180.0)) / self._lon_res) - T0 = datetime.datetime(2017, 12, 1, 0, 0, 0) + T0 = dt.datetime(2017, 12, 1, 0, 0, 0) # round time to nearest third hour - corrected_DT = round_date(acqTime, datetime.timedelta(hours=self._time_res)) + corrected_DT = round_date(acqTime, dt.timedelta(hours=self._time_res)) if not corrected_DT == acqTime: logger.warning('Rounded given datetime from %s to %s', acqTime, corrected_DT) @@ -85,31 +84,28 @@ def _fetch(self, out): url = 'https://opendap.nccs.nasa.gov/dods/GEOS-5/fp/0.25_deg/assim/inst3_3d_asm_Nv' session = pydap.cas.urs.setup_session('username', 'password', check_url=url) ds = pydap.client.open_url(url, session=session) - qv = ds['qv'].array[ - time_ind, - ml_min:(ml_max + 1), - lat_min_ind:(lat_max_ind + 1), - lon_min_ind:(lon_max_ind + 1) - ].data[0] - - p = ds['pl'].array[ - time_ind, - ml_min:(ml_max + 1), - lat_min_ind:(lat_max_ind + 1), - lon_min_ind:(lon_max_ind + 1) - ].data[0] - t = ds['t'].array[ - time_ind, - ml_min:(ml_max + 1), - lat_min_ind:(lat_max_ind + 1), - lon_min_ind:(lon_max_ind + 1) - ].data[0] - h = ds['h'].array[ - time_ind, - ml_min:(ml_max + 1), - lat_min_ind:(lat_max_ind + 1), - lon_min_ind:(lon_max_ind + 1) - ].data[0] + + p = ( + ds['pl'] + .array[ + time_ind, ml_min : (ml_max + 1), lat_min_ind : (lat_max_ind + 1), lon_min_ind : (lon_max_ind + 1) + ] + .data[0] + ) + t = ( + ds['t'] + .array[ + time_ind, ml_min : (ml_max + 1), lat_min_ind : (lat_max_ind + 1), lon_min_ind : (lon_max_ind + 1) + ] + .data[0] + ) + h = ( + ds['h'] + .array[ + time_ind, ml_min : (ml_max + 1), lat_min_ind : (lat_max_ind + 1), lon_min_ind : (lon_max_ind + 1) + ] + .data[0] + ) else: root = 'https://portal.nccs.nasa.gov/datashare/gmao/geos-fp/das/Y{}/M{:02d}/D{:02d}' @@ -127,48 +123,38 @@ def _fetch(self, out): logger.warning('Weather model already exists, skipping download') with h5py.File(f, 'r') as ds: - q = ds['QV'][0, :, lat_min_ind:(lat_max_ind + 1), lon_min_ind:(lon_max_ind + 1)] - p = ds['PL'][0, :, lat_min_ind:(lat_max_ind + 1), lon_min_ind:(lon_max_ind + 1)] - t = ds['T'][0, :, lat_min_ind:(lat_max_ind + 1), lon_min_ind:(lon_max_ind + 1)] - h = ds['H'][0, :, lat_min_ind:(lat_max_ind + 1), lon_min_ind:(lon_max_ind + 1)] + q = ds['QV'][0, :, lat_min_ind : (lat_max_ind + 1), lon_min_ind : (lon_max_ind + 1)] + p = ds['PL'][0, :, lat_min_ind : (lat_max_ind + 1), lon_min_ind : (lon_max_ind + 1)] + t = ds['T'][0, :, lat_min_ind : (lat_max_ind + 1), lon_min_ind : (lon_max_ind + 1)] + h = ds['H'][0, :, lat_min_ind : (lat_max_ind + 1), lon_min_ind : (lon_max_ind + 1)] os.remove(f) - lats = np.arange( - (-90 + lat_min_ind * self._lat_res), - (-90 + (lat_max_ind + 1) * self._lat_res), - self._lat_res - ) + lats = np.arange((-90 + lat_min_ind * self._lat_res), (-90 + (lat_max_ind + 1) * self._lat_res), self._lat_res) lons = np.arange( - (-180 + lon_min_ind * self._lon_res), - (-180 + (lon_max_ind + 1) * self._lon_res), - self._lon_res + (-180 + lon_min_ind * self._lon_res), (-180 + (lon_max_ind + 1) * self._lon_res), self._lon_res ) try: # Note that lat/lon gets written twice for GMAO because they are the same as y/x writeWeatherVarsXarray(lats, lons, h, q, p, t, dt, crs, outName=None, NoDataValue=None, chunk=(1, 91, 144)) - except Exception: - logger.exception("Unable to save weathermodel to file") - + except: + logger.exception('Unable to save weathermodel to file') - def load_weather(self, f=None): - ''' + def load_weather(self, f=None) -> None: + """ Consistent class method to be implemented across all weather model types. As a result of calling this method, all of the variables (x, y, z, p, q, t, wet_refractivity, hydrostatic refractivity, e) should be fully populated. - ''' + """ f = self.files[0] if f is None else f self._load_model_level(f) - - def _load_model_level(self, filename): - ''' - Get the variables from the GMAO link using OpenDAP - ''' - + def _load_model_level(self, filename) -> None: + """Get the variables from the GMAO link using OpenDAP.""" # adding the import here should become absolute when transition to netcdf from netCDF4 import Dataset + with Dataset(filename, mode='r') as f: lons = np.array(f.variables['x'][:]) lats = np.array(f.variables['y'][:]) @@ -178,7 +164,7 @@ def _load_model_level(self, filename): t = np.array(f.variables['T'][:]) # restructure the 1-D lat/lon in regular 2D grid - _lons, _lats= np.meshgrid(lons, lats) + _lons, _lats = np.meshgrid(lons, lats) # Re-structure everything from (heights, lats, lons) to (lons, lats, heights) p = np.transpose(p) diff --git a/tools/RAiDER/models/hres.py b/tools/RAiDER/models/hres.py index f883b38c6..926b09e79 100755 --- a/tools/RAiDER/models/hres.py +++ b/tools/RAiDER/models/hres.py @@ -1,39 +1,34 @@ -import datetime +import datetime as dt import numpy as np - from pyproj import CRS from RAiDER.models.ecmwf import ECMWF -from RAiDER.models.weatherModel import WeatherModel, TIME_RES from RAiDER.models.model_levels import ( - LEVELS_91_HEIGHTS, - LEVELS_25_HEIGHTS, A_91_HRES, B_91_HRES, + LEVELS_91_HEIGHTS, ) +from RAiDER.models.weatherModel import TIME_RES, WeatherModel class HRES(ECMWF): - ''' - Implement ECMWF models - ''' + """Implement ECMWF models.""" - def __init__(self, level_type='ml'): + def __init__(self, level_type='ml') -> None: # initialize a weather model WeatherModel.__init__(self) # model constants - self._k1 = 0.776 # [K/Pa] - self._k2 = 0.233 # [K/Pa] + self._k1 = 0.776 # [K/Pa] + self._k2 = 0.233 # [K/Pa] self._k3 = 3.75e3 # [K^2/Pa] - # 9 km horizontal grid spacing. This is only used for extending the download-buffer, i.e. not in subsequent processing. - self._lon_res = 9. / 111 # 0.08108115 - self._lat_res = 9. / 111 # 0.08108115 - self._x_res = 9. / 111 # 0.08108115 - self._y_res = 9. / 111 # 0.08108115 + self._lon_res = 9.0 / 111 # 0.08108115 + self._lat_res = 9.0 / 111 # 0.08108115 + self._x_res = 9.0 / 111 # 0.08108115 + self._y_res = 9.0 / 111 # 0.08108115 self._humidityType = 'q' # Default, pressure levels are 'pl' @@ -45,14 +40,16 @@ def __init__(self, level_type='ml'): self._time_res = TIME_RES[self._dataset.upper()] # Tuple of min/max years where data is available. - self._valid_range = (datetime.datetime(1983, 4, 20).replace(tzinfo=datetime.timezone(offset=datetime.timedelta())), - datetime.datetime.now(datetime.timezone.utc)) + self._valid_range = ( + dt.datetime(1983, 4, 20).replace(tzinfo=dt.timezone(offset=dt.timedelta())), + dt.datetime.now(dt.timezone.utc), + ) # Availability lag time in days - self._lag_time = datetime.timedelta(hours=6) + self._lag_time = dt.timedelta(hours=6) self.setLevelType('ml') - def update_a_b(self): + def update_a_b(self) -> None: # Before 2013-06-26, there were only 91 model levels. The mapping coefficients below are extracted # based on https://www.ecmwf.int/en/forecasts/documentation-and-support/91-model-levels self._levels = 91 @@ -60,31 +57,29 @@ def update_a_b(self): self._a = A_91_HRES self._b = B_91_HRES - def load_weather(self, f=None): - ''' + def load_weather(self, f=None) -> None: + """ Consistent class method to be implemented across all weather model types. As a result of calling this method, all of the variables (x, y, z, p, q, t, wet_refractivity, hydrostatic refractivity, e) should be fully populated. - ''' + """ f = self.files[0] if f is None else f if self._model_level_type == 'ml': - if (self._time < datetime.datetime(2013, 6, 26, 0, 0, 0)): + if self._time < dt.datetime(2013, 6, 26, 0, 0, 0): self.update_a_b() self._load_model_level(f) elif self._model_level_type == 'pl': self._load_pressure_levels(f) - def _fetch(self,out): - ''' - Fetch a weather model from ECMWF - ''' + def _fetch(self, out) -> None: + """Fetch a weather model from ECMWF.""" # bounding box plus a buffer lat_min, lat_max, lon_min, lon_max = self._ll_bounds time = self._time - if (time < datetime.datetime(2013, 6, 26, 0, 0, 0)): + if time < dt.datetime(2013, 6, 26, 0, 0, 0): self.update_a_b() # execute the search at ECMWF diff --git a/tools/RAiDER/models/hrrr.py b/tools/RAiDER/models/hrrr.py index 3d55e501b..de872d597 100644 --- a/tools/RAiDER/models/hrrr.py +++ b/tools/RAiDER/models/hrrr.py @@ -1,41 +1,44 @@ -import datetime +import datetime as dt import os -import rioxarray -import xarray +from pathlib import Path +import geopandas as gpd import numpy as np - +import xarray as xr from herbie import Herbie -import geopandas as gpd -from pathlib import Path from pyproj import CRS, Transformer from shapely.geometry import Polygon, box -from RAiDER.utilFcns import round_date, transform_coords, rio_profile, rio_stats -from RAiDER.models.weatherModel import WeatherModel, TIME_RES -from RAiDER.models.model_levels import LEVELS_50_HEIGHTS, LEVELS_137_HEIGHTS from RAiDER.logger import logger +from RAiDER.models.model_levels import LEVELS_50_HEIGHTS +from RAiDER.models.weatherModel import TIME_RES, WeatherModel +from RAiDER.utilFcns import round_date + HRRR_CONUS_COVERAGE_POLYGON = Polygon(((-125, 21), (-133, 49), (-60, 49), (-72, 21))) HRRR_AK_COVERAGE_POLYGON = Polygon(((195, 40), (157, 55), (175, 70), (260, 77), (232, 52))) -HRRR_AK_PROJ = CRS.from_string('+proj=stere +ellps=sphere +a=6371229.0 +b=6371229.0 +lat_0=90 +lon_0=225.0 ' - '+x_0=0.0 +y_0=0.0 +lat_ts=60.0 +no_defs +type=crs') +HRRR_AK_PROJ = CRS.from_string( + '+proj=stere +ellps=sphere +a=6371229.0 +b=6371229.0 +lat_0=90 +lon_0=225.0 ' + '+x_0=0.0 +y_0=0.0 +lat_ts=60.0 +no_defs +type=crs' +) # Source: https://eric.clst.org/tech/usgeojson/ AK_GEO = gpd.read_file(Path(__file__).parent / 'data' / 'alaska.geojson.zip').geometry.unary_union -def check_hrrr_dataset_availability(dt: datetime) -> bool: - """Note a file could still be missing within the models valid range""" - H = Herbie(dt, - model='hrrr', - product='nat', - fxx=0) - avail = (H.grib_source is not None) - return avail +def check_hrrr_dataset_availability(datetime: dt.datetime) -> bool: + """Note a file could still be missing within the models valid range.""" + herbie = Herbie( + datetime, + model='hrrr', + product='nat', + fxx=0, + ) + return herbie.grib_source is not None -def download_hrrr_file(ll_bounds, DATE, out, model='hrrr', product='nat', fxx=0, verbose=False): - ''' - Download a HRRR weather model using Herbie + +def download_hrrr_file(ll_bounds, DATE, out, model='hrrr', product='nat', fxx=0, verbose=False) -> None: + """ + Download a HRRR weather model using Herbie. Args: DATE (Python datetime) - Datetime as a Python datetime. Herbie will automatically return the closest valid time, @@ -48,8 +51,8 @@ def download_hrrr_file(ll_bounds, DATE, out, model='hrrr', product='nat', fxx=0, Returns: None, writes data to a netcdf file - ''' - H = Herbie( + """ + herbie = Herbie( DATE.strftime('%Y-%m-%d %H:%M'), model=model, product=product, @@ -59,13 +62,12 @@ def download_hrrr_file(ll_bounds, DATE, out, model='hrrr', product='nat', fxx=0, save_dir=Path(os.path.dirname(out)), ) - # Iterate through the list of datasets try: - ds_list = H.xarray(":(SPFH|PRES|TMP|HGT):", verbose=verbose) - except ValueError as E: - logger.error (E) - raise ValueError + ds_list = herbie.xarray(':(SPFH|PRES|TMP|HGT):', verbose=verbose) + except ValueError as e: + logger.error(e) + raise ds_out = None # Note order coord names are request for `test_HRRR_ztd` matters @@ -95,36 +97,31 @@ def download_hrrr_file(ll_bounds, DATE, out, model='hrrr', product='nat', fxx=0, ds_out = ds_out.rename({'gh': 'z', coord: 'levels'}) # projection information - ds_out["proj"] = int() + ds_out['proj'] = 0 for k, v in CRS.from_user_input(ds_out.herbie.crs).to_cf().items(): ds_out.proj.attrs[k] = v for var in ds_out.data_vars: ds_out[var].attrs['grid_mapping'] = 'proj' - # pull the grid information proj = CRS.from_cf(ds_out['proj'].attrs) t = Transformer.from_crs(4326, proj, always_xy=True) xl, yl = t.transform(ds_out['longitude'].values, ds_out['latitude'].values) W, E, S, N = np.nanmin(xl), np.nanmax(xl), np.nanmin(yl), np.nanmax(yl) - grid_x = 3000 # meters - grid_y = 3000 # meters - xs = np.arange(W, E+grid_x/2, grid_x) - ys = np.arange(S, N+grid_y/2, grid_y) + grid_x = 3000 # meters + grid_y = 3000 # meters + xs = np.arange(W, E + grid_x / 2, grid_x) + ys = np.arange(S, N + grid_y / 2, grid_y) ds_out['x'] = xs ds_out['y'] = ys ds_sub = ds_out.isel(x=slice(x_min, x_max), y=slice(y_min, y_max)) ds_sub.to_netcdf(out, engine='netcdf4') - return - def get_bounds_indices(SNWE, lats, lons): - ''' - Convert SNWE lat/lon bounds to index bounds - ''' + """Convert SNWE lat/lon bounds to index bounds.""" # Unpack the bounds and find the relevent indices S, N, W, E = SNWE @@ -133,8 +130,8 @@ def get_bounds_indices(SNWE, lats, lons): m1 = (S <= lats) & (N >= lats) & (W <= lons) & (E >= lons) else: raise ValueError( - 'Longitude is either flipped or you are crossing the international date line;' + - 'if the latter please give me longitudes from 0-360' + 'Longitude is either flipped or you are crossing the international date line;' + + 'if the latter please give me longitudes from 0-360' ) if np.sum(m1) == 0: @@ -162,13 +159,10 @@ def get_bounds_indices(SNWE, lats, lons): def load_weather_hrrr(filename): - ''' - Loads a weather model from a HRRR file - ''' + """Loads a weather model from a HRRR file.""" # read data from the netcdf file - ds = xarray.open_dataset(filename, engine='netcdf4') + ds = xr.open_dataset(filename, engine='netcdf4') # Pull the relevant data from the file - pl = ds.levels.values pres = ds['pres'].values.transpose(1, 2, 0) xArr = ds['x'].values yArr = ds['y'].values @@ -183,16 +177,14 @@ def load_weather_hrrr(filename): lons[lons > 180] -= 360 # data cube format should be lats,lons,heights - _xs = np.broadcast_to(xArr[np.newaxis, :, np.newaxis], - geo_hgt.shape) - _ys = np.broadcast_to(yArr[:, np.newaxis, np.newaxis], - geo_hgt.shape) + _xs = np.broadcast_to(xArr[np.newaxis, :, np.newaxis], geo_hgt.shape) + _ys = np.broadcast_to(yArr[:, np.newaxis, np.newaxis], geo_hgt.shape) return _xs, _ys, lons, lats, qs, temps, pres, geo_hgt, proj class HRRR(WeatherModel): - def __init__(self): + def __init__(self) -> None: # initialize a weather model super().__init__() @@ -206,10 +198,10 @@ def __init__(self): # Tuple of min/max years where data is available. self._valid_range = ( - datetime.datetime(2016, 7, 15).replace(tzinfo=datetime.timezone(offset=datetime.timedelta())), - datetime.datetime.now(datetime.timezone.utc) + dt.datetime(2016, 7, 15).replace(tzinfo=dt.timezone(offset=dt.timedelta())), + dt.datetime.now(dt.timezone.utc), ) - self._lag_time = datetime.timedelta(hours=3) # Availability lag time in days + self._lag_time = dt.timedelta(hours=3) # Availability lag time in days # model constants self._k1 = 0.776 # [K/Pa] @@ -217,10 +209,10 @@ def __init__(self): self._k3 = 3.75e3 # [K^2/Pa] # 3 km horizontal grid spacing - self._lat_res = 3. / 111 - self._lon_res = 3. / 111 - self._x_res = 3. - self._y_res = 3. + self._lat_res = 3.0 / 111 + self._lon_res = 3.0 / 111 + self._x_res = 3.0 + self._y_res = 3.0 self._Nproc = 1 self._Name = 'HRRR' @@ -245,28 +237,25 @@ def __init__(self): x0 = 0 y0 = 0 earth_radius = 6371229 - self._proj = CRS(f'+proj=lcc +lat_1={lat1} +lat_2={lat2} +lat_0={lat0} '\ - f'+lon_0={lon0} +x_0={x0} +y_0={y0} +a={earth_radius} '\ - f'+b={earth_radius} +units=m +no_defs') + self._proj = CRS( + f'+proj=lcc +lat_1={lat1} +lat_2={lat2} +lat_0={lat0} ' + f'+lon_0={lon0} +x_0={x0} +y_0={y0} +a={earth_radius} ' + f'+b={earth_radius} +units=m +no_defs' + ) self._valid_bounds = HRRR_CONUS_COVERAGE_POLYGON self.setLevelType('nat') - def __model_levels__(self): - self._levels = 50 + self._levels = 50 self._zlevels = np.flipud(LEVELS_50_HEIGHTS) - def __pressure_levels__(self): raise NotImplementedError('Pressure levels do not go high enough for HRRR.') - - def _fetch(self, out): - ''' - Fetch weather model data from HRRR - ''' + def _fetch(self, out) -> None: + """Fetch weather model data from HRRR.""" self._files = out - corrected_DT = round_date(self._time, datetime.timedelta(hours=self._time_res)) + corrected_DT = round_date(self._time, dt.timedelta(hours=self._time_res)) self.checkTime(corrected_DT) if not corrected_DT == self._time: logger.info('Rounded given datetime from %s to %s', self._time, corrected_DT) @@ -277,16 +266,14 @@ def _fetch(self, out): download_hrrr_file(bounds, corrected_DT, out, 'hrrr', self._model_level_type) - - def load_weather(self, f=None, *args, **kwargs): - ''' + def load_weather(self, f=None, *args, **kwargs) -> None: + """ Load a weather model into a python weatherModel object, from self.files if no filename is passed. - ''' + """ if f is None: f = self.files[0] if isinstance(self.files, list) else self.files - _xs, _ys, _lons, _lats, qs, temps, pres, geo_hgt, proj = load_weather_hrrr(f) # convert geopotential height to geometric height @@ -301,18 +288,14 @@ def load_weather(self, f=None, *args, **kwargs): self._lons = _lons self._proj = proj - - def checkValidBounds(self: WeatherModel, ll_bounds: np.ndarray): - ''' + def checkValidBounds(self, ll_bounds: np.ndarray) -> None: + """ Checks whether the given bounding box is valid for the HRRR or HRRRAK - (i.e., intersects with the model domain at all) + (i.e., intersects with the model domain at all). Args: ll_bounds : np.ndarray - - Returns: - The weather model object - ''' + """ S, N, W, E = ll_bounds aoi = box(W, S, E, N) if self._valid_bounds.contains(aoi): @@ -326,7 +309,7 @@ def checkValidBounds(self: WeatherModel, ll_bounds: np.ndarray): Mod = HRRRAK() # valid bounds are in 0->360 to account for dateline crossing W, E = np.mod([W, E], 360) - aoi = box(W, S, E, N) + aoi = box(W, S, E, N) if Mod._valid_bounds.contains(aoi): pass elif aoi.intersects(Mod._valid_bounds): @@ -335,11 +318,9 @@ def checkValidBounds(self: WeatherModel, ll_bounds: np.ndarray): else: raise ValueError('The requested location is unavailable for HRRR') - return Mod - class HRRRAK(WeatherModel): - def __init__(self): + def __init__(self) -> None: # The HRRR-AK model has a few different parameters than HRRR-CONUS. # These will get used if a user requests a bounding box in Alaska super().__init__() @@ -350,10 +331,10 @@ def __init__(self): self._k3 = 3.75e3 # [K^2/Pa] # 3 km horizontal grid spacing - self._lat_res = 3. / 111 - self._lon_res = 3. / 111 - self._x_res = 3. - self._y_res = 3. + self._lat_res = 3.0 / 111 + self._lon_res = 3.0 / 111 + self._x_res = 3.0 + self._y_res = 3.0 self._Nproc = 1 self._Npl = 0 @@ -362,41 +343,39 @@ def __init__(self): self._classname = 'hrrrak' self._dataset = 'hrrrak' - self._Name = "HRRR-AK" + self._Name = 'HRRR-AK' self._time_res = TIME_RES['HRRR-AK'] self._valid_range = ( - datetime.datetime(2018, 7, 13).replace(tzinfo=datetime.timezone(offset=datetime.timedelta())), - datetime.datetime.now(datetime.timezone.utc) + dt.datetime(2018, 7, 13).replace(tzinfo=dt.timezone(offset=dt.timedelta())), + dt.datetime.now(dt.timezone.utc), ) - self._lag_time = datetime.timedelta(hours=3) + self._lag_time = dt.timedelta(hours=3) self._valid_bounds = HRRR_AK_COVERAGE_POLYGON # The projection information gets read directly from the weather model file but we # keep this here for object instantiation. self._proj = HRRR_AK_PROJ self.setLevelType('nat') - def __model_levels__(self): - self._levels = 50 + self._levels = 50 self._zlevels = np.flipud(LEVELS_50_HEIGHTS) - def __pressure_levels__(self): - raise NotImplementedError('hrrr.py: Revisit whether or not pressure levels from HRRR can be used for delay calculations; they do not go high enough compared to native model levels.') - + raise NotImplementedError( + 'hrrr.py: Revisit whether or not pressure levels from HRRR can be used for delay calculations; they do not go high enough compared to native model levels.' + ) - def _fetch(self, out): + def _fetch(self, out) -> None: bounds = self._ll_bounds.copy() bounds[2:] = np.mod(bounds[2:], 360) - corrected_DT = round_date(self._time, datetime.timedelta(hours=self._time_res)) + corrected_DT = round_date(self._time, dt.timedelta(hours=self._time_res)) self.checkTime(corrected_DT) if not corrected_DT == self._time: - logger.info('Rounded given datetime from {} to {}'.format(self._time, corrected_DT)) + logger.info(f'Rounded given datetime from {self._time} to {corrected_DT}') download_hrrr_file(bounds, corrected_DT, out, 'hrrrak', self._model_level_type) - - def load_weather(self, f=None, *args, **kwargs): + def load_weather(self, f=None, *args, **kwargs) -> None: if f is None: f = self.files[0] if isinstance(self.files, list) else self.files _xs, _ys, _lons, _lats, qs, temps, pres, geo_hgt, proj = load_weather_hrrr(f) diff --git a/tools/RAiDER/models/merra2.py b/tools/RAiDER/models/merra2.py index c0c111e3e..82107e558 100755 --- a/tools/RAiDER/models/merra2.py +++ b/tools/RAiDER/models/merra2.py @@ -1,25 +1,23 @@ -import io +import datetime as dt import os -import xarray -import datetime import numpy as np import pydap.cas.urs import pydap.client - +import xarray as xr from pyproj import CRS -from RAiDER.models.weatherModel import WeatherModel from RAiDER.logger import logger -from RAiDER.utilFcns import writeWeatherVarsXarray, read_EarthData_loginInfo from RAiDER.models.model_levels import ( LEVELS_137_HEIGHTS, ) +from RAiDER.models.weatherModel import WeatherModel +from RAiDER.utilFcns import read_EarthData_loginInfo, writeWeatherVarsXarray # Path to Netrc file, can be controlled by env var # Useful for containers - similar to CDSAPI_RC -EARTHDATA_RC = os.environ.get("EARTHDATA_RC", None) +EARTHDATA_RC = os.environ.get('EARTHDATA_RC', None) def Model(): @@ -27,9 +25,9 @@ def Model(): class MERRA2(WeatherModel): - def __init__(self): - + def __init__(self) -> None: import calendar + # initialize a weather model WeatherModel.__init__(self) @@ -40,15 +38,17 @@ def __init__(self): self._dataset = 'merra2' # Tuple of min/max years where data is available. - utcnow = datetime.datetime.now(datetime.timezone.utc) - enddate = datetime.datetime(utcnow.year, utcnow.month, 15) - datetime.timedelta(days=60) - enddate = datetime.datetime(enddate.year, enddate.month, calendar.monthrange(enddate.year, enddate.month)[1]) - self._valid_range = (datetime.datetime(1980, 1, 1).replace(tzinfo=datetime.timezone(offset=datetime.timedelta())), - datetime.datetime.now(datetime.timezone.utc)) - lag_time = utcnow - enddate.replace(tzinfo=datetime.timezone(offset=datetime.timedelta())) - self._lag_time = datetime.timedelta(days=lag_time.days) # Availability lag time in days + utcnow = dt.datetime.now(dt.timezone.utc) + enddate = dt.datetime(utcnow.year, utcnow.month, 15) - dt.timedelta(days=60) + enddate = dt.datetime(enddate.year, enddate.month, calendar.monthrange(enddate.year, enddate.month)[1]) + self._valid_range = ( + dt.datetime(1980, 1, 1).replace(tzinfo=dt.timezone(offset=dt.timedelta())), + dt.datetime.now(dt.timezone.utc), + ) + lag_time = utcnow - enddate.replace(tzinfo=dt.timezone(offset=dt.timedelta())) + self._lag_time = dt.timedelta(days=lag_time.days) # Availability lag time in days self._time_res = 1 - + # model constants self._k1 = 0.776 # [K/Pa] self._k2 = 0.233 # [K/Pa] @@ -68,12 +68,10 @@ def __init__(self): # Projection self._proj = CRS.from_epsg(4326) - def _fetch(self, out): - ''' - Fetch weather model data from GMAO: note we only extract the lat/lon bounds for this weather model; fetching data is not needed here as we don't actually download any data using OpenDAP - ''' - time = self._time - + def _fetch(self, out) -> None: + """Fetch weather model data from GMAO: note we only extract the lat/lon bounds for this weather model; fetching data is not needed here as we don't actually download any data using OpenDAP.""" + time = self._time + # check whether the file already exists if os.path.exists(out): return @@ -84,15 +82,9 @@ def _fetch(self, out): lon_min_ind = int((self._ll_bounds[2] - (-180.0)) / self._lon_res) lon_max_ind = int((self._ll_bounds[3] - (-180.0)) / self._lon_res) - lats = np.arange( - (-90 + lat_min_ind * self._lat_res), - (-90 + (lat_max_ind + 1) * self._lat_res), - self._lat_res - ) + lats = np.arange((-90 + lat_min_ind * self._lat_res), (-90 + (lat_max_ind + 1) * self._lat_res), self._lat_res) lons = np.arange( - (-180 + lon_min_ind * self._lon_res), - (-180 + (lon_max_ind + 1) * self._lon_res), - self._lon_res + (-180 + lon_min_ind * self._lon_res), (-180 + (lon_max_ind + 1) * self._lon_res), self._lon_res ) lon, lat = np.meshgrid(lons, lats) @@ -110,40 +102,45 @@ def _fetch(self, out): earthdata_usr, earthdata_pwd = read_EarthData_loginInfo(EARTHDATA_RC) # open the dataset and pull the data - url = 'https://goldsmr5.gesdisc.eosdis.nasa.gov/opendap/MERRA2/M2T3NVASM.5.12.4/' + time.strftime('%Y/%m') + '/MERRA2_' + str(url_sub) + '.tavg3_3d_asm_Nv.' + time.strftime('%Y%m%d') + '.nc4' + url = ( + 'https://goldsmr5.gesdisc.eosdis.nasa.gov/opendap/MERRA2/M2T3NVASM.5.12.4/' + + time.strftime('%Y/%m') + + '/MERRA2_' + + str(url_sub) + + '.tavg3_3d_asm_Nv.' + + time.strftime('%Y%m%d') + + '.nc4' + ) session = pydap.cas.urs.setup_session(earthdata_usr, earthdata_pwd, check_url=url) stream = pydap.client.open_url(url, session=session) - q = stream['QV'][0,:,lat_min_ind:lat_max_ind + 1, lon_min_ind:lon_max_ind + 1].data.squeeze() - p = stream['PL'][0,:,lat_min_ind:lat_max_ind + 1, lon_min_ind:lon_max_ind + 1].data.squeeze() - t = stream['T'][0,:,lat_min_ind:lat_max_ind + 1, lon_min_ind:lon_max_ind + 1].data.squeeze() - h = stream['H'][0,:,lat_min_ind:lat_max_ind + 1, lon_min_ind:lon_max_ind + 1].data.squeeze() + q = stream['QV'][0, :, lat_min_ind : lat_max_ind + 1, lon_min_ind : lon_max_ind + 1].data.squeeze() + p = stream['PL'][0, :, lat_min_ind : lat_max_ind + 1, lon_min_ind : lon_max_ind + 1].data.squeeze() + t = stream['T'][0, :, lat_min_ind : lat_max_ind + 1, lon_min_ind : lon_max_ind + 1].data.squeeze() + h = stream['H'][0, :, lat_min_ind : lat_max_ind + 1, lon_min_ind : lon_max_ind + 1].data.squeeze() try: writeWeatherVarsXarray(lat, lon, h, q, p, t, time, self._proj, outName=out) except Exception as e: logger.debug(e) - logger.exception("MERRA-2: Unable to save weathermodel to file") - raise RuntimeError('MERRA-2 failed with the following error: {}'.format(e)) + logger.exception('MERRA-2: Unable to save weathermodel to file') + raise RuntimeError(f'MERRA-2 failed with the following error: {e}') - def load_weather(self, f=None, *args, **kwargs): - ''' + def load_weather(self, f=None, *args, **kwargs) -> None: + """ Consistent class method to be implemented across all weather model types. As a result of calling this method, all of the variables (x, y, z, p, q, t, wet_refractivity, hydrostatic refractivity, e) should be fully populated. - ''' + """ f = self.files[0] if f is None else f self._load_model_level(f) - def _load_model_level(self, filename): - ''' - Get the variables from the GMAO link using OpenDAP - ''' - + def _load_model_level(self, filename) -> None: + """Get the variables from the GMAO link using OpenDAP.""" # adding the import here should become absolute when transition to netcdf - ds = xarray.load_dataset(filename) + ds = xr.load_dataset(filename) lons = ds['longitude'].values lats = ds['latitude'].values h = ds['h'].values diff --git a/tools/RAiDER/models/model_levels.py b/tools/RAiDER/models/model_levels.py index db3e122b6..56891580d 100644 --- a/tools/RAiDER/models/model_levels.py +++ b/tools/RAiDER/models/model_levels.py @@ -1,5 +1,5 @@ -''' -Pre-defined model levels and a, b constants for the different weather models +""" +Pre-defined model levels and a, b constants for the different weather models. **NOTE**: The fixed heights used here are from ECMWF's _geometric_ altitudes (https://confluence.ecmwf.int/display/UDOC/L137+model+level+definitions), @@ -7,7 +7,7 @@ the altitude to include the variation of gravity with height, while geometric altitude is the standard direct vertical distance above mean sea level (MSL)." - Wikipedia.org, https://en.wikipedia.org/wiki/International_Standard_Atmosphere -''' +""" LEVELS_137_HEIGHTS = [ 80301.65, @@ -506,7 +506,7 @@ -500, ] -## HRRR Model Levels +# HRRR Model Levels # Computed according to: H = a + b * Z where: # H is the resulting levels in geometric height # a is the Surface geopotential height (in meters) @@ -514,16 +514,18 @@ # averaged in space over CONUS # b is the native (sigma) model levels (https://rapidrefresh.noaa.gov/faq/HRRR.faq.html) # Z is the spatial average geopotential height of the sigma level (in meters) -LEVELS_50_HEIGHTS = [2.61580385e+04, 2.48712879e+04, 2.36910518e+04, 2.25524744e+04, - 2.13986900e+04, 2.02464207e+04, 1.90883153e+04, 1.79427740e+04, - 1.68476065e+04, 1.57399654e+04, 1.45826790e+04, 1.33886515e+04, - 1.22171878e+04, 1.11019360e+04, 1.00395775e+04, 9.01965365e+03, - 8.03486128e+03, 7.09323111e+03, 6.27822334e+03, 5.57101666e+03, - 4.96120000e+03, 4.42159162e+03, 3.94118518e+03, 3.51064883e+03, - 3.12371808e+03, 2.77490670e+03, 2.45941860e+03, 2.17290722e+03, - 1.90394551e+03, 1.66716448e+03, 1.44127808e+03, 1.22697117e+03, - 1.02507126e+03, 8.38877887e+02, 6.74297597e+02, 5.34810131e+02, - 4.18916771e+02, 3.23291544e+02, 2.44985788e+02, 1.81492083e+02, - 1.34383211e+02, 1.02007390e+02, 7.70762881e+01, 5.77739913e+01, - 4.31591299e+01, 3.26389095e+01, 2.52657431e+01, 2.02104423e+01, - 1.66520787e+01, 1.39366382e+01, 0, -10, -20, -50, -100, -200, -500] \ No newline at end of file +LEVELS_50_HEIGHTS = [ + 2.61580385e+04, 2.48712879e+04, 2.36910518e+04, 2.25524744e+04, + 2.13986900e+04, 2.02464207e+04, 1.90883153e+04, 1.79427740e+04, + 1.68476065e+04, 1.57399654e+04, 1.45826790e+04, 1.33886515e+04, + 1.22171878e+04, 1.11019360e+04, 1.00395775e+04, 9.01965365e+03, + 8.03486128e+03, 7.09323111e+03, 6.27822334e+03, 5.57101666e+03, + 4.96120000e+03, 4.42159162e+03, 3.94118518e+03, 3.51064883e+03, + 3.12371808e+03, 2.77490670e+03, 2.45941860e+03, 2.17290722e+03, + 1.90394551e+03, 1.66716448e+03, 1.44127808e+03, 1.22697117e+03, + 1.02507126e+03, 8.38877887e+02, 6.74297597e+02, 5.34810131e+02, + 4.18916771e+02, 3.23291544e+02, 2.44985788e+02, 1.81492083e+02, + 1.34383211e+02, 1.02007390e+02, 7.70762881e+01, 5.77739913e+01, + 4.31591299e+01, 3.26389095e+01, 2.52657431e+01, 2.02104423e+01, + 1.66520787e+01, 1.39366382e+01, 0, -10, -20, -50, -100, -200, -500 +] diff --git a/tools/RAiDER/models/ncmr.py b/tools/RAiDER/models/ncmr.py index 394ab9538..f448c727b 100755 --- a/tools/RAiDER/models/ncmr.py +++ b/tools/RAiDER/models/ncmr.py @@ -2,59 +2,59 @@ Created on Wed Sep 9 10:26:44 2020 @author: prashant Modified by Yang Lei, GPS/Caltech """ -import datetime + +import datetime as dt import os import urllib.request import numpy as np - from pyproj import CRS -from RAiDER.models.weatherModel import WeatherModel, TIME_RES from RAiDER.logger import logger +from RAiDER.models.model_levels import ( + LEVELS_137_HEIGHTS, +) +from RAiDER.models.weatherModel import TIME_RES, WeatherModel from RAiDER.utilFcns import ( read_NCMR_loginInfo, show_progress, writeWeatherVarsXarray, ) -from RAiDER.models.model_levels import ( - LEVELS_137_HEIGHTS, -) class NCMR(WeatherModel): - ''' - Implement NCMRWF NCUM (named as NCMR) model in future - ''' + """Implement NCMRWF NCUM (named as NCMR) model in future.""" - def __init__(self): + def __init__(self) -> None: # initialize a weather model WeatherModel.__init__(self) - self._humidityType = 'q' # q for specific humidity and rh for relative humidity - self._model_level_type = 'ml' # Default, pressure levels are 'pl', and model levels are "ml" - self._classname = 'ncmr' # name of the custom weather model - self._dataset = 'ncmr' # same name as above - self._Name = 'NCMR' # name of the new weather model (in Capital) + self._humidityType = 'q' # q for specific humidity and rh for relative humidity + self._model_level_type = 'ml' # Default, pressure levels are 'pl', and model levels are "ml" + self._classname = 'ncmr' # name of the custom weather model + self._dataset = 'ncmr' # same name as above + self._Name = 'NCMR' # name of the new weather model (in Capital) self._time_res = TIME_RES[self._dataset.upper()] # Tuple of min/max years where data is available. - self._valid_range = (datetime.datetime(2015, 12, 1).replace(tzinfo=datetime.timezone(offset=datetime.timedelta())), - datetime.datetime.now(datetime.timezone.utc)) + self._valid_range = ( + dt.datetime(2015, 12, 1).replace(tzinfo=dt.timezone(offset=dt.timedelta())), + dt.datetime.now(dt.timezone.utc), + ) # Availability lag time in days/hours - self._lag_time = datetime.timedelta(hours=6) + self._lag_time = dt.timedelta(hours=6) # model constants - self._k1 = 0.776 # [K/Pa] - self._k2 = 0.233 # [K/Pa] + self._k1 = 0.776 # [K/Pa] + self._k2 = 0.233 # [K/Pa] self._k3 = 3.75e3 # [K^2/Pa] # horizontal grid spacing - self._lon_res = .17578125 # grid spacing in longitude - self._lat_res = .11718750 # grid spacing in latitude + self._lon_res = 0.17578125 # grid spacing in longitude + self._lat_res = 0.11718750 # grid spacing in latitude - self._x_res = .17578125 # same as longitude - self._y_res = .11718750 # same as latitude + self._x_res = 0.17578125 # same as longitude + self._y_res = 0.11718750 # same as latitude self._zlevels = np.flipud(LEVELS_137_HEIGHTS) @@ -63,23 +63,21 @@ def __init__(self): # Projection self._proj = CRS.from_epsg(4326) - def _fetch(self, out): - ''' + def _fetch(self, out) -> None: + """ Fetch weather model data from NCMR: note we only extract the lat/lon bounds for this weather model; - fetching data is not needed here as we don't actually download data , data exist in same system - ''' + fetching data is not needed here as we don't actually download data, data exist in same system. + """ time = self._time # Auxillary function: - ''' + """ download data of the NCMR model and save it in desired location - ''' + """ self._files = self._download_ncmr_file(out, time, self._ll_bounds) - def load_weather(self, f=None, *args, **kwargs): - ''' - Load NCMR model variables from existing file - ''' + def load_weather(self, f=None, *args, **kwargs) -> None: + """Load NCMR model variables from existing file.""" f = self.files[0] if f is None else f # bounding box plus a buffer @@ -88,12 +86,11 @@ def load_weather(self, f=None, *args, **kwargs): self._makeDataCubes(f) - def _download_ncmr_file(self, out, date_time, bounding_box): - ''' + def _download_ncmr_file(self, out, date_time, bounding_box) -> None: + """ Download weather model data (whole globe) from NCMR weblink, crop it to the region of interest, and save the cropped data as a standard .nc file of RAiDER (e.g. "NCMR_YYYY_MM_DD_THH_MM_SS.nc"); - Temporarily download data from NCMR ftp 'https://ftp.ncmrwf.gov.in/pub/outgoing/SAC/NCUM_OSF/' and copied in weather_models folder - ''' - + Temporarily download data from NCMR ftp 'https://ftp.ncmrwf.gov.in/pub/outgoing/SAC/NCUM_OSF/' and copied in weather_models folder. + """ from netCDF4 import Dataset ############# Use these lines and modify the link when actually downloading NCMR data from a weblink ############# @@ -109,17 +106,17 @@ def _download_ncmr_file(self, out, date_time, bounding_box): ######################################################################################################################## ############# For debugging: use pre-downloaded files; Remove/comment out it when actually downloading NCMR data from a weblink ############# -# filepath = os.path.dirname(out) + '/NCUM_ana_mdllev_20180701_00z.nc' + # filepath = os.path.dirname(out) + '/NCUM_ana_mdllev_20180701_00z.nc' ######################################################################################################################## # calculate the array indices for slicing the GMAO variable arrays lat_min_ind = int((self._bounds[0] - (-89.94141)) / self._lat_res) lat_max_ind = int((self._bounds[1] - (-89.94141)) / self._lat_res) - if (self._bounds[2] < 0.0): + if self._bounds[2] < 0.0: lon_min_ind = int((self._bounds[2] + 360.0 - (0.087890625)) / self._lon_res) else: lon_min_ind = int((self._bounds[2] - (0.087890625)) / self._lon_res) - if (self._bounds[3] < 0.0): + if self._bounds[3] < 0.0: lon_max_ind = int((self._bounds[3] + 360.0 - (0.087890625)) / self._lon_res) else: lon_max_ind = int((self._bounds[3] - (0.087890625)) / self._lon_res) @@ -128,41 +125,63 @@ def _download_ncmr_file(self, out, date_time, bounding_box): ml_max = 70 with Dataset(filepath, 'r', maskandscale=True) as f: - lats = f.variables['latitude'][lat_min_ind:(lat_max_ind + 1)].copy() - if (self._bounds[2] * self._bounds[3] < 0): + lats = f.variables['latitude'][lat_min_ind : (lat_max_ind + 1)].copy() + if self._bounds[2] * self._bounds[3] < 0: lons1 = f.variables['longitude'][lon_min_ind:].copy() - lons2 = f.variables['longitude'][0:(lon_max_ind + 1)].copy() + lons2 = f.variables['longitude'][0 : (lon_max_ind + 1)].copy() lons = np.append(lons1, lons2) else: - lons = f.variables['longitude'][lon_min_ind:(lon_max_ind + 1)].copy() - if (self._bounds[2] * self._bounds[3] < 0): - t1 = f.variables['air_temperature'][ml_min:(ml_max + 1), lat_min_ind:(lat_max_ind + 1), lon_min_ind:].copy() - t2 = f.variables['air_temperature'][ml_min:(ml_max + 1), lat_min_ind:(lat_max_ind + 1), 0:(lon_max_ind + 1)].copy() + lons = f.variables['longitude'][lon_min_ind : (lon_max_ind + 1)].copy() + if self._bounds[2] * self._bounds[3] < 0: + t1 = f.variables['air_temperature'][ + ml_min : (ml_max + 1), lat_min_ind : (lat_max_ind + 1), lon_min_ind: + ].copy() + t2 = f.variables['air_temperature'][ + ml_min : (ml_max + 1), lat_min_ind : (lat_max_ind + 1), 0 : (lon_max_ind + 1) + ].copy() t = np.append(t1, t2, axis=2) else: - t = f.variables['air_temperature'][ml_min:(ml_max + 1), lat_min_ind:(lat_max_ind + 1), lon_min_ind:(lon_max_ind + 1)].copy() + t = f.variables['air_temperature'][ + ml_min : (ml_max + 1), lat_min_ind : (lat_max_ind + 1), lon_min_ind : (lon_max_ind + 1) + ].copy() # Skipping first pressure levels (below 20 meter) - if (self._bounds[2] * self._bounds[3] < 0): - q1 = f.variables['specific_humidity'][(ml_min + 1):(ml_max + 1), lat_min_ind:(lat_max_ind + 1), lon_min_ind:].copy() - q2 = f.variables['specific_humidity'][(ml_min + 1):(ml_max + 1), lat_min_ind:(lat_max_ind + 1), 0:(lon_max_ind + 1)].copy() + if self._bounds[2] * self._bounds[3] < 0: + q1 = f.variables['specific_humidity'][ + (ml_min + 1) : (ml_max + 1), lat_min_ind : (lat_max_ind + 1), lon_min_ind: + ].copy() + q2 = f.variables['specific_humidity'][ + (ml_min + 1) : (ml_max + 1), lat_min_ind : (lat_max_ind + 1), 0 : (lon_max_ind + 1) + ].copy() q = np.append(q1, q2, axis=2) else: - q = f.variables['specific_humidity'][(ml_min + 1):(ml_max + 1), lat_min_ind:(lat_max_ind + 1), lon_min_ind:(lon_max_ind + 1)].copy() - if (self._bounds[2] * self._bounds[3] < 0): - p1 = f.variables['air_pressure'][(ml_min + 1):(ml_max + 1), lat_min_ind:(lat_max_ind + 1), lon_min_ind:].copy() - p2 = f.variables['air_pressure'][(ml_min + 1):(ml_max + 1), lat_min_ind:(lat_max_ind + 1), 0:(lon_max_ind + 1)].copy() + q = f.variables['specific_humidity'][ + (ml_min + 1) : (ml_max + 1), lat_min_ind : (lat_max_ind + 1), lon_min_ind : (lon_max_ind + 1) + ].copy() + if self._bounds[2] * self._bounds[3] < 0: + p1 = f.variables['air_pressure'][ + (ml_min + 1) : (ml_max + 1), lat_min_ind : (lat_max_ind + 1), lon_min_ind: + ].copy() + p2 = f.variables['air_pressure'][ + (ml_min + 1) : (ml_max + 1), lat_min_ind : (lat_max_ind + 1), 0 : (lon_max_ind + 1) + ].copy() p = np.append(p1, p2, axis=2) else: - p = f.variables['air_pressure'][(ml_min + 1):(ml_max + 1), lat_min_ind:(lat_max_ind + 1), lon_min_ind:(lon_max_ind + 1)].copy() - - level_hgt = f.variables['level_height'][(ml_min + 1):(ml_max + 1)].copy() - if (self._bounds[2] * self._bounds[3] < 0): - surface_alt1 = f.variables['surface_altitude'][lat_min_ind:(lat_max_ind + 1), lon_min_ind:].copy() - surface_alt2 = f.variables['surface_altitude'][lat_min_ind:(lat_max_ind + 1), 0:(lon_max_ind + 1)].copy() + p = f.variables['air_pressure'][ + (ml_min + 1) : (ml_max + 1), lat_min_ind : (lat_max_ind + 1), lon_min_ind : (lon_max_ind + 1) + ].copy() + + level_hgt = f.variables['level_height'][(ml_min + 1) : (ml_max + 1)].copy() + if self._bounds[2] * self._bounds[3] < 0: + surface_alt1 = f.variables['surface_altitude'][lat_min_ind : (lat_max_ind + 1), lon_min_ind:].copy() + surface_alt2 = f.variables['surface_altitude'][ + lat_min_ind : (lat_max_ind + 1), 0 : (lon_max_ind + 1) + ].copy() surface_alt = np.append(surface_alt1, surface_alt2, axis=1) else: - surface_alt = f.variables['surface_altitude'][lat_min_ind:(lat_max_ind + 1), lon_min_ind:(lon_max_ind + 1)].copy() + surface_alt = f.variables['surface_altitude'][ + lat_min_ind : (lat_max_ind + 1), lon_min_ind : (lon_max_ind + 1) + ].copy() hgt = np.zeros([len(level_hgt), len(surface_alt[:, 1]), len(surface_alt[1, :])]) for i in range(len(level_hgt)): @@ -176,13 +195,11 @@ def _download_ncmr_file(self, out, date_time, bounding_box): try: writeWeatherVarsXarray(lats, lons, hgt, q, p, t, self._time, self._proj, outName=out) - except Exception: - logger.exception("Unable to save weathermodel to file") + except: + logger.exception('Unable to save weathermodel to file') - def _makeDataCubes(self, filename): - ''' - Get the variables from the saved .nc file (named as "NCMR_YYYY_MM_DD_THH_MM_SS.nc") - ''' + def _makeDataCubes(self, filename) -> None: + """Get the variables from the saved .nc file (named as "NCMR_YYYY_MM_DD_THH_MM_SS.nc").""" from netCDF4 import Dataset # adding the import here should become absolute when transition to netcdf @@ -195,10 +212,8 @@ def _makeDataCubes(self, filename): t = np.array(f.variables['T'][:]) # re-assign lons, lats to match heights - _lons = np.broadcast_to(lons[np.newaxis, np.newaxis, :], - t.shape) - _lats = np.broadcast_to(lats[np.newaxis, :, np.newaxis], - t.shape) + _lons = np.broadcast_to(lons[np.newaxis, np.newaxis, :], t.shape) + _lats = np.broadcast_to(lats[np.newaxis, :, np.newaxis], t.shape) # Re-structure everything from (heights, lats, lons) to (lons, lats, heights) _lats = np.transpose(_lats) diff --git a/tools/RAiDER/models/plotWeather.py b/tools/RAiDER/models/plotWeather.py index 6be5cdc8d..c5bd8b29b 100755 --- a/tools/RAiDER/models/plotWeather.py +++ b/tools/RAiDER/models/plotWeather.py @@ -5,19 +5,20 @@ class objects. It is not designed to be used on its own apart """ import os -from RAiDER.interpolator import RegularGridInterpolator as Interpolator -from mpl_toolkits.axes_grid1 import make_axes_locatable as mal -import numpy as np -import matplotlib.pyplot as plt + import matplotlib as mpl +import matplotlib.pyplot as plt +import numpy as np +from mpl_toolkits.axes_grid1 import make_axes_locatable as mal + +from RAiDER.interpolator import RegularGridInterpolator as Interpolator + + mpl.use('Agg') def plot_pqt(weatherObj, savefig=True, z1=500, z2=15000): - ''' - Create a plot with pressure, temp, and humidity at two heights - ''' - + """Create a plot with pressure, temp, and humidity at two heights.""" # Get the interpolator intFcn_p = Interpolator((weatherObj._xs, weatherObj._ys, weatherObj._zs), weatherObj._p.swapaxes(0, 1)) @@ -47,7 +48,9 @@ def plot_pqt(weatherObj, savefig=True, z1=500, z2=15000): # setup the plot f = plt.figure(figsize=(18, 14)) - f.suptitle(f'{weatherObj._Name} Pressure/Humidity/Temperature at height {z1}m and {z2}m (values should drop as elevation increases)') + f.suptitle( + f'{weatherObj._Name} Pressure/Humidity/Temperature at height {z1}m and {z2}m (values should drop as elevation increases)' + ) xind = int(np.floor(weatherObj._xs.shape[0] / 2)) yind = int(np.floor(weatherObj._ys.shape[0] / 2)) @@ -55,19 +58,21 @@ def plot_pqt(weatherObj, savefig=True, z1=500, z2=15000): # loop over each plot for ind, plot, title in zip(range(len(plots)), plots, titles): sp = f.add_subplot(3, 3, ind + 1) - im = sp.imshow(np.reshape(plot, x.shape), - cmap='viridis', - extent=[np.nanmin(x), np.nanmax(x), np.nanmin(y), np.nanmax(y)], - origin='lower') + im = sp.imshow( + np.reshape(plot, x.shape), + cmap='viridis', + extent=[np.nanmin(x), np.nanmax(x), np.nanmin(y), np.nanmax(y)], + origin='lower', + ) sp.plot(x[yind, xind], y[yind, xind], 'ko') divider = mal(sp) - cax = divider.append_axes("right", size="4%", pad=0.05) + cax = divider.append_axes('right', size='4%', pad=0.05) plt.colorbar(im, cax=cax) sp.set_title(title) if ind == 0: - sp.set_ylabel('{} m\n'.format(z1)) + sp.set_ylabel(f'{z1} m\n') if ind == 3: - sp.set_ylabel('{} m\n'.format(z2)) + sp.set_ylabel(f'{z2} m\n') # add plots that show each variable with height zdata = weatherObj._zs[:] / 1000 @@ -84,25 +89,27 @@ def plot_pqt(weatherObj, savefig=True, z1=500, z2=15000): sp.plot(weatherObj._t[yind, xind, :] - 273.15, zdata) sp.set_xlabel('Temp (C)') - plt.subplots_adjust(top=0.95, bottom=0.1, left=0.1, right=0.95, hspace=0.2, - wspace=0.3) + plt.subplots_adjust(top=0.95, bottom=0.1, left=0.1, right=0.95, hspace=0.2, wspace=0.3) if savefig: - wd = os.path.dirname(os.path.dirname(weatherObj._out_name)) - f = f'{weatherObj._Name}_weather_hgt{z1}_and_{z2}m.pdf' + wd = os.path.dirname(os.path.dirname(weatherObj._out_name)) + f = f'{weatherObj._Name}_weather_hgt{z1}_and_{z2}m.pdf' plt.savefig(os.path.join(wd, f)) return f def plot_wh(weatherObj, savefig=True, z1=500, z2=15000): - ''' + """ Create a plot with wet refractivity and hydrostatic refractivity, - at two different heights - ''' - + at two different heights. + """ # Get the interpolator - intFcn_w = Interpolator((weatherObj._xs, weatherObj._ys, weatherObj._zs), weatherObj._wet_refractivity.swapaxes(0, 1)) - intFcn_h = Interpolator((weatherObj._xs, weatherObj._ys, weatherObj._zs), weatherObj._hydrostatic_refractivity.swapaxes(0, 1)) + intFcn_w = Interpolator( + (weatherObj._xs, weatherObj._ys, weatherObj._zs), weatherObj._wet_refractivity.swapaxes(0, 1) + ) + intFcn_h = Interpolator( + (weatherObj._xs, weatherObj._ys, weatherObj._zs), weatherObj._hydrostatic_refractivity.swapaxes(0, 1) + ) # get the points needed XY = np.meshgrid(weatherObj._xs, weatherObj._ys) @@ -122,10 +129,7 @@ def plot_wh(weatherObj, savefig=True, z1=500, z2=15000): plots = [w1, h1, w2, h2] # titles - titles = ('Wet refractivity {}'.format(z1), - 'Hydrostatic refractivity {}'.format(z1), - '{}'.format(z2), - '{}'.format(z2)) + titles = (f'Wet refractivity {z1}', f'Hydrostatic refractivity {z1}', f'{z2}', f'{z2}') # setup the plot f = plt.figure(figsize=(14, 10)) @@ -134,19 +138,23 @@ def plot_wh(weatherObj, savefig=True, z1=500, z2=15000): # loop over each plot for ind, plot, title in zip(range(len(plots)), plots, titles): sp = f.add_subplot(2, 2, ind + 1) - im = sp.imshow(np.reshape(plot, x.shape), cmap='viridis', - extent=[np.nanmin(x), np.nanmax(x), np.nanmin(y), np.nanmax(y)], origin='lower') + im = sp.imshow( + np.reshape(plot, x.shape), + cmap='viridis', + extent=[np.nanmin(x), np.nanmax(x), np.nanmin(y), np.nanmax(y)], + origin='lower', + ) divider = mal(sp) - cax = divider.append_axes("right", size="4%", pad=0.05) + cax = divider.append_axes('right', size='4%', pad=0.05) plt.colorbar(im, cax=cax) sp.set_title(title) if ind == 0: - sp.set_ylabel('{} m\n'.format(z1)) + sp.set_ylabel(f'{z1} m\n') if ind == 2: - sp.set_ylabel('{} m\n'.format(z2)) + sp.set_ylabel(f'{z2} m\n') if savefig: - wd = os.path.dirname(os.path.dirname(weatherObj._out_name)) - f = f'{weatherObj._Name}_refractivity_hgt{z1}_and_{z2}m.pdf' + wd = os.path.dirname(os.path.dirname(weatherObj._out_name)) + f = f'{weatherObj._Name}_refractivity_hgt{z1}_and_{z2}m.pdf' plt.savefig(os.path.join(wd, f)) return f diff --git a/tools/RAiDER/models/template.py b/tools/RAiDER/models/template.py index a753605ba..e7b2cf7b0 100644 --- a/tools/RAiDER/models/template.py +++ b/tools/RAiDER/models/template.py @@ -1,14 +1,16 @@ -from pyproj import CRS +import datetime as dt + import numpy as np -import datetime -from RAiDER.models.weatherModel import WeatherModel +from pyproj import CRS + from RAiDER.models.model_levels import ( LEVELS_137_HEIGHTS, ) +from RAiDER.models.weatherModel import WeatherModel class customModelReader(WeatherModel): - def __init__(self): + def __init__(self) -> None: WeatherModel.__init__(self) self._humidityType = 'q' # can be "q" (specific humidity) or "rh" (relative humidity) self._model_level_type = 'pl' # Default, pressure levels are "pl", and model levels are "ml" @@ -17,10 +19,12 @@ def __init__(self): # Tuple of min/max years where data is available. # valid range of the dataset. Users need to specify the start date and end date (can be "present") - self._valid_range = (datetime.datetime(2016, 7, 15).replace(tzinfo=datetime.timezone(offset=datetime.timedelta())), - datetime.datetime.now(datetime.timezone.utc)) + self._valid_range = ( + dt.datetime(2016, 7, 15).replace(tzinfo=dt.timezone(offset=dt.timedelta())), + dt.datetime.now(dt.timezone.utc), + ) # Availability lag time. Can be specified in hours "hours=3" or in days "days=3" - self._lag_time = datetime.timedelta(hours=3) + self._lag_time = dt.timedelta(hours=3) # Availabile time resolution; i.e. minimum rate model is available in hours. 1 is hourly self._time_res = 1 @@ -31,11 +35,11 @@ def __init__(self): self._k3 = 3.75e3 # [K^2/Pa] # horizontal grid spacing - self._lat_res = 3. / 111 # grid spacing in latitude - self._lon_res = 3. / 111 # grid spacing in longitude - self._x_res = 3. # x-direction grid spacing in the native weather model projection + self._lat_res = 3.0 / 111 # grid spacing in latitude + self._lon_res = 3.0 / 111 # grid spacing in longitude + self._x_res = 3.0 # x-direction grid spacing in the native weather model projection # (if the projection is in lat/lon, it is the same as "self._lon_res") - self._y_res = 3. # y-direction grid spacing in the weather model native projection + self._y_res = 3.0 # y-direction grid spacing in the weather model native projection # (if the projection is in lat/lon, it is the same as "self._lat_res") # zlevels specify fixed heights at which to interpolate the weather model variables @@ -54,19 +58,21 @@ def __init__(self): x0 = 0 y0 = 0 earth_radius = 6371229 - p1 = CRS('+proj=lcc +lat_1={lat1} +lat_2={lat2} +lat_0={lat0} +lon_0={lon0} +x_0={x0} +y_0={y0} +a={a} +b={a} +units=m +no_defs'.format(lat1=lat1, lat2=lat2, lat0=lat0, lon0=lon0, x0=x0, y0=y0, a=earth_radius)) + p1 = CRS( + f'+proj=lcc +lat_1={lat1} +lat_2={lat2} +lat_0={lat0} +lon_0={lon0} +x_0={x0} +y_0={y0} +a={earth_radius} +b={earth_radius} +units=m +no_defs' + ) self._proj = p1 - def _fetch(self, out): - ''' + def _fetch(self, out) -> None: + """ Fetch weather model data from the custom weather model "ABCD" Inputs (no need to change in the custom weather model reader): lats - latitude lons - longitude time - datatime object (year,month,day,hour,minute,second) out - name of downloaded dataset file from the custom weather model server - Nextra - buffer of latitude/longitude for determining the bounding box - ''' + Nextra - buffer of latitude/longitude for determining the bounding box. + """ # Auxilliary function: # download dataset of the custom weather model "ABCD" from a server and then save it to a file named out. # This function needs to be writen by the users. For download from the weather model server, the weather model @@ -75,13 +81,12 @@ def _fetch(self, out): # retrieval to the following "load_weather" function. self._files = self._download_abcd_file(out, 'abcd', self._time, self._ll_bounds) - def load_weather(self, filename): - ''' + def load_weather(self, filename) -> None: + """ Load weather model variables from the downloaded file named filename Inputs: - filename - filename of the downloaded weather model file - ''' - + filename - filename of the downloaded weather model file. + """ # Auxilliary function: # read individual variables (in 3-D cube format with exactly the same dimension) from downloaded file # This function needs to be writen by the users. For downloaded file from the weather model server, @@ -131,8 +136,8 @@ def load_weather(self, filename): ########### - def _download_abcd_file(self, out, model_name, date_time, bounding_box): - ''' + def _download_abcd_file(self, out, model_name, date_time, bounding_box) -> None: + """ Auxilliary function: Download weather model data from a server Inputs: @@ -141,12 +146,12 @@ def _download_abcd_file(self, out, model_name, date_time, bounding_box): date_time - datatime object (year,month,day,hour,minute,second) bounding_box - lat/lon bounding box for the region of interest Output: - out - returned filename from input - ''' + out - returned filename from input. + """ pass - def _makeDataCubes(self, filename): - ''' + def _makeDataCubes(self, filename) -> None: + """ Auxilliary function: Read 3-D data cubes from downloaded file or directly from weather model weblink (in which case, there is no need to download and save any file; rather, the weblink needs to be hardcoded in the custom reader, e.g. GMAO) @@ -160,6 +165,6 @@ def _makeDataCubes(self, filename): t - temperature (3-D data cube) q - humidity (3-D data cube; could be relative humidity or specific humidity) p - pressure level (3-D data cube; could be pressure level (preferred) or surface pressure) - hgt - height (3-D data cube; could be geopotential height or topographic height (preferred)) - ''' + hgt - height (3-D data cube; could be geopotential height or topographic height (preferred)). + """ pass diff --git a/tools/RAiDER/models/weatherModel.py b/tools/RAiDER/models/weatherModel.py index 7ad457668..d561aa8b2 100755 --- a/tools/RAiDER/models/weatherModel.py +++ b/tools/RAiDER/models/weatherModel.py @@ -1,43 +1,42 @@ -import datetime +import datetime as dt import os from abc import ABC, abstractmethod +from typing import Optional import numpy as np -import netCDF4 -import rioxarray -import xarray - +import xarray as xr from pyproj import CRS -from shapely.geometry import box from shapely.affinity import translate +from shapely.geometry import box from shapely.ops import unary_union -from RAiDER.constants import _ZREF, _ZMIN, _g0 from RAiDER import utilFcns as util +from RAiDER.constants import _ZMIN, _ZREF, _g0 from RAiDER.interpolate import interpolate_along_axis from RAiDER.interpolator import fillna3D from RAiDER.logger import logger -from RAiDER.models import plotWeather as plots, weatherModel -from RAiDER.models.customExceptions import * -from RAiDER.utilFcns import ( - robmax, robmin, calcgeoh, transform_coords, clip_bbox -) - -TIME_RES = {'GMAO': 3, - 'ECMWF': 1, - 'HRES': 6, - 'HRRR': 1, - 'WRF': 1, - 'NCMR': 1, - 'HRRR-AK': 3, - } +from RAiDER.models import plotWeather as plots +from RAiDER.models.customExceptions import DatetimeOutsideRange +from RAiDER.utilFcns import calcgeoh, clip_bbox, robmax, robmin, transform_coords + + +TIME_RES = { + 'GMAO': 3, + 'ECMWF': 1, + 'HRES': 6, + 'HRRR': 1, + 'WRF': 1, + 'NCMR': 1, + 'HRRR-AK': 3, +} + class WeatherModel(ABC): - ''' - Implement a generic weather model for getting estimated SAR delays - ''' + """Implement a generic weather model for getting estimated SAR delays.""" - def __init__(self): + _dataset: Optional[str] + + def __init__(self) -> None: # Initialize model-specific constants/parameters self._k1 = None self._k2 = None @@ -48,7 +47,7 @@ def __init__(self): self.files = None - self._time_res = None # time resolution of the weather model in hours + self._time_res = None # time resolution of the weather model in hours self._lon_res = None self._lat_res = None @@ -57,16 +56,16 @@ def __init__(self): self._classname = None self._dataset = None - self._Name = None - self._wmLoc = None + self._Name = None + self._wmLoc = None self._model_level_type = 'ml' self._valid_range = ( - datetime.datetime(1900, 1, 1).replace(tzinfo=datetime.timezone(offset=datetime.timedelta())), - datetime.datetime.now(datetime.timezone.utc).date(), + dt.datetime(1900, 1, 1).replace(tzinfo=dt.timezone(offset=dt.timedelta())), + dt.datetime.now(dt.timezone.utc).date(), ) # Tuple of min/max years where data is available. - self._lag_time = datetime.timedelta(days=30) # Availability lag time in days + self._lag_time = dt.timedelta(days=30) # Availability lag time in days self._time = None self._bbox = None @@ -87,7 +86,7 @@ def __init__(self): self._lats = None self._lons = None self._ll_bounds = None - self._valid_bounds = box(-180, -90, 180, 90) # Shapely box with WSEN bounds + self._valid_bounds = box(-180, -90, 180, 90) # Shapely box with WSEN bounds self._p = None self._q = None @@ -99,115 +98,101 @@ def __init__(self): self._wet_ztd = None self._hydrostatic_ztd = None - - def __str__(self): + def __str__(self) -> str: string = '\n' string += '======Weather Model class object=====\n' - string += 'Weather model time: {}\n'.format(self._time) - string += 'Latitude resolution: {}\n'.format(self._lat_res) - string += 'Longitude resolution: {}\n'.format(self._lon_res) - string += 'Native projection: {}\n'.format(self._proj) - string += 'ZMIN: {}\n'.format(self._zmin) - string += 'ZMAX: {}\n'.format(self._zmax) - string += 'k1 = {}\n'.format(self._k1) - string += 'k2 = {}\n'.format(self._k2) - string += 'k3 = {}\n'.format(self._k3) - string += 'Humidity type = {}\n'.format(self._humidityType) + string += f'Weather model time: {self._time}\n' + string += f'Latitude resolution: {self._lat_res}\n' + string += f'Longitude resolution: {self._lon_res}\n' + string += f'Native projection: {self._proj}\n' + string += f'ZMIN: {self._zmin}\n' + string += f'ZMAX: {self._zmax}\n' + string += f'k1 = {self._k1}\n' + string += f'k2 = {self._k2}\n' + string += f'k3 = {self._k3}\n' + string += f'Humidity type = {self._humidityType}\n' string += '=====================================\n' - string += 'Class name: {}\n'.format(self._classname) - string += 'Dataset: {}\n'.format(self._dataset) + string += f'Class name: {self._classname}\n' + string += f'Dataset: {self._dataset}\n' string += '=====================================\n' - string += 'A: {}\n'.format(self._a) - string += 'B: {}\n'.format(self._b) + string += f'A: {self._a}\n' + string += f'B: {self._b}\n' if self._p is not None: string += 'Number of points in Lon/Lat = {}/{}\n'.format(*self._p.shape[:2]) - string += 'Total number of grid points (3D): {}\n'.format(np.prod(self._p.shape)) + string += f'Total number of grid points (3D): {np.prod(self._p.shape)}\n' if self._xs.size == 0: - string += 'Minimum/Maximum y: {: 4.2f}/{: 4.2f}\n'\ - .format(robmin(self._ys), robmax(self._ys)) - string += 'Minimum/Maximum x: {: 4.2f}/{: 4.2f}\n'\ - .format(robmin(self._xs), robmax(self._xs)) - string += 'Minimum/Maximum zs/heights: {: 10.2f}/{: 10.2f}\n'\ - .format(robmin(self._zs), robmax(self._zs)) - string += '=====================================\n' - return str(string) + string += f'Minimum/Maximum y: {robmin(self._ys): 4.2f}/{robmax(self._ys): 4.2f}\n' + string += f'Minimum/Maximum x: {robmin(self._xs): 4.2f}/{robmax(self._xs): 4.2f}\n' + string += f'Minimum/Maximum zs/heights: {robmin(self._zs): 10.2f}/{robmax(self._zs): 10.2f}\n' + string += '=====================================\n' + return string def Model(self): return self._Name - def dtime(self): return self._time_res - def getLLRes(self): return np.max([self._lat_res, self._lon_res]) - - def fetch(self, out, time): - ''' + def fetch(self, out, time) -> None: + """ Checks the input datetime against the valid date range for the model and then - calls the model _fetch routine + calls the model _fetch routine. Args: ---------- out - ll_bounds - 4 x 1 array, SNWE time = UTC datetime - ''' + """ self.checkTime(time) self.setTime(time) # write the error raised by the weather model API to the log try: self._fetch(out) - except Exception as E: - logger.exception(E) + except Exception as e: + logger.exception(e) raise - @abstractmethod def _fetch(self, out): - ''' - Placeholder method. Should be implemented in each weather model type class - ''' + """Placeholder method. Should be implemented in each weather model type class.""" pass - def getTime(self): return self._time - - def setTime(self, time, fmt='%Y-%m-%dT%H:%M:%S'): - ''' Set the time for a weather model ''' + def setTime(self, time, fmt='%Y-%m-%dT%H:%M:%S') -> None: + """Set the time for a weather model.""" if isinstance(time, str): - self._time = datetime.datetime.strptime(time, fmt) - elif isinstance(time, datetime.datetime): + self._time = dt.datetime.strptime(time, fmt) + elif isinstance(time, dt.datetime): self._time = time else: raise ValueError('"time" must be a string or a datetime object') if self._time.tzinfo is None: - self._time = self._time.replace(tzinfo=datetime.timezone(offset=datetime.timedelta())) - + self._time = self._time.replace(tzinfo=dt.timezone(offset=dt.timedelta())) def get_latlon_bounds(self): return self._ll_bounds - - def set_latlon_bounds(self, ll_bounds, Nextra=2, output_spacing=None): - ''' + def set_latlon_bounds(self, ll_bounds, Nextra=2, output_spacing=None) -> None: + """ Need to correct lat/lon bounds because not all of the weather models have valid data exactly bounded by -90/90 (lats) and -180/180 (lons); for GMAO and MERRA2, need to adjust the longitude higher end with an extra buffer; for other models, the exact bounds are close to -90/90 (lats) and -180/180 (lons) and thus can be rounded to the above regions (either in the downloading-file API or subsetting- data API) without problems. - ''' + """ ex_buffer_lon_max = 0.0 if self._Name in 'HRRR HRRR-AK HRES'.split(): - Nextra = 6 # have a bigger buffer + Nextra = 6 # have a bigger buffer else: ex_buffer_lon_max = self._lon_res @@ -216,45 +201,37 @@ def set_latlon_bounds(self, ll_bounds, Nextra=2, output_spacing=None): S, N, W, E = ll_bounds # Adjust bounds if they get near the poles or IDL - pixlat, pixlon = Nextra*self._lat_res, Nextra*self._lon_res + pixlat, pixlon = Nextra * self._lat_res, Nextra * self._lon_res - S= np.max([S - pixlat, -90.0 + pixlat]) - N= np.min([N + pixlat, 90.0 - pixlat]) - W= np.max([W - (pixlon + ex_buffer_lon_max), -180.0 + (pixlon+ex_buffer_lon_max)]) - E= np.min([E + (pixlon + ex_buffer_lon_max), 180.0 - pixlon - ex_buffer_lon_max]) + S = np.max([S - pixlat, -90.0 + pixlat]) + N = np.min([N + pixlat, 90.0 - pixlat]) + W = np.max([W - (pixlon + ex_buffer_lon_max), -180.0 + (pixlon + ex_buffer_lon_max)]) + E = np.min([E + (pixlon + ex_buffer_lon_max), 180.0 - pixlon - ex_buffer_lon_max]) if output_spacing is not None: - S, N, W, E = clip_bbox([S,N,W,E], output_spacing) + S, N, W, E = clip_bbox([S, N, W, E], output_spacing) self._ll_bounds = np.array([S, N, W, E]) - def get_wmLoc(self): - """ Get the path to the direct with the weather model files """ + """Get the path to the direct with the weather model files.""" if self._wmLoc is None: wmLoc = os.path.join(os.getcwd(), 'weather_files') else: wmLoc = self._wmLoc return wmLoc - - def set_wmLoc(self, weather_model_directory:str): - """ Set the path to the directory with the weather model files """ + def set_wmLoc(self, weather_model_directory: str) -> None: + """Set the path to the directory with the weather model files.""" self._wmLoc = weather_model_directory - - def load( - self, - *args, - _zlevels=None, - **kwargs - ): - ''' + def load(self, *args, _zlevels=None, **kwargs): + """ Calls the load_weather method. Each model class should define a load_weather method appropriate for that class. 'args' should be one or more filenames. - ''' + """ # If the weather file has already been processed, do nothing outLoc = self.get_wmLoc() - path_wm_raw = make_raw_weather_data_filename(outLoc, self.Model(), self.getTime()) + path_wm_raw = make_raw_weather_data_filename(outLoc, self.Model(), self.getTime()) self._out_name = self.out_file(outLoc) if os.path.exists(self._out_name): @@ -275,66 +252,55 @@ def load( self._getZTD() return None - @abstractmethod def load_weather(self, *args, **kwargs): - ''' - Placeholder method. Should be implemented in each weather model type class - ''' + """Placeholder method. Should be implemented in each weather model type class.""" pass - def plot(self, plotType='pqt', savefig=True): - ''' - Plotting method. Valid plot types are 'pqt' - ''' + """Plotting method. Valid plot types are 'pqt'.""" if plotType == 'pqt': plot = plots.plot_pqt(self, savefig) elif plotType == 'wh': plot = plots.plot_wh(self, savefig) else: - raise RuntimeError('WeatherModel.plot: No plotType named {}'.format(plotType)) + raise RuntimeError(f'WeatherModel.plot: No plotType named {plotType}') return plot - - def checkTime(self, time): - ''' - Checks the time against the lag time and valid date range for the given model type + def checkTime(self, time) -> None: + """ + Checks the time against the lag time and valid date range for the given model type. Parameters: time - Python datetime object - Raises: + Raises: Different errors depending on the issue - ''' + """ start_time = self._valid_range[0] end_time = self._valid_range[1] - if not isinstance(time, datetime.datetime): - raise ValueError('"time" should be a Python datetime object, instead it is {}'.format(time)) - - # This is needed because Python now gets angry if you try to compare non-timezone-aware - # objects with time-zone aware objects. - time = time.replace(tzinfo=datetime.timezone(offset=datetime.timedelta())) + if not isinstance(time, dt.datetime): + raise ValueError(f'"time" should be a Python datetime object, instead it is {time}') - logger.info( - 'Weather model %s is available from %s to %s', - self.Model(), start_time, end_time - ) + # This is needed because Python now gets angry if you try to compare non-timezone-aware + # objects with time-zone aware objects. + time = time.replace(tzinfo=dt.timezone(offset=dt.timedelta())) + + logger.info('Weather model %s is available from %s to %s', self.Model(), start_time, end_time) if time < start_time: raise DatetimeOutsideRange(self.Model(), time) if end_time < time: raise DatetimeOutsideRange(self.Model(), time) - # datetime.datetime.utcnow() is deprecated because Python developers - # want everyone to use timezone-aware datetimes. - if time > datetime.datetime.now(datetime.timezone.utc) - self._lag_time: + # dt.datetime.utcnow() is deprecated because Python developers + # want everyone to use timezone-aware datetimes. + if time > dt.datetime.now(dt.timezone.utc) - self._lag_time: raise DatetimeOutsideRange(self.Model(), time) - - def setLevelType(self, levelType): - '''Set the level type to model levels or pressure levels''' + def setLevelType(self, levelType) -> None: + """Set the level type to model levels or pressure levels.""" if levelType in 'ml pl nat prs'.split(): self._model_level_type = levelType else: @@ -345,25 +311,18 @@ def setLevelType(self, levelType): else: self.__pressure_levels__() - def _convertmb2Pa(self, pres): - ''' - Convert pressure in millibars to Pascals - ''' + """Convert pressure in millibars to Pascals.""" return 100 * pres - - def _get_heights(self, lats, geo_hgt, geo_ht_fill=np.nan): - ''' - Transform geo heights to WGS84 ellipsoidal heights - ''' + def _get_heights(self, lats, geo_hgt, geo_ht_fill=np.nan) -> None: + """Transform geo heights to WGS84 ellipsoidal heights.""" geo_ht_fix = np.where(geo_hgt != geo_ht_fill, geo_hgt, np.nan) - lats_full = np.broadcast_to(lats[...,np.newaxis], geo_ht_fix.shape) - self._zs = util.geo_to_ht(lats_full, geo_ht_fix) - + lats_full = np.broadcast_to(lats[..., np.newaxis], geo_ht_fix.shape) + self._zs = util.geo_to_ht(lats_full, geo_ht_fix) - def _find_e(self): - """Check the type of e-calculation needed""" + def _find_e(self) -> None: + """Check the type of e-calculation needed.""" if self._humidityType == 'rh': self._find_e_from_rh() elif self._humidityType == 'q': @@ -373,51 +332,39 @@ def _find_e(self): self._rh = None self._q = None - - def _find_e_from_q(self): + def _find_e_from_q(self) -> None: """Calculate e, partial pressure of water vapor.""" svp = find_svp(self._t) # We have q = w/(w + 1), so w = q/(1 - q) w = self._q / (1 - self._q) self._e = w * self._R_v * (self._p - svp) / self._R_d - - def _find_e_from_rh(self): + def _find_e_from_rh(self) -> None: """Calculate partial pressure of water vapor.""" svp = find_svp(self._t) self._e = self._rh / 100 * svp - - def _get_wet_refractivity(self): - ''' - Calculate the wet delay from pressure, temperature, and e - ''' + def _get_wet_refractivity(self) -> None: + """Calculate the wet delay from pressure, temperature, and e.""" self._wet_refractivity = self._k2 * self._e / self._t + self._k3 * self._e / self._t**2 - - def _get_hydro_refractivity(self): - ''' - Calculate the hydrostatic delay from pressure and temperature - ''' + def _get_hydro_refractivity(self) -> None: + """Calculate the hydrostatic delay from pressure and temperature.""" self._hydrostatic_refractivity = self._k1 * self._p / self._t - def getWetRefractivity(self): return self._wet_refractivity - def getHydroRefractivity(self): return self._hydrostatic_refractivity - - def _adjust_grid(self, ll_bounds=None): - ''' + def _adjust_grid(self, ll_bounds=None) -> None: + """ This function pads the weather grid with a level at self._zmin, if it does not already go that low. <> <> - ''' - + """ if self._zmin < np.nanmin(self._zs): # first add in a new layer at zmin self._zs = np.insert(self._zs, 0, self._zmin) @@ -430,32 +377,24 @@ def _adjust_grid(self, ll_bounds=None): if ll_bounds is not None: self._trimExtent(ll_bounds) - - def _getZTD(self): - ''' + def _getZTD(self) -> None: + """ Compute the full slant tropospheric delay for each weather model grid node, using the reference - height zref - ''' + height zref. + """ wet = self.getWetRefractivity() hydro = self.getHydroRefractivity() # Get the integrated ZTD wet_total, hydro_total = np.zeros(wet.shape), np.zeros(hydro.shape) for level in range(wet.shape[2]): - wet_total[..., level] = 1e-6 * np.trapz( - wet[..., level:], x=self._zs[level:], axis=2 - ) - hydro_total[..., level] = 1e-6 * np.trapz( - hydro[..., level:], x=self._zs[level:], axis=2 - ) + wet_total[..., level] = 1e-6 * np.trapz(wet[..., level:], x=self._zs[level:], axis=2) + hydro_total[..., level] = 1e-6 * np.trapz(hydro[..., level:], x=self._zs[level:], axis=2) self._hydrostatic_ztd = hydro_total self._wet_ztd = wet_total - def _getExtent(self, lats, lons): - ''' - get the bounding box around a set of lats/lons - ''' + """Get the bounding box around a set of lats/lons.""" if (lats.size == 1) & (lons.size == 1): return [lats - self._lat_res, lats + self._lat_res, lons - self._lon_res, lons + self._lon_res] elif (lats.size > 1) & (lons.size > 1): @@ -467,7 +406,6 @@ def _getExtent(self, lats, lons): else: raise RuntimeError('Not a valid lat/lon shape') - @property def bbox(self) -> list: """ @@ -478,18 +416,17 @@ def bbox(self) -> list: list xmin, ymin, xmax, ymax - Raises + Raises: ------ ValueError When `self.files` is None. """ - if self._bbox is None: path_weather_model = self.out_file(self.get_wmLoc()) if not os.path.exists(path_weather_model): raise ValueError('Need to save cropped weather model as netcdf') - with xarray.load_dataset(path_weather_model) as ds: + with xr.load_dataset(path_weather_model) as ds: try: xmin, xmax = ds.x.min(), ds.x.max() ymin, ymax = ds.y.min(), ds.y.max() @@ -497,11 +434,11 @@ def bbox(self) -> list: xmin, xmax = ds.longitude.min(), ds.longitude.max() ymin, ymax = ds.latitude.min(), ds.latitude.max() - wm_proj = self._proj - xs, ys = [xmin, xmin, xmax, xmax], [ymin, ymax, ymin, ymax] + wm_proj = self._proj + xs, ys = [xmin, xmin, xmax, xmax], [ymin, ymax, ymin, ymax] lons, lats = transform_coords(wm_proj, CRS(4326), xs, ys) - ## projected weather models may not be aligned N/S - ## should only matter for warning messages + # projected weather models may not be aligned N/S + # should only matter for warning messages W, E = np.min(lons), np.max(lons) # S, N = np.sort([lats[np.argmin(lons)], lats[np.argmax(lons)]]) S, N = np.min(lats), np.max(lats) @@ -509,41 +446,27 @@ def bbox(self) -> list: return self._bbox - def checkValidBounds( - self: weatherModel, - ll_bounds: np.ndarray, - ): - ''' - Checks whether the given bounding box is valid for the model - (i.e., intersects with the model domain at all) + self, + ll_bounds: np.ndarray, + ) -> None: + """Check whether the given bounding box is valid for the model (i.e., intersects with the model domain at all). Args: ll_bounds : np.ndarray - - Returns: - The weather model object - ''' + """ S, N, W, E = ll_bounds - if box(W, S, E, N).intersects(self._valid_bounds): - Mod = self - - else: + if not box(W, S, E, N).intersects(self._valid_bounds): raise ValueError(f'The requested location is unavailable for {self._Name}') - return Mod - - - def checkContainment(self: weatherModel, - ll_bounds, - buffer_deg: float = 1e-5) -> bool: - """" + def checkContainment(self, ll_bounds, buffer_deg: float = 1e-5) -> bool: + """ " Checks containment of weather model bbox of outLats and outLons provided. Args: ---------- - weather_model : weatherModel + weather_model : WeatherModel ll_bounds: an array of floats (SNWE) demarcating bbox of targets buffer_deg : float For x-translates for extents that lie outside of world bounding box, @@ -557,64 +480,54 @@ def checkContainment(self: weatherModel, and False otherwise. """ ymin_input, ymax_input, xmin_input, xmax_input = ll_bounds - input_box = box(xmin_input, ymin_input, xmax_input, ymax_input) + input_box = box(xmin_input, ymin_input, xmax_input, ymax_input) xmin, ymin, xmax, ymax = self.bbox weather_model_box = box(xmin, ymin, xmax, ymax) - world_box = box(-180, -90, 180, 90) + world_box = box(-180, -90, 180, 90) # Logger - input_box_str = [f'{x:1.2f}' for x in [xmin_input, ymin_input, - xmax_input, ymax_input]] + input_box_str = [f'{x:1.2f}' for x in [xmin_input, ymin_input, xmax_input, ymax_input]] weath_box_str = [f'{x:1.2f}' for x in [xmin, ymin, xmax, ymax]] weath_box_str = ', '.join(weath_box_str) input_box_str = ', '.join(input_box_str) - logger.info(f'Extent of the weather model is (xmin, ymin, xmax, ymax):' - f'{weath_box_str}') - logger.info(f'Extent of the input is (xmin, ymin, xmax, ymax): ' - f'{input_box_str}') + logger.info(f'Extent of the weather model is (xmin, ymin, xmax, ymax):' f'{weath_box_str}') + logger.info(f'Extent of the input is (xmin, ymin, xmax, ymax): ' f'{input_box_str}') # If the bounding box goes beyond the normal world extents # Look at two x-translates, buffer them, and take their union. if not world_box.contains(weather_model_box): - logger.info('Considering x-translates of weather model +/-360 ' - 'as bounding box outside of -180, -90, 180, 90') - translates = [weather_model_box.buffer(buffer_deg), - translate(weather_model_box, - xoff=360).buffer(buffer_deg), - translate(weather_model_box, - xoff=-360).buffer(buffer_deg) - ] + logger.info( + 'Considering x-translates of weather model +/-360 as bounding box outside of -180, -90, 180, 90' + ) + translates = [ + weather_model_box.buffer(buffer_deg), + translate(weather_model_box, xoff=360).buffer(buffer_deg), + translate(weather_model_box, xoff=-360).buffer(buffer_deg), + ] weather_model_box = unary_union(translates) return weather_model_box.contains(input_box) - - def _isOutside(self, extent1, extent2): - ''' - Determine whether any of extent1 lies outside extent2 - extent1/2 should be a list containing [lower_lat, upper_lat, left_lon, right_lon] - ''' + def _isOutside(self, extent1, extent2) -> bool: + """ + Determine whether any of extent1 lies outside extent2. + extent1/2 should be a list containing [lower_lat, upper_lat, left_lon, right_lon]. + """ t1 = extent1[0] < extent2[0] t2 = extent1[1] > extent2[1] t3 = extent1[2] < extent2[2] t4 = extent1[3] > extent2[3] - if np.any([t1, t2, t3, t4]): - return True - return False - + return np.any([t1, t2, t3, t4]) - def _trimExtent(self, extent): - ''' - get the bounding box around a set of lats/lons - ''' + def _trimExtent(self, extent) -> None: + """Get the bounding box around a set of lats/lons.""" lat = self._lats.copy() lon = self._lons.copy() lat[np.isnan(lat)] = np.nanmean(lat) lon[np.isnan(lon)] = np.nanmean(lon) - mask = (lat >= extent[0]) & (lat <= extent[1]) & \ - (lon >= extent[2]) & (lon <= extent[3]) + mask = (lat >= extent[0]) & (lat <= extent[1]) & (lon >= extent[2]) & (lon <= extent[3]) ma1 = np.sum(mask, axis=1).astype('bool') ma2 = np.sum(mask, axis=0).astype('bool') if np.sum(ma1) == 0 and np.sum(ma2) == 0: @@ -640,9 +553,8 @@ def _trimExtent(self, extent): self._wet_refractivity = self._wet_refractivity[index1:index2, index3:index4, ...] self._hydrostatic_refractivity = self._hydrostatic_refractivity[index1:index2, index3:index4, :] - def _calculategeoh(self, z, lnsp): - ''' + """ Function to calculate pressure, geopotential, and geopotential height from the surface pressure and model levels provided by a weather model. The model levels are numbered from the highest eleveation to the lowest. @@ -655,61 +567,44 @@ def _calculategeoh(self, z, lnsp): pressurelvs - The pressure at each of the model levels for each of the input points geoheight - The geopotential heights - ''' + """ return calcgeoh(lnsp, self._t, self._q, z, self._a, self._b, self._R_d, self._levels) - def getProjection(self): - ''' - Returns: the native weather projection, which should be a pyproj object - ''' + """Returns: the native weather projection, which should be a pyproj object.""" return self._proj - def getPoints(self): return self._xs.copy(), self._ys.copy(), self._zs.copy() - - def _uniform_in_z(self, _zlevels=None): - ''' - Interpolate all variables to a regular grid in z - ''' + def _uniform_in_z(self, _zlevels=None) -> None: + """Interpolate all variables to a regular grid in z.""" nx, ny = self._p.shape[:2] # new regular z-spacing if _zlevels is None: try: _zlevels = self._zlevels - except BaseException: + except: _zlevels = np.nanmean(self._zs, axis=(0, 1)) new_zs = np.tile(_zlevels, (nx, ny, 1)) # re-assign values to the uniform z - self._t = interpolate_along_axis( - self._zs, self._t, new_zs, axis=2, fill_value=np.nan - ).astype(np.float32) - self._p = interpolate_along_axis( - self._zs, self._p, new_zs, axis=2, fill_value=np.nan - ).astype(np.float32) - self._e = interpolate_along_axis( - self._zs, self._e, new_zs, axis=2, fill_value=np.nan - ).astype(np.float32) + self._t = interpolate_along_axis(self._zs, self._t, new_zs, axis=2, fill_value=np.nan).astype(np.float32) + self._p = interpolate_along_axis(self._zs, self._p, new_zs, axis=2, fill_value=np.nan).astype(np.float32) + self._e = interpolate_along_axis(self._zs, self._e, new_zs, axis=2, fill_value=np.nan).astype(np.float32) self._zs = _zlevels self._xs = np.unique(self._xs) self._ys = np.unique(self._ys) - - def _checkForNans(self): - ''' - Fill in NaN-values - ''' + def _checkForNans(self) -> None: + """Fill in NaN-values.""" self._p = fillna3D(self._p) - self._t = fillna3D(self._t, fill_value=1e16) # to avoid division by zero later on + self._t = fillna3D(self._t, fill_value=1e16) # to avoid division by zero later on self._e = fillna3D(self._e) - def out_file(self, outLoc): f = make_weather_model_filename( self._Name, @@ -718,11 +613,8 @@ def out_file(self, outLoc): ) return os.path.join(outLoc, f) - def filename(self, time=None, outLoc='weather_files'): - ''' - Create a filename to store the weather model - ''' + """Create a filename to store the weather model.""" os.makedirs(outLoc, exist_ok=True) if time is None: @@ -740,24 +632,22 @@ def filename(self, time=None, outLoc='weather_files'): self.files = [f] return f - def write(self): - ''' + """ By calling the abstract/modular netcdf writer (RAiDER.utilFcns.write2NETCDF4core), write the weather model data and refractivity to an NETCDF4 file that can be accessed by external programs. - ''' + """ # Generate the filename f = self._out_name attrs_dict = { - "Conventions": 'CF-1.6', - "datetime": datetime.datetime.strftime(self._time, "%Y_%m_%dT%H_%M_%S"), - 'date_created': datetime.datetime.now().strftime("%Y_%m_%dT%H_%M_%S"), - 'title': 'Weather model data and delay calculations', - 'model_name': self._Name - - } + 'Conventions': 'CF-1.6', + 'datetime': dt.datetime.strftime(self._time, '%Y_%m_%dT%H_%M_%S'), + 'date_created': dt.datetime.now().strftime('%Y_%m_%dT%H_%M_%S'), + 'title': 'Weather model data and delay calculations', + 'model_name': self._Name, + } dimension_dict = { 'x': ('x', self._xs), @@ -778,7 +668,7 @@ def write(self): 'hydro_total': (('z', 'y', 'x'), self._hydrostatic_ztd.swapaxes(0, 2).swapaxes(1, 2)), } - ds = xarray.Dataset(data_vars=dataset_dict, coords=dimension_dict, attrs=attrs_dict) + ds = xr.Dataset(data_vars=dataset_dict, coords=dimension_dict, attrs=attrs_dict) # Define units ds['t'].attrs['units'] = 'K' @@ -799,7 +689,7 @@ def write(self): ds['hydro_total'].attrs['standard_name'] = 'total_hydrostatic_refractivity' # projection information - ds["proj"] = int() + ds['proj'] = 0 for k, v in self._proj.to_cf().items(): ds.proj.attrs[k] = v for var in ds.data_vars: @@ -810,38 +700,30 @@ def write(self): return f -def make_weather_model_filename(name, time, ll_bounds): +def make_weather_model_filename(name, time, ll_bounds) -> str: s = np.floor(ll_bounds[0]) - S = f'{np.abs(s):.0f}S' if s <0 else f'{s:.0f}N' + S = f'{np.abs(s):.0f}S' if s < 0 else f'{s:.0f}N' n = np.ceil(ll_bounds[1]) - N = f'{np.abs(n):.0f}S' if n <0 else f'{n:.0f}N' + N = f'{np.abs(n):.0f}S' if n < 0 else f'{n:.0f}N' w = np.floor(ll_bounds[2]) - W = f'{np.abs(w):.0f}W' if w <0 else f'{w:.0f}E' + W = f'{np.abs(w):.0f}W' if w < 0 else f'{w:.0f}E' e = np.ceil(ll_bounds[3]) - E = f'{np.abs(e):.0f}W' if e <0 else f'{e:.0f}E' + E = f'{np.abs(e):.0f}W' if e < 0 else f'{e:.0f}E' return f'{name}_{time.strftime("%Y_%m_%d_T%H_%M_%S")}_{S}_{N}_{W}_{E}.nc' def make_raw_weather_data_filename(outLoc, name, time): - ''' Filename generator for the raw downloaded weather model data ''' - f = os.path.join( - outLoc, - '{}_{}.{}'.format( - name, - datetime.datetime.strftime(time, '%Y_%m_%d_T%H_%M_%S'), - 'nc' - ) - ) + """Filename generator for the raw downloaded weather model data.""" + date_string = dt.datetime.strftime(time, '%Y_%m_%d_T%H_%M_%S') + f = os.path.join(outLoc, f'{name}_{date_string}.nc') return f def find_svp(t): - """ - Calculate standard vapor presure. Should be model-specific - """ + """Calculate standard vapor presure. Should be model-specific.""" # From TRAIN: # Could not find the wrf used equation as they appear to be # mixed with latent heat etc. Istead I used the equations used @@ -860,8 +742,8 @@ def find_svp(t): tref = t - t1 wgt = (t - t2) / (t1 - t2) - svpw = (6.1121 * np.exp((17.502 * tref) / (240.97 + tref))) - svpi = (6.1121 * np.exp((22.587 * tref) / (273.86 + tref))) + svpw = 6.1121 * np.exp((17.502 * tref) / (240.97 + tref)) + svpi = 6.1121 * np.exp((22.587 * tref) / (273.86 + tref)) svp = svpi + (svpw - svpi) * wgt**2 ix_bound1 = t > t1 @@ -874,20 +756,18 @@ def find_svp(t): def get_mapping(proj): - '''Get CF-complient projection information from a proj''' + """Get CF-complient projection information from a proj.""" # In case of WGS-84 lat/lon, keep it simple - if proj.to_epsg()==4326: + if proj.to_epsg() == 4326: return 'WGS84' else: return proj.to_wkt() -def checkContainment_raw(path_wm_raw, - ll_bounds, - buffer_deg: float = 1e-5) -> bool: - """" +def checkContainment_raw(path_wm_raw, ll_bounds, buffer_deg: float = 1e-5) -> bool: + """ " Checks if existing raw weather model contains - requested ll_bounds + requested ll_bounds. Args: ---------- @@ -905,8 +785,9 @@ def checkContainment_raw(path_wm_raw, and False otherwise. """ import xarray as xr + ymin_input, ymax_input, xmin_input, xmax_input = ll_bounds - input_box = box(xmin_input, ymin_input, xmax_input, ymax_input) + input_box = box(xmin_input, ymin_input, xmax_input, ymax_input) with xr.open_dataset(path_wm_raw) as ds: try: @@ -916,31 +797,27 @@ def checkContainment_raw(path_wm_raw, ymin, ymax = ds.y.min(), ds.y.max() xmin, xmax = ds.x.min(), ds.x.max() - xmin, xmax = np.mod(np.array([xmin, xmax])+180, 360) - 180 + xmin, xmax = np.mod(np.array([xmin, xmax]) + 180, 360) - 180 weather_model_box = box(xmin, ymin, xmax, ymax) - world_box = box(-180, -90, 180, 90) + world_box = box(-180, -90, 180, 90) # Logger - input_box_str = [f'{x:1.2f}' for x in [xmin_input, ymin_input, - xmax_input, ymax_input]] + input_box_str = [f'{x:1.2f}' for x in [xmin_input, ymin_input, xmax_input, ymax_input]] weath_box_str = [f'{x:1.2f}' for x in [xmin, ymin, xmax, ymax]] weath_box_str = ', '.join(weath_box_str) input_box_str = ', '.join(input_box_str) - # If the bounding box goes beyond the normal world extents # Look at two x-translates, buffer them, and take their union. if not world_box.contains(weather_model_box): - logger.info('Considering x-translates of weather model +/-360 ' - 'as bounding box outside of -180, -90, 180, 90') - translates = [weather_model_box.buffer(buffer_deg), - translate(weather_model_box, - xoff=360).buffer(buffer_deg), - translate(weather_model_box, - xoff=-360).buffer(buffer_deg) - ] + logger.info('Considering x-translates of weather model +/-360 as bounding box outside of -180, -90, 180, 90') + translates = [ + weather_model_box.buffer(buffer_deg), + translate(weather_model_box, xoff=360).buffer(buffer_deg), + translate(weather_model_box, xoff=-360).buffer(buffer_deg), + ] weather_model_box = unary_union(translates) return weather_model_box.contains(input_box) diff --git a/tools/RAiDER/models/wrf.py b/tools/RAiDER/models/wrf.py index 827495a54..5a64febf5 100644 --- a/tools/RAiDER/models/wrf.py +++ b/tools/RAiDER/models/wrf.py @@ -2,7 +2,7 @@ import scipy.io.netcdf as netcdf from pyproj import CRS, Transformer -from RAiDER.models.weatherModel import WeatherModel, TIME_RES +from RAiDER.models.weatherModel import TIME_RES, WeatherModel # Need to incorporate this snippet into this part of the code. @@ -15,31 +15,27 @@ # lats, lons = wrf.wm_nodes(*weather_files) # class WRF(WeatherModel): - ''' - WRF class definition, based on the WeatherModel base class. - ''' + """WRF class definition, based on the WeatherModel base class.""" + # TODO: finish implementing - def __init__(self): + def __init__(self) -> None: WeatherModel.__init__(self) self._k1 = 0.776 # K/Pa self._k2 = 0.233 # K/Pa self._k3 = 3.75e3 # K^2/Pa - # Currently WRF is using RH instead of Q to get E self._humidityType = 'rh' self._Name = 'WRF' self._time_res = TIME_RES[self._Name] - def _fetch(self): + def _fetch(self) -> None: pass - def load_weather(self, file1, file2, *args, **kwargs): - ''' - Consistent class method to be implemented across all weather model types - ''' + def load_weather(self, file1, file2, *args, **kwargs) -> None: + """Consistent class method to be implemented across all weather model types.""" try: lons, lats = self._get_wm_nodes(file1) self._read_netcdf(file2) @@ -61,10 +57,8 @@ def load_weather(self, file1, file2, *args, **kwargs): xs = np.mean(xs, axis=0) ys = np.mean(ys, axis=1) - _xs = np.broadcast_to(xs[np.newaxis, np.newaxis, :], - self._p.shape) - _ys = np.broadcast_to(ys[np.newaxis, :, np.newaxis], - self._p.shape) + _xs = np.broadcast_to(xs[np.newaxis, np.newaxis, :], self._p.shape) + _ys = np.broadcast_to(ys[np.newaxis, :, np.newaxis], self._p.shape) # Re-structure everything from (heights, lats, lons) to (lons, lats, heights) self._p = np.transpose(self._p) self._t = np.transpose(self._t) @@ -85,10 +79,8 @@ def _get_wm_nodes(self, nodeFile): return lons, lats - def _read_netcdf(self, weatherFile, defNul=None): - """ - Read weather variables from a netCDF file - """ + def _read_netcdf(self, weatherFile, defNul=None) -> None: + """Read weather variables from a netCDF file.""" if defNul is None: defNul = np.nan @@ -129,10 +121,17 @@ def _read_netcdf(self, weatherFile, defNul=None): # Projection # See http://www.pkrc.net/wrf-lambert.html earthRadius = 6370e3 # <- note Ray had a bug here - p1 = CRS(proj='lcc', lat_1=lat1, - lat_2=lat2, lat_0=lat0, - lon_0=lon0, a=earthRadius, b=earthRadius, - towgs84=(0, 0, 0), no_defs=True) + p1 = CRS( + proj='lcc', + lat_1=lat1, + lat_2=lat2, + lat_0=lat0, + lon_0=lon0, + a=earthRadius, + b=earthRadius, + towgs84=(0, 0, 0), + no_defs=True, + ) self._proj = p1 temps[temps == tNull] = np.nan @@ -155,36 +154,28 @@ def _read_netcdf(self, weatherFile, defNul=None): self._zs = geoh if len(sp.shape) == 1: - self._p = np.broadcast_to( - sp[:, np.newaxis, np.newaxis], self._zs.shape) + self._p = np.broadcast_to(sp[:, np.newaxis, np.newaxis], self._zs.shape) else: self._p = sp class UnitTypeError(Exception): - ''' - Define a unit type exception for easily formatting - error messages for units - ''' + """Define a unit type exception for easily formatting error messages for units.""" def __init___(self, varName, unittype): - msg = "Unknown units for {}: '{}'".format(varName, unittype) + msg = f"Unknown units for {varName}: '{unittype}'" Exception.__init__(self, msg) -def checkUnits(unitCheck, varName): - ''' - Implement a check that the units are as expected - ''' +def checkUnits(unitCheck, varName) -> None: + """Implement a check that the units are as expected.""" unitDict = {'pressure': 'Pa', 'temperature': 'K', 'relative humidity': '%', 'geopotential': 'm'} if unitCheck != unitDict[varName]: raise UnitTypeError(varName, unitCheck) def getNullValue(var): - ''' - Get the null (or fill) value if it exists, otherwise set the null value to defNullValue - ''' + """Get the null (or fill) value if it exists, otherwise set the null value to defNullValue.""" # NetCDF files have the ability to record their nodata value, but in the # particular NetCDF files that I'm reading, this field is left # unspecified and a nodata value of -999 is used. The solution I'm using diff --git a/tools/RAiDER/processWM.py b/tools/RAiDER/processWM.py index e84cb55c7..9278c7e8d 100755 --- a/tools/RAiDER/processWM.py +++ b/tools/RAiDER/processWM.py @@ -10,22 +10,25 @@ import matplotlib.pyplot as plt import numpy as np -from typing import List - from RAiDER.logger import logger -from RAiDER.utilFcns import getTimeFromFile -from RAiDER.models.weatherModel import make_raw_weather_data_filename, checkContainment_raw -from RAiDER.models.customExceptions import * +from RAiDER.models.customExceptions import ( + CriticalError, + DatetimeOutsideRange, + ExistingWeatherModelTooSmall, + TryToKeepGoingError, +) +from RAiDER.models.weatherModel import checkContainment_raw, make_raw_weather_data_filename, make_weather_model_filename + def prepareWeatherModel( - weather_model, - time, - ll_bounds, - download_only: bool=False, - makePlots: bool=False, - force_download: bool=False, - ) -> str: - """Parse inputs to download and prepare a weather model grid for interpolation + weather_model, + time, + ll_bounds, + download_only: bool = False, + makePlots: bool = False, + force_download: bool = False, +) -> str: + """Parse inputs to download and prepare a weather model grid for interpolation. Args: weather_model: WeatherModel - instantiated weather model object @@ -38,12 +41,12 @@ def prepareWeatherModel( Returns: str: filename of the netcdf file to which the weather model has been written """ - ## set the bounding box from the in the case that it hasn't been set + # set the bounding box from the in the case that it hasn't been set if weather_model.get_latlon_bounds() is None: weather_model.set_latlon_bounds(ll_bounds) # Ensure the file output location exists - wmLoc = weather_model.get_wmLoc() + wmLoc = weather_model.get_wmLoc() weather_model.setTime(time) # get the path to the less processed weather model file @@ -55,15 +58,16 @@ def prepareWeatherModel( # check whether weather model files exists and/or or should be downloaded if os.path.exists(path_wm_crop) and not force_download: logger.warning( - 'Processed weather model already exists, please remove it ("%s") if you want ' - 'to download a new one.', path_wm_crop) + 'Processed weather model already exists, please remove it ("%s") if you want to download a new one.', + path_wm_crop, + ) # check whether the raw weather model covers this area - elif os.path.exists(path_wm_raw) and \ - checkContainment_raw(path_wm_raw, ll_bounds) and not force_download: + elif os.path.exists(path_wm_raw) and checkContainment_raw(path_wm_raw, ll_bounds) and not force_download: logger.warning( - 'Raw weather model already exists, please remove it ("%s") if you want ' - 'to download a new one.', path_wm_raw) + 'Raw weather model already exists, please remove it ("%s") if you want to download a new one.', + path_wm_raw, + ) # if no weather model files supplied, check the standard location else: @@ -75,19 +79,14 @@ def prepareWeatherModel( # If only downloading, exit now if download_only: - logger.warning( - 'download_only flag selected. No further processing will happen.' - ) + logger.warning('download_only flag selected. No further processing will happen.') return None # Otherwise, load the weather model data f = weather_model.load() if f is not None: - logger.warning( - 'The processed weather model file already exists,' - ' so I will use that.' - ) + logger.warning('The processed weather model file already exists, so I will use that.') containment = weather_model.checkContainment(ll_bounds) if not containment and weather_model.Model() not in 'HRRR'.split(): @@ -96,26 +95,19 @@ def prepareWeatherModel( return f # Logging some basic info - logger.debug( - 'Number of weather model nodes: %s', - np.prod(weather_model.getWetRefractivity().shape) - ) + logger.debug('Number of weather model nodes: %s', np.prod(weather_model.getWetRefractivity().shape)) shape = weather_model.getWetRefractivity().shape logger.debug(f'Shape of weather model: {shape}') logger.debug( 'Bounds of the weather model: %.2f/%.2f/%.2f/%.2f (SNWE)', - np.nanmin(weather_model._ys), np.nanmax(weather_model._ys), - np.nanmin(weather_model._xs), np.nanmax(weather_model._xs) + np.nanmin(weather_model._ys), + np.nanmax(weather_model._ys), + np.nanmin(weather_model._xs), + np.nanmax(weather_model._xs), ) logger.debug('Weather model: %s', weather_model.Model()) - logger.debug( - 'Mean value of the wet refractivity: %f', - np.nanmean(weather_model.getWetRefractivity()) - ) - logger.debug( - 'Mean value of the hydrostatic refractivity: %f', - np.nanmean(weather_model.getHydroRefractivity()) - ) + logger.debug('Mean value of the wet refractivity: %f', np.nanmean(weather_model.getWetRefractivity())) + logger.debug('Mean value of the hydrostatic refractivity: %f', np.nanmean(weather_model.getHydroRefractivity())) logger.debug(weather_model) if makePlots: @@ -128,7 +120,7 @@ def prepareWeatherModel( containment = weather_model.checkContainment(ll_bounds) except Exception as e: - logger.exception("Unable to save weathermodel to file") + logger.exception('Unable to save weathermodel to file') logger.exception(e) raise CriticalError @@ -142,21 +134,8 @@ def prepareWeatherModel( return f -def _weather_model_debug( - los, - lats, - lons, - ll_bounds, - weather_model, - wmLoc, - time, - out, - download_only - ): - """ - raiderWeatherModelDebug main function. - """ - +def _weather_model_debug(los, lats, lons, ll_bounds, weather_model, wmLoc, time, out, download_only) -> None: + """RaiderWeatherModelDebug main function.""" logger.debug('Starting to run the weather model calculation with debugging plots') logger.debug('Time type: %s', type(time)) logger.debug('Time: %s', time.strftime('%Y%m%d')) @@ -168,11 +147,7 @@ def _weather_model_debug( wmLoc = os.path.join(out, 'weather_files') # weather model calculation - wm_filename = make_weather_model_filename( - weather_model['name'], - time, - ll_bounds - ) + wm_filename = make_weather_model_filename(weather_model['name'], time, ll_bounds) weather_model_file = os.path.join(wmLoc, wm_filename) if not os.path.exists(weather_model_file): @@ -184,9 +159,9 @@ def _weather_model_debug( lons=lons, ll_bounds=ll_bounds, download_only=download_only, - makePlots=True + makePlots=True, ) try: weather_model.write2NETCDF4(weather_model_file) - except Exception: - logger.exception("Unable to save weathermodel to file") + except: + logger.exception('Unable to save weathermodel to file') diff --git a/tools/RAiDER/s1_azimuth_timing.py b/tools/RAiDER/s1_azimuth_timing.py index dae72b19b..9dcdef1f8 100644 --- a/tools/RAiDER/s1_azimuth_timing.py +++ b/tools/RAiDER/s1_azimuth_timing.py @@ -1,4 +1,5 @@ -import datetime +import datetime as dt +from typing import Optional import warnings import asf_search as asf @@ -6,6 +7,7 @@ import pandas as pd from shapely.geometry import Point + try: import isce3.ext.isce3 as isce except ImportError: @@ -15,58 +17,65 @@ from RAiDER.s1_orbits import get_orbits_from_slc_ids_hyp3lib -def _asf_query(point: Point, - start: datetime.datetime, - end: datetime.datetime, - buffer_degrees: float = 2) -> list[str]: - """Using a buffer to get as many SLCs covering a given request as +def _asf_query( + point: Point, + start: dt.datetime, + end: dt.datetime, + buffer_degrees: float = 2 +) -> list[str]: + """ + Using a buffer to get as many SLCs covering a given request as. Parameters ---------- point : Point - start : datetime.datetime - end : datetime.datetime + start : dt.datetime + end : dt.datetime buffer_degrees : float, optional - Returns + Returns: ------- list[str] """ - results = asf.geo_search(intersectsWith=point.buffer(buffer_degrees).wkt, - processingLevel=asf.PRODUCT_TYPE.SLC, - start=start, - end=end, - maxResults=5 - ) + results = asf.geo_search( + intersectsWith=point.buffer(buffer_degrees).wkt, + processingLevel=asf.PRODUCT_TYPE.SLC, + start=start, + end=end, + maxResults=5, + ) slc_ids = [r.properties['sceneName'] for r in results] return slc_ids -def get_slc_id_from_point_and_time(lon: float, - lat: float, - dt: datetime.datetime, - buffer_seconds: int = 600, - buffer_deg: float = 2) -> list: - """Obtains a (non-unique) SLC id from the lon/lat and datetime of inputs. The buffere ensures that +def get_slc_id_from_point_and_time( + lon: float, + lat: float, + datetime: dt.datetime, + buffer_seconds: int = 600, + buffer_deg: float = 2 +) -> list: + """ + Obtains a (non-unique) SLC id from the lon/lat and datetime of inputs. The buffere ensures that an SLC id is within the queried start/end times. Note an S1 scene takes roughly 30 seconds to acquire. Parameters ---------- lon : float lat : float - dt : datetime.datetime + datetime : dt.datetime buffer_seconds : int, optional Do not recommend adjusting this, by default 600, to ensure enough padding for multiple orbit files - Returns + Returns: ------- list All slc_ids returned by asf_search """ point = Point(lon, lat) - time_delta = datetime.timedelta(seconds=buffer_seconds) - start = dt - time_delta - end = dt + time_delta + time_delta = dt.timedelta(seconds=buffer_seconds) + start = datetime - time_delta + end = datetime + time_delta # Requires buffer of degrees to get several SLCs and ensure we get correct # orbit files @@ -77,17 +86,19 @@ def get_slc_id_from_point_and_time(lon: float, return slc_ids -def get_azimuth_time_grid(lon_mesh: np.ndarray, - lat_mesh: np.ndarray, - hgt_mesh: np.ndarray, - orb: 'isce.core.Orbit') -> np.ndarray: - ''' +def get_azimuth_time_grid( + lon_mesh: np.ndarray, + lat_mesh: np.ndarray, + hgt_mesh: np.ndarray, + orb: 'isce.core.Orbit' +) -> np.ndarray: + """ Source: https://github.com/dbekaert/RAiDER/blob/dev/tools/RAiDER/losreader.py#L601C1-L674C22 lon_mesh, lat_mesh, hgt_mesh are coordinate arrays (this routine makes a mesh to comute azimuth timing grid) Technically, this is "sensor neutral" since it uses an orb object. - ''' + """ if isce is None: raise ImportError('isce3 is required for this function. Use conda to install isce3`') @@ -99,28 +110,35 @@ def get_azimuth_time_grid(lon_mesh: np.ndarray, look = isce.core.LookSide.Right m, n, p = hgt_mesh.shape - az_arr = np.full((m, n, p), - np.datetime64('NaT'), - # source: https://stackoverflow.com/a/27469108 - dtype='datetime64[ms]') + az_arr = np.full( + (m, n, p), + np.datetime64('NaT'), + # source: https://stackoverflow.com/a/27469108 + dtype='datetime64[ms]', + ) for ind_0 in range(m): for ind_1 in range(n): for ind_2 in range(p): + hgt_pt, lat_pt, lon_pt = ( + hgt_mesh[ind_0, ind_1, ind_2], + lat_mesh[ind_0, ind_1, ind_2], + lon_mesh[ind_0, ind_1, ind_2], + ) - hgt_pt, lat_pt, lon_pt = (hgt_mesh[ind_0, ind_1, ind_2], - lat_mesh[ind_0, ind_1, ind_2], - lon_mesh[ind_0, ind_1, ind_2]) - - input_vec = np.array([np.deg2rad(lon_pt), - np.deg2rad(lat_pt), - hgt_pt]) + input_vec = np.array([np.deg2rad(lon_pt), np.deg2rad(lat_pt), hgt_pt]) aztime, sr = isce.geometry.geo2rdr( - input_vec, elp, orb, dop, 0.06, look, + input_vec, + elp, + orb, + dop, + 0.06, + look, threshold=residual_threshold, maxiter=num_iteration, - delta_range=10.0) + delta_range=10.0, + ) rng_seconds = sr / isce.core.speed_of_light aztime = aztime + rng_seconds @@ -130,10 +148,12 @@ def get_azimuth_time_grid(lon_mesh: np.ndarray, return az_arr -def get_s1_azimuth_time_grid(lon: np.ndarray, - lat: np.ndarray, - hgt: np.ndarray, - dt: datetime.datetime) -> np.ndarray: +def get_s1_azimuth_time_grid( + lon: np.ndarray, + lat: np.ndarray, + hgt: np.ndarray, + datetime: dt.datetime +) -> np.ndarray: """Based on the lon, lat, hgt (3d cube) - obtains an associated s1 orbit file to calculate the azimuth timing across the cube. Requires datetime of acq associated to cube. @@ -146,9 +166,9 @@ def get_s1_azimuth_time_grid(lon: np.ndarray, 1 dimensional coordinate array or 3d mesh of coordinates hgt : np.ndarray 1 dimensional coordinate array or 3d mesh of coordinates - dt : datetime.datetime + datetime : dt.datetime - Returns + Returns: ------- np.ndarray Cube whose coordinates are hgt x lat x lon with each pixel @@ -160,12 +180,16 @@ def get_s1_azimuth_time_grid(lon: np.ndarray, raise ValueError('Coordinates must be 1d or 3d coordinate arrays') if dims[0] == 1: - hgt_mesh, lat_mesh, lon_mesh = np.meshgrid(hgt, lat, lon, - # indexing keyword argument - # Ensures output dimensions - # align with order the inputs - # height x latitude x longitude - indexing='ij') + hgt_mesh, lat_mesh, lon_mesh = np.meshgrid( + hgt, + lat, + lon, + # indexing keyword argument + # Ensures output dimensions + # align with order the inputs + # height x latitude x longitude + indexing='ij', + ) else: hgt_mesh = hgt lat_mesh = lat @@ -174,44 +198,45 @@ def get_s1_azimuth_time_grid(lon: np.ndarray, try: lon_m = np.mean(lon) lat_m = np.mean(lat) - slc_ids = get_slc_id_from_point_and_time(lon_m, lat_m, dt) + slc_ids = get_slc_id_from_point_and_time(lon_m, lat_m, datetime) except ValueError: warnings.warn('No slc id found for the given datetime and grid; returning empty grid') m, n, p = hgt_mesh.shape - az_arr = np.full((m, n, p), - np.datetime64('NaT'), - dtype='datetime64[ms]') + az_arr = np.full((m, n, p), np.datetime64('NaT'), dtype='datetime64[ms]') return az_arr orb_files = get_orbits_from_slc_ids_hyp3lib(slc_ids) orb_files = [str(of) for of in orb_files] - orb = get_isce_orbit(orb_files, dt, pad=600) + orb = get_isce_orbit(orb_files, datetime, pad=600) az_arr = get_azimuth_time_grid(lon_mesh, lat_mesh, hgt_mesh, orb) return az_arr -def get_n_closest_datetimes(ref_time: datetime.datetime, - n_target_times: int, - time_step_hours: int) -> list[datetime.datetime]: - """Gets n closes times relative to the `round_to_hour_delta` and the +def get_n_closest_datetimes( + ref_time: dt.datetime, + n_target_times: int, + time_step_hours: int +) -> list[dt.datetime]: + """ + Gets n closest times relative to the `round_to_hour_delta` and the `ref_time`. Specifically, if one is interetsted in getting 3 closest times to say 0, 6, 12, 18 UTC times of a ref time `dt`, then: ``` - dt = datetime.datetime(2023, 1, 1, 11, 0, 0) + dt = dt.datetime(2023, 1, 1, 11, 0, 0) get_n_closest_datetimes(dt, 3, 6) ``` gives the desired answer of ``` - [datetime.datetime(2023, 1, 1, 12, 0, 0), - datetime.datetime(2023, 1, 1, 6, 0, 0), - datetime.datetime(2023, 1, 1, 18, 0, 0)] + [dt.datetime(2023, 1, 1, 12, 0, 0), + dt.datetime(2023, 1, 1, 6, 0, 0), + dt.datetime(2023, 1, 1, 18, 0, 0)] ``` Parameters ---------- - ref_time : datetime.datetime + ref_time : dt.datetime Time to round from n_times : int Number of times to get @@ -220,9 +245,9 @@ def get_n_closest_datetimes(ref_time: datetime.datetime, nearest 0, 2, 4, etc. times. Must be divisible by 24 otherwise is not consistent across all days. - Returns + Returns: ------- - list[datetime.datetime] + list[dt.datetime] List of closest dates ordered by absolute proximity. If two dates have same distance to ref_time, choose earlier one (more likely to be available) """ @@ -230,9 +255,11 @@ def get_n_closest_datetimes(ref_time: datetime.datetime, closest_times = [] if (24 % time_step_hours) != 0: - raise ValueError('The time step does not evenly divide 24 hours;' - 'Time step has period > 1 day and depends when model ' - 'starts') + raise ValueError( + 'The time step does not evenly divide 24 hours;' + 'Time step has period > 1 day and depends when model ' + 'starts' + ) ts = pd.Timestamp(ref_time) for k in range(iterations): @@ -251,63 +278,68 @@ def get_n_closest_datetimes(ref_time: datetime.datetime, return closest_times -def get_times_for_azimuth_interpolation(ref_time: datetime.datetime, - time_step_hours: int, - buffer_in_seconds: int = 300) -> list[datetime.datetime]: +def get_times_for_azimuth_interpolation( + ref_time: dt.datetime, + time_step_hours: int, + buffer_in_seconds: int = 300 +) -> list[dt.datetime]: """Obtains times needed for azimuth interpolation. Filters 3 closests dates from ref_time so that all returned dates are within `time_step_hours` + `buffer_in_seconds`. This ensures we request dates that are really needed. ``` - dt = datetime.datetime(2023, 1, 1, 11, 1, 0) + dt = dt.datetime(2023, 1, 1, 11, 1, 0) get_times_for_azimuth_interpolation(dt, 1) ``` yields ``` - [datetime.datetime(2023, 1, 1, 11, 0, 0), - datetime.datetime(2023, 1, 1, 12, 0, 0), - datetime.datetime(2023, 1, 1, 10, 0, 0)] + [dt.datetime(2023, 1, 1, 11, 0, 0), + dt.datetime(2023, 1, 1, 12, 0, 0), + dt.datetime(2023, 1, 1, 10, 0, 0)] ``` whereas ``` - dt = datetime.datetime(2023, 1, 1, 11, 30, 0) + dt = dt.datetime(2023, 1, 1, 11, 30, 0) get_times_for_azimuth_interpolation(dt, 1) ``` yields ``` - [datetime.datetime(2023, 1, 1, 11, 0, 0), - datetime.datetime(2023, 1, 1, 12, 0, 0)] + [dt.datetime(2023, 1, 1, 11, 0, 0), + dt.datetime(2023, 1, 1, 12, 0, 0)] ``` Parameters ---------- - ref_time : datetime.datetime + ref_time : dt.datetime A time of acquisition time_step_hours : int Weather model time step, should evenly divide 24 hours buffer_in_seconds : int, optional Buffer for filtering absolute times, by default 300 (or 5 minutes) - Returns + Returns: ------- - list[datetime.datetime] + list[dt.datetime] 2 or 3 closest times within 1 time step (plust the buffer) and the reference time """ # Get 3 closest times closest_times = get_n_closest_datetimes(ref_time, 3, time_step_hours) - def filter_time(time: datetime.datetime): + def filter_time(time: dt.datetime): absolute_time_difference_sec = abs((ref_time - time).total_seconds()) upper_bound_seconds = time_step_hours * 60 * 60 + buffer_in_seconds return absolute_time_difference_sec < upper_bound_seconds + out_times = list(filter(filter_time, closest_times)) return out_times -def get_inverse_weights_for_dates(azimuth_time_array: np.ndarray, - dates: list[datetime.datetime], - inverse_regularizer: float = 1e-9, - temporal_window_hours: float = None) -> list[np.ndarray]: +def get_inverse_weights_for_dates( + azimuth_time_array: np.ndarray, + dates: list[dt.datetime], + inverse_regularizer: float = 1e-9, + temporal_window_hours: Optional[float] = None, +) -> list[np.ndarray]: """Obtains weights according to inverse weighting with respect to the absolute difference between azimuth timing array and dates. The output will be a list with length equal to that of dates and whose entries are arrays each whose shape matches the azimuth_timing_array. @@ -319,7 +351,7 @@ def get_inverse_weights_for_dates(azimuth_time_array: np.ndarray, ---------- azimuth_time_array : np.ndarray Array of type `np.datetime64[ms]` - dates : list[datetime.datetime] + dates : list[dt.datetime] List of datetimes inverse_regularizer : float, optional If a `time` in the azimuth time arr equals one of the given dates, then the regularlizer ensures that the value @@ -330,7 +362,7 @@ def get_inverse_weights_for_dates(azimuth_time_array: np.ndarray, No check of equi-spaced dates are done so not specifying temporal window hours requires dates to be derived from valid model time steps - Returns + Returns: ------- list[np.ndarray] Weighting per pixel with respect to each date @@ -342,7 +374,7 @@ def get_inverse_weights_for_dates(azimuth_time_array: np.ndarray, if n_dates == 0: raise ValueError('No dates provided') - if not all([isinstance(date, datetime.datetime) for date in dates]): + if not all([isinstance(date, dt.datetime) for date in dates]): raise TypeError('dates must be all datetimes') if temporal_window_hours is None: temporal_window_seconds = min([abs((date - dates[0]).total_seconds()) for date in dates[1:]]) @@ -354,7 +386,7 @@ def get_inverse_weights_for_dates(azimuth_time_array: np.ndarray, abs_diff = [np.abs(azimuth_time_array - date) / np.timedelta64(1, 's') for date in dates_np] # Get inverse weighting with mask determined by window - wgts = [1. / (diff + inverse_regularizer) for diff in abs_diff] + wgts = [1.0 / (diff + inverse_regularizer) for diff in abs_diff] masks = [(diff <= temporal_window_seconds).astype(int) for diff in abs_diff] if all([mask.sum() == 0 for mask in masks]): diff --git a/tools/RAiDER/s1_orbits.py b/tools/RAiDER/s1_orbits.py index c7cff6462..217a1191d 100644 --- a/tools/RAiDER/s1_orbits.py +++ b/tools/RAiDER/s1_orbits.py @@ -1,3 +1,4 @@ +import datetime as dt import netrc import os import re @@ -7,6 +8,7 @@ import eof.download from hyp3lib import get_orb + from RAiDER.logger import logger @@ -20,7 +22,7 @@ def _netrc_path() -> Path: def ensure_orbit_credentials() -> Optional[int]: - """Ensure credentials exist for ESA's CDSE and ASF's S1QC to download orbits + """Ensure credentials exist for ESA's CDSE and ASF's S1QC to download orbits. This method will prefer to use CDSE and NASA Earthdata credentials from your `~/.netrc` file if they exist, otherwise will look for environment variables and update or create your `~/.netrc` file. The environment variables @@ -44,9 +46,11 @@ def ensure_orbit_credentials() -> Optional[int]: username = os.environ.get('ESA_USERNAME') password = os.environ.get('ESA_PASSWORD') if username is None or password is None: - raise ValueError('Credentials are required for fetching orbit data from dataspace.copernicus.eu!\n' - 'Either add your credentials to ~/.netrc or set the ESA_USERNAME and ESA_PASSWORD ' - 'environment variables.') + raise ValueError( + 'Credentials are required for fetching orbit data from dataspace.copernicus.eu!\n' + 'Either add your credentials to ~/.netrc or set the ESA_USERNAME and ESA_PASSWORD ' + 'environment variables.' + ) netrc_credentials.hosts[ESA_CDSE_HOST] = (username, None, password) @@ -54,9 +58,11 @@ def ensure_orbit_credentials() -> Optional[int]: username = os.environ.get('EARTHDATA_USERNAME') password = os.environ.get('EARTHDATA_PASSWORD') if username is None or password is None: - raise ValueError('Credentials are required for fetching orbit data from s1qc.asf.alaska.edu!\n' - 'Either add your credentials to ~/.netrc or set the EARTHDATA_USERNAME and' - ' EARTHDATA_PASSWORD environment variables.') + raise ValueError( + 'Credentials are required for fetching orbit data from s1qc.asf.alaska.edu!\n' + 'Either add your credentials to ~/.netrc or set the EARTHDATA_USERNAME and' + ' EARTHDATA_PASSWORD environment variables.' + ) netrc_credentials.hosts[NASA_EDL_HOST] = (username, None, password) @@ -64,7 +70,7 @@ def ensure_orbit_credentials() -> Optional[int]: def get_orbits_from_slc_ids(slc_ids: List[str], directory=Path.cwd()) -> List[Path]: - """Download all orbit files for a set of SLCs + """Download all orbit files for a set of SLCs. This method will ensure that the downloaded orbit files cover the entire acquisition start->stop time @@ -79,11 +85,8 @@ def get_orbits_from_slc_ids(slc_ids: List[str], directory=Path.cwd()) -> List[Pa return orb_files -def get_orbits_from_slc_ids_hyp3lib( - slc_ids: list, orbit_directory: str = None -) -> dict: - """Reference: https://github.com/ACCESS-Cloud-Based-InSAR/DockerizedTopsApp/blob/dev/isce2_topsapp/localize_orbits.py#L23""" - +def get_orbits_from_slc_ids_hyp3lib(slc_ids: list, orbit_directory: str = None) -> dict: + """Reference: https://github.com/ACCESS-Cloud-Based-InSAR/DockerizedTopsApp/blob/dev/isce2_topsapp/localize_orbits.py#L23.""" # Populates env variables to netrc as required for sentineleof _ = ensure_orbit_credentials() esa_username, _, esa_password = netrc.netrc().authenticators(ESA_CDSE_HOST) @@ -105,25 +108,25 @@ def get_orbits_from_slc_ids_hyp3lib( return orbits -def download_eofs(dts: list, missions: list, save_dir: str): - """Wrapper around sentineleof to first try downloading from ASF and fall back to CDSE""" +def download_eofs(datetimes: list[dt.datetime], missions: list, save_dir: str): + """Wrapper around sentineleof to first try downloading from ASF and fall back to CDSE.""" _ = ensure_orbit_credentials() orb_files = [] - for dt, mission in zip(dts, missions): - dt = dt if isinstance(dt, list) else [dt] + for datetime, mission in zip(datetimes, missions): + datetime = datetime if isinstance(datetime, list) else [datetime] mission = mission if isinstance(mission, list) else [mission] try: - orb_file = eof.download.download_eofs(dt, mission, save_dir=save_dir, force_asf=True) + orb_file = eof.download.download_eofs(datetime, mission, save_dir=save_dir, force_asf=True) except: logger.error('Could not download orbit from ASF, trying ESA...') - orb_file = eof.download.download_eofs(dt, mission, save_dir=save_dir, force_asf=False) + orb_file = eof.download.download_eofs(datetime, mission, save_dir=save_dir, force_asf=False) orb_file = orb_file[0] if isinstance(orb_file, list) else orb_file orb_files.append(orb_file) - if not len(orb_files) == len(dts): - raise Exception(f'Missing {len(dts) - len(orb_files)} orbit files! dts={dts}, orb_files={len(orb_files)}') + if not len(orb_files) == len(datetimes): + raise Exception(f'Missing {len(datetimes) - len(orb_files)} orbit files! dts={datetimes}, orb_files={len(orb_files)}') return orb_files diff --git a/tools/RAiDER/types/BB.py b/tools/RAiDER/types/BB.py new file mode 100644 index 000000000..260f2de06 --- /dev/null +++ b/tools/RAiDER/types/BB.py @@ -0,0 +1,7 @@ +"""Types to help distinguish different bounding box formats.""" + +SNWE = tuple[float, float, float, float] +WSEN = tuple[float, float, float, float] # used in dem_stitcher + +SN = tuple[float, float] +WE = tuple[float, float] diff --git a/tools/RAiDER/types/RIO.py b/tools/RAiDER/types/RIO.py new file mode 100644 index 000000000..0340cc271 --- /dev/null +++ b/tools/RAiDER/types/RIO.py @@ -0,0 +1,27 @@ +"""Polyfills for several symbols used for types that rasterio doesn't export.""" + +from dataclasses import dataclass +from typing import TypedDict, Union + +import rasterio.crs +import rasterio.transform + + +GDAL = tuple[float, float, float, float, float, float] + +@dataclass +class Statistics: + max: float + mean: float + min: float + std: float + + +class Profile(TypedDict): + driver: str + width: int + height: int + count: int + crs: Union[str, dict, rasterio.crs.CRS] + transform: rasterio.transform.Affine + dtype: str diff --git a/tools/RAiDER/types/__init__.py b/tools/RAiDER/types/__init__.py new file mode 100644 index 000000000..58d00736b --- /dev/null +++ b/tools/RAiDER/types/__init__.py @@ -0,0 +1,10 @@ +"""Types specific to RAiDER.""" + +from typing import Literal, Union + +from pyproj import CRS + + +LookDir = Literal['right', 'left'] +TimeInterpolationMethod = Literal['none', 'center_time', 'azimuth_time_grid'] +CRSLike = Union[CRS, str, int] diff --git a/tools/RAiDER/utilFcns.py b/tools/RAiDER/utilFcns.py old mode 100755 new mode 100644 index c0107928c..d028861ab --- a/tools/RAiDER/utilFcns.py +++ b/tools/RAiDER/utilFcns.py @@ -1,13 +1,29 @@ """Geodesy-related utility functions.""" -import os + +import datetime as dt +import pathlib import re -import xarray +from pathlib import Path +from typing import Any, Optional, Union -from datetime import datetime, timedelta, timezone +import numpy as np +import rasterio +import xarray as xr +import yaml from numpy import ndarray -from pyproj import Transformer, CRS, Proj +from pyproj import CRS, Proj, Transformer + +import RAiDER +from RAiDER.constants import ( + R_EARTH_MAX_WGS84 as Rmax, + R_EARTH_MIN_WGS84 as Rmin, + _THRESHOLD_SECONDS, + _g0 as g0, + _g1 as G1, +) +from RAiDER.logger import logger +from RAiDER.types import BB, RIO, CRSLike -import numpy as np # Optional imports try: @@ -18,38 +34,24 @@ import multiprocessing as mp except ImportError: mp = None -try: - import rasterio -except ImportError: - rasterio = None try: import progressbar except ImportError: progressbar = None -from RAiDER.constants import ( - _g0 as g0, - _g1 as G1, - R_EARTH_MAX_WGS84 as Rmax, - R_EARTH_MIN_WGS84 as Rmin, - _THRESHOLD_SECONDS, -) -from RAiDER.logger import logger - - pbar = None def projectDelays(delay, inc): - '''Project zenith delays to LOS''' - if inc==90: + """Project zenith delays to LOS.""" + if inc == 90: raise ZeroDivisionError return delay / cosd(inc) def floorish(val, frac): - '''Round a value to the lower fractional part''' + """Round a value to the lower fractional part.""" return val - (val % frac) @@ -108,7 +110,7 @@ def enu2ecef( def ecef2enu(xyz, lat, lon, height): - '''Convert ECEF xyz to ENU''' + """Convert ECEF xyz to ENU.""" x, y, z = xyz[..., 0], xyz[..., 1], xyz[..., 2] t = cosd(lon) * x + sind(lon) * y @@ -119,62 +121,54 @@ def ecef2enu(xyz, lat, lon, height): return np.stack((e, n, u), axis=-1) -def rio_profile(fname): - ''' - Reads the profile of a rasterio file - ''' - if rasterio is None: - raise ImportError('RAiDER.utilFcns: rio_profile - rasterio is not installed') - - ## need to access subdataset directly - if os.path.basename(fname).startswith('S1-GUNW'): - fname = os.path.join(f'NETCDF:"{fname}":science/grids/data/unwrappedPhase') - with rasterio.open(fname) as src: - profile = src.profile - - elif os.path.exists(fname + '.vrt'): - fname = fname + '.vrt' +def rio_profile(path: Path) -> RIO.Profile: + """Reads the profile of a rasterio file.""" + path_vrt = Path(f'{path}.vrt') - with rasterio.open(fname) as src: - profile = src.profile + if path.name.startswith('S1-GUNW'): + # need to access subdataset directly + path = Path(f'NETCDF:"{path}":science/grids/data/unwrappedPhase') + elif path_vrt.exists(): + path = path_vrt - return profile + with rasterio.open(path) as src: + return src.profile -def rio_extents(profile): - """ Get a bounding box in SNWE from a rasterio profile """ - gt = profile["transform"].to_gdal() - xSize = profile["width"] - ySize = profile["height"] +def rio_extents(profile: RIO.Profile) -> BB.SNWE: + """Get a bounding box in SNWE from a rasterio profile.""" + gt = profile['transform'].to_gdal() + xSize = profile['width'] + ySize = profile['height'] W, E = gt[0], gt[0] + (xSize - 1) * gt[1] + (ySize - 1) * gt[2] N, S = gt[3], gt[3] + (xSize - 1) * gt[4] + (ySize - 1) * gt[5] return S, N, W, E -def rio_open(fname, returnProj=False, userNDV=None, band=None): - ''' - Reads a rasterio-compatible raster file and returns the data and profile - ''' - if rasterio is None: - raise ImportError('RAiDER.utilFcns: rio_open - rasterio is not installed') +def rio_open( + path: Path, + userNDV: Optional[float]=None, + band: Optional[int]=None +) -> tuple[np.ndarray, RIO.Profile]: + """Reads a rasterio-compatible raster file and returns the data and profile.""" + vrt_path = path.with_suffix(path.suffix + '.vrt') + if vrt_path.exists(): + path = vrt_path - if os.path.exists(fname + '.vrt'): - fname = fname + '.vrt' - - with rasterio.open(fname) as src: - profile = src.profile + with rasterio.open(path) as src: + profile: RIO.Profile = src.profile # For all bands - nodata = src.nodatavals + nodata: tuple[float, ...] = src.nodatavals # If user requests a band if band is not None: ndv = nodata[band - 1] - data = src.read(band).squeeze() - nodataToNan(data, [userNDV, nodata[band - 1]]) + data: np.ndarray = src.read(band).squeeze() + nodataToNan(data, [userNDV, ndv]) else: - data = src.read().squeeze() + data: np.ndarray = src.read().squeeze() if data.ndim > 2: for bnd in range(data.shape[0]): val = data[bnd, ...] @@ -182,90 +176,82 @@ def rio_open(fname, returnProj=False, userNDV=None, band=None): else: nodataToNan(data, list(nodata) + [userNDV]) - if data.ndim > 2: - dlist = [] + dlist: list[list[float]] = [] for k in range(data.shape[0]): - dlist.append(data[k,...].copy()) - data = dlist + dlist.append(data[k].copy()) + data = np.array(dlist) - if not returnProj: - return data + return data, profile - else: - return data, profile - -def nodataToNan(inarr, listofvals): - """ - Setting values to nan as needed - """ - inarr = inarr.astype(float) # nans cannot be integers (i.e. in DEM) - for val in listofvals: +def nodataToNan(inarr: np.ndarray, vals: list[Optional[float]]) -> None: + """Setting values to nan as needed.""" + inarr = inarr.astype(float) # nans cannot be integers (i.e. in DEM) + for val in vals: if val is not None: inarr[inarr == val] = np.nan -def rio_stats(fname, band=1): - ''' - Read a rasterio-compatible file and pull the metadata. +def rio_stats(path: Path, band: int=1) -> tuple[RIO.Statistics, Optional[CRS], RIO.GDAL]: + """Read a rasterio-compatible file and pull the metadata. Args: - fname - filename to be loaded + path - file path to be loaded band - band number to use for getting statistics Returns: stats - a list of stats for the specified band proj - CRS/projection information for the file gt - geotransform for the data - ''' - if rasterio is None: - raise ImportError('RAiDER.utilFcns: rio_stats - rasterio is not installed') - - if os.path.basename(fname).startswith('S1-GUNW'): - fname = os.path.join(f'NETCDF:"{fname}":science/grids/data/unwrappedPhase') + """ + if path.name.startswith('S1-GUNW'): + path = Path(f'NETCDF:"{path}":science/grids/data/unwrappedPhase') - if os.path.exists(fname + '.vrt'): - fname = fname + '.vrt' + vrt_path = path.with_suffix(path.suffix + '.vrt') + if vrt_path.exists(): + path = vrt_path # Turn off PAM to avoid creating .aux.xml files - with rasterio.Env(GDAL_PAM_ENABLED="NO"): - with rasterio.open(fname) as src: - gt = src.transform.to_gdal() - proj = src.crs + with rasterio.Env(GDAL_PAM_ENABLED='NO'): + with rasterio.open(path) as src: stats = src.statistics(band) + proj = src.crs + gt = src.transform.to_gdal() return stats, proj, gt -def get_file_and_band(filestr): - """ - Support file;bandnum as input for filename strings - """ - parts = filestr.split(";") +def get_file_and_band(filestr: str) -> tuple[Path, int]: + """Support file;bandnum as input for filename strings.""" + parts = filestr.split(';') # Defaults to first band if no bandnum is provided if len(parts) == 1: - return filestr.strip(), 1 + return Path(filestr.strip()), 1 elif len(parts) == 2: - return parts[0].strip(), int(parts[1].strip()) + return Path(parts[0].strip()), int(parts[1].strip()) else: - raise ValueError( - f"Cannot interpret {filestr} as valid filename" - ) - -def writeArrayToRaster(array, filename, noDataValue=0., fmt='ENVI', proj=None, gt=None): - ''' - write a numpy array to a GDAL-readable raster - ''' + raise ValueError(f'Cannot interpret {filestr} as valid filename') + + +def writeArrayToRaster( + array: np.ndarray, + path: Path, + noDataValue: float=0.0, + fmt: str='ENVI', + proj: Optional[CRS]=None, + gt: Optional[RIO.GDAL]=None +) -> None: + """Write a numpy array to a GDAL-readable raster.""" array_shp = np.shape(array) if array.ndim != 2: - raise RuntimeError('writeArrayToRaster: cannot write an array of shape {} to a raster image'.format(array_shp)) + raise RuntimeError(f'writeArrayToRaster: cannot write an array of shape {array_shp} to a raster image') # Data type - if "complex" in str(array.dtype): + if 'complex' in str(array.dtype): dtype = np.complex64 - elif "float" in str(array.dtype): + elif 'float' in str(array.dtype): dtype = np.float32 else: dtype = np.uint8 @@ -278,32 +264,39 @@ def writeArrayToRaster(array, filename, noDataValue=0., fmt='ENVI', proj=None, g except TypeError: trans = gt - ## cant write netcdfs with rasterio in a simple way + # cant write netcdfs with rasterio in a simple way if fmt == 'nc': fmt = 'GTiff' - filename = filename.replace('.nc', '.tif') - - with rasterio.open(filename, mode="w", count=1, - width=array_shp[1], height=array_shp[0], - dtype=dtype, crs=proj, nodata=noDataValue, - driver=fmt, transform=trans) as dst: + path = path.with_suffix('.tif') + + with rasterio.open( + path, + mode='w', + count=1, + width=array_shp[1], + height=array_shp[0], + dtype=dtype, + crs=proj, + nodata=noDataValue, + driver=fmt, + transform=trans, + ) as dst: dst.write(array, 1) - logger.info('Wrote: %s', filename) - return + logger.info('Wrote: %s', path) def round_date(date, precision): # First try rounding up # Timedelta since the beginning of time - T0 = datetime.min + T0 = dt.datetime.min try: datedelta = T0 - date except TypeError: - T0 = T0.replace(tzinfo=timezone(offset=timedelta())) + T0 = T0.replace(tzinfo=dt.timezone(offset=dt.timedelta())) datedelta = T0 - date - - # Round that timedelta to the specified precision + + # Round that dt.timedelta to the specified precision rem = datedelta % precision # Add back to get date rounded up round_up = date + rem @@ -312,9 +305,9 @@ def round_date(date, precision): try: datedelta = date - T0 except TypeError: - T0 = T0.replace(tzinfo=timezone(offset=timedelta())) + T0 = T0.replace(tzinfo=dt.timezone(offset=dt.timedelta())) datedelta = date - T0 - + rem = datedelta % precision round_down = date - rem @@ -336,29 +329,23 @@ def _least_nonzero(a): def robmin(a): - ''' - Get the minimum of an array, accounting for empty lists - ''' + """Get the minimum of an array, accounting for empty lists.""" return np.nanmin(a) def robmax(a): - ''' - Get the minimum of an array, accounting for empty lists - ''' + """Get the minimum of an array, accounting for empty lists.""" return np.nanmax(a) def _get_g_ll(lats): - ''' - Compute the variation in gravity constant with latitude - ''' - return G1 * (1 - 0.002637 * cosd(2 * lats) + 0.0000059 * (cosd(2 * lats))**2) + """Compute the variation in gravity constant with latitude.""" + return G1 * (1 - 0.002637 * cosd(2 * lats) + 0.0000059 * (cosd(2 * lats)) ** 2) def get_Re(lats): - ''' - Returns earth radius as a function of latitude for WGS84 + """ + Returns earth radius as a function of latitude for WGS84. Args: lats - ndarray of geodetic latitudes in degrees @@ -374,8 +361,8 @@ def get_Re(lats): array([6378137., 6372770.5219805, 6367417.56705189, 6362078.07851428, 6356752.]) >>> assert output[0] == 6378137 # (Rmax) >>> assert output[-1] == 6356752 # (Rmin) - ''' - return np.sqrt(1 / (((cosd(lats)**2) / Rmax**2) + ((sind(lats)**2) / Rmin**2))) + """ + return np.sqrt(1 / (((cosd(lats) ** 2) / Rmax**2) + ((sind(lats) ** 2) / Rmin**2))) def geo_to_ht(lats, hts): @@ -401,8 +388,8 @@ def geo_to_ht(lats, hts): Returns: ndarray: geometric heights. These are approximate ellipsoidal heights referenced to WGS84 """ - g_ll = _get_g_ll(lats) # gravity function of latitude - Re = get_Re(lats) # Earth radius function of latitude + g_ll = _get_g_ll(lats) # gravity function of latitude + Re = get_Re(lats) # Earth radius function of latitude # Calculate Geometric Height, h h = (hts * Re) / (g_ll / g0 * Re - hts) @@ -411,29 +398,33 @@ def geo_to_ht(lats, hts): def padLower(invar): - ''' - add a layer of data below the lowest current z-level at height zmin - ''' + """Add a layer of data below the lowest current z-level at height zmin.""" new_var = _least_nonzero(invar) return np.concatenate((new_var[:, :, np.newaxis], invar), axis=2) -def round_time(dt, roundTo=60): - ''' +def round_time(datetime, roundTo=60): + """ Round a datetime object to any time lapse in seconds - dt: datetime.datetime object + datetime: dt.datetime object roundTo: Closest number of seconds to round to, default 1 minute. Source: https://stackoverflow.com/questions/3463930/how-to-round-the-minute-of-a-datetime-object/10854034#10854034 - ''' - seconds = (dt.replace(tzinfo=None) - dt.min).seconds + """ + seconds = (datetime.replace(tzinfo=None) - datetime.min).seconds rounding = (seconds + roundTo / 2) // roundTo * roundTo - return dt + timedelta(0, rounding - seconds, -dt.microsecond) - - -def writeDelays(aoi, wetDelay, hydroDelay, - wetFilename, hydroFilename=None, - outformat=None, ndv=0.): - """ Write the delay numpy arrays to files in the format specified """ + return datetime + dt.timedelta(0, rounding - seconds, -datetime.microsecond) + + +def writeDelays( + aoi, #: AOI, + wetDelay, + hydroDelay, + wet_path: Path, + hydro_path: Optional[Path]=None, + outformat: str=None, + ndv: float=0.0 +) -> None: + """Write the delay numpy arrays to files in the format specified.""" if pd is None: raise ImportError('pandas is required to write GNSS delays to a file') @@ -443,44 +434,29 @@ def writeDelays(aoi, wetDelay, hydroDelay, # Do different things, depending on the type of input if aoi.type() == 'station_file': - df = pd.read_csv(aoi._filename).drop_duplicates(subset=["Lat", "Lon"]) + df = pd.read_csv(aoi._filename).drop_duplicates(subset=['Lat', 'Lon']) df['wetDelay'] = wetDelay df['hydroDelay'] = hydroDelay df['totalDelay'] = wetDelay + hydroDelay - df.to_csv(wetFilename, index=False) - logger.info('Wrote delays to: %s', wetFilename) + df.to_csv(str(wet_path), index=False) + logger.info('Wrote delays to: %s', wet_path.absolute()) else: + if hydro_path is None: + raise ValueError('Hydro delay file path must be specified if the AOI is not a station file') proj = aoi.projection() - gt = aoi.geotransform() - writeArrayToRaster( - wetDelay, - wetFilename, - noDataValue=ndv, - fmt=outformat, - proj=proj, - gt=gt - ) - writeArrayToRaster( - hydroDelay, - hydroFilename, - noDataValue=ndv, - fmt=outformat, - proj=proj, - gt=gt - ) + gt = aoi.geotransform() + writeArrayToRaster(wetDelay, wet_path, noDataValue=ndv, fmt=outformat, proj=proj, gt=gt) + writeArrayToRaster(hydroDelay, hydro_path, noDataValue=ndv, fmt=outformat, proj=proj, gt=gt) def getTimeFromFile(filename): - ''' - Parse a filename to get a date-time - ''' + """Parse a filename to get a date-time.""" fmt = '%Y_%m_%d_T%H_%M_%S' p = re.compile(r'\d{4}_\d{2}_\d{2}_T\d{2}_\d{2}_\d{2}') out = p.search(filename).group() - return datetime.strptime(out, fmt) - + return dt.datetime.strptime(out, fmt) # Part of the following UTM and WGS84 converter is borrowed from https://gist.github.com/twpayne/4409500 @@ -536,7 +512,7 @@ def WGS84_to_UTM(lon, lat, common_center=False): if common_center: lon0 = np.median(lon) lat0 = np.median(lat) - z0, l0, x0, y0 = project((lon0, lat0)) + z0, l0, _, _ = project((lon0, lat0)) Z = lon.copy() L = np.zeros(lon.shape, dtype=' None: # I added datetime as an input to the function and just copied these two lines from merra2 for the attrs_dict attrs_dict = { - 'datetime': dt.strftime("%Y_%m_%dT%H_%M_%S"), - 'date_created': datetime.now().strftime("%Y_%m_%dT%H_%M_%S"), + 'datetime': datetime.strftime('%Y_%m_%dT%H_%M_%S'), + 'date_created': datetime.now().strftime('%Y_%m_%dT%H_%M_%S'), 'NoDataValue': NoDataValue, 'chunksize': chunk, # 'mapping_name': mapping_name, } - + dimension_dict = { 'latitude': (('y', 'x'), lat), 'longitude': (('y', 'x'), lon), @@ -660,34 +636,34 @@ def writeWeatherVarsXarray(lat, lon, h, q, p, t, dt, crs, outName=None, NoDataVa 't': (('z', 'y', 'x'), t), } - ds = xarray.Dataset( - data_vars=dataset_dict, - coords=dimension_dict, - attrs=attrs_dict, - ) - + ds = xr.Dataset( + data_vars=dataset_dict, + coords=dimension_dict, + attrs=attrs_dict, + ) + ds['h'].attrs['standard_name'] = 'mid_layer_heights' ds['p'].attrs['standard_name'] = 'mid_level_pressure' ds['q'].attrs['standard_name'] = 'specific_humidity' ds['t'].attrs['standard_name'] = 'air_temperature' - + ds['h'].attrs['units'] = 'm' ds['p'].attrs['units'] = 'Pa' ds['q'].attrs['units'] = 'kg kg-1' ds['t'].attrs['units'] = 'K' - ds["proj"] = int() + ds['proj'] = 0 for k, v in crs.to_cf().items(): ds.proj.attrs[k] = v for var in ds.data_vars: ds[var].attrs['grid_mapping'] = 'proj' - + ds.to_netcdf(outName) del ds - + def convertLons(inLons): - '''Convert lons from 0-360 to -180-180''' + """Convert lons from 0-360 to -180-180.""" mask = inLons > 180 outLons = inLons outLons[mask] = outLons[mask] - 360 @@ -695,13 +671,12 @@ def convertLons(inLons): def read_NCMR_loginInfo(filepath=None): - from pathlib import Path if filepath is None: filepath = str(Path.home()) + '/.ncmrlogin' - f = open(filepath, 'r') + f = open(filepath) lines = f.readlines() url = lines[0].strip().split(': ')[1] username = lines[1].strip().split(': ')[1] @@ -711,18 +686,17 @@ def read_NCMR_loginInfo(filepath=None): def read_EarthData_loginInfo(filepath=None): - from netrc import netrc - urs_usr, _, urs_pwd = netrc().hosts["urs.earthdata.nasa.gov"] + urs_usr, _, urs_pwd = netrc().hosts['urs.earthdata.nasa.gov'] return urs_usr, urs_pwd -def show_progress(block_num, block_size, total_size): - '''Show download progress''' +def show_progress(block_num, block_size, total_size) -> None: + """Show download progress.""" if progressbar is None: raise ImportError('RAiDER.utilFcns: show_progress - progressbar is not available') - + global pbar if pbar is None: pbar = progressbar.ProgressBar(maxval=total_size) @@ -737,26 +711,22 @@ def show_progress(block_num, block_size, total_size): def getChunkSize(in_shape): - '''Create a reasonable chunk size''' + """Create a reasonable chunk size.""" if mp is None: raise ImportError('RAiDER.utilFcns: getChunkSize - multiprocessing is not available') minChunkSize = 100 maxChunkSize = 1000 cpu_count = mp.cpu_count() - chunkSize = tuple( - max( - min(maxChunkSize, s // cpu_count), - min(s, minChunkSize) - ) for s in in_shape - ) + chunkSize = tuple(max(min(maxChunkSize, s // cpu_count), min(s, minChunkSize)) for s in in_shape) return chunkSize def calcgeoh(lnsp, t, q, z, a, b, R_d, num_levels): - ''' + """ Calculate pressure, geopotential, and geopotential height from the surface pressure and model levels provided by a weather model. The model levels are numbered from the highest eleveation to the lowest. + Args: ---------- lnsp: ndarray - [y, x] array of log surface pressure @@ -766,13 +736,14 @@ def calcgeoh(lnsp, t, q, z, a, b, R_d, num_levels): a: ndarray - [z] vector of a values b: ndarray - [z] vector of b values num_levels: int - integer number of model levels + Returns: ------- geopotential - The geopotential in units of height times acceleration pressurelvs - The pressure at each of the model levels for each of the input points geoheight - The geopotential heights - ''' + """ geopotential = np.zeros_like(t) pressurelvs = np.zeros_like(geopotential) geoheight = np.zeros_like(geopotential) @@ -783,15 +754,13 @@ def calcgeoh(lnsp, t, q, z, a, b, R_d, num_levels): if len(a) != num_levels + 1 or len(b) != num_levels + 1: raise ValueError( - 'I have here a model with {} levels, but parameters a '.format(num_levels) + - 'and b have lengths {} and {} respectively. Of '.format(len(a), len(b)) + - 'course, these three numbers should be equal.') + f'I have here a model with {num_levels} levels, but parameters a and b have lengths {len(a)} and {len(b)} ' + 'respectively. Of course, these three numbers should be equal.' + ) # Integrate up into the atmosphere from *lowest level* z_h = 0 # initial value - for lev, t_level, q_level in zip( - range(num_levels, 0, -1), t[::-1], q[::-1]): - + for lev, t_level, q_level in zip(range(num_levels, 0, -1), t[::-1], q[::-1]): # lev is the level number 1-60, we need a corresponding index # into ts and qs # ilevel = num_levels - lev # << this was Ray's original, but is a typo @@ -835,18 +804,17 @@ def calcgeoh(lnsp, t, q, z, a, b, R_d, num_levels): def transform_coords(proj1, proj2, x, y): """ Transform coordinates from proj1 to proj2 (can be EPSG or crs from proj). - e.g. x, y = transform_coords(4326, 4087, lon, lat) + e.g. x, y = transform_coords(4326, 4087, lon, lat). """ transformer = Transformer.from_crs(proj1, proj2, always_xy=True) return transformer.transform(x, y) def get_nearest_wmtimes(t0, time_delta): - """" - Get the nearest two available times to the requested time given a time step + """Get the nearest two available times to the requested time given a time step. Args: - t0 - user-requested Python datetime + t0 - user-requested Python datetime time_delta - time interval of weather model Returns: @@ -854,18 +822,18 @@ def get_nearest_wmtimes(t0, time_delta): available times to the requested time Example: - >>> import datetime + >>> import datetime as dt >>> from RAiDER.utilFcns import get_nearest_wmtimes - >>> t0 = datetime.datetime(2020,1,1,11,35,0) + >>> t0 = dt.datetime(2020,1,1,11,35,0) >>> get_nearest_wmtimes(t0, 3) - (datetime.datetime(2020, 1, 1, 9, 0), datetime.datetime(2020, 1, 1, 12, 0)) + (dt.datetime(2020, 1, 1, 9, 0), dt.datetime(2020, 1, 1, 12, 0)) """ # get the closest time available - tclose = round_time(t0, roundTo = time_delta * 60 *60) + tclose = round_time(t0, roundTo=time_delta * 60 * 60) # Just calculate both options and take the closest - t2_1 = tclose + timedelta(hours=time_delta) - t2_2 = tclose - timedelta(hours=time_delta) + t2_1 = tclose + dt.timedelta(hours=time_delta) + t2_2 = tclose - dt.timedelta(hours=time_delta) t2 = [t2_1 if get_dt(t2_1, t0) < get_dt(t2_2, t0) else t2_2][0] # If you're within 5 minutes just take the closest time @@ -878,10 +846,10 @@ def get_nearest_wmtimes(t0, time_delta): return [t2, tclose] -def get_dt(t1,t2): - ''' +def get_dt(t1: dt.datetime, t2: dt.datetime) -> float: + """ Helper function for getting the absolute difference in seconds between - two python datetimes + two python datetimes. Args: t1, t2 - Python datetimes @@ -890,11 +858,59 @@ def get_dt(t1,t2): Absolute difference in seconds between the two inputs Examples: - >>> import datetime + >>> import datetime as dt >>> from RAiDER.utilFcns import get_dt - >>> get_dt(datetime.datetime(2020,1,1,5,0,0), datetime.datetime(2020,1,1,0,0,0)) + >>> get_dt(dt.datetime(2020,1,1,5,0,0), dt.datetime(2020,1,1,0,0,0)) 18000.0 - ''' + """ return np.abs((t1 - t2).total_seconds()) +# Tell PyYAML how to serialize pathlib Paths +yaml.add_representer( + pathlib.PosixPath, + lambda dumper, data: dumper.represent_scalar( + 'tag:yaml.org,2002:str', + str(data) + ) +) +yaml.add_representer( + tuple, + lambda dumper, data: dumper.represent_sequence( + 'tag:yaml.org,2002:seq', + data + ) +) + +def write_yaml(content: dict[str, Any], dst: Union[str, Path]) -> Path: + """Write a new yaml file from a dictionary with template.yaml as a base. + + Each key-value pair in 'content' will override the one from template.yaml. + """ + yaml_path = Path(RAiDER.__file__).parent / 'cli/examples/template/template.yaml' + + with yaml_path.open() as f: + try: + params = yaml.safe_load(f) + except yaml.YAMLError as exc: + print(exc) + raise ValueError(f'Something is wrong with the yaml file {yaml_path}') + + params = {**params, **content} + + dst = Path(dst) + with dst.open('w') as fh: + yaml.dump(params, fh, default_flow_style=False) + + logger.info('Wrote new cfg file: %s', str(dst)) + return dst + + +def parse_crs(proj: CRSLike) -> CRS: + if isinstance(proj, CRS): + return proj + elif isinstance(proj, str): + return CRS.from_epsg(proj.lstrip('EPSG:')) + elif isinstance(proj, int): + return CRS.from_epsg(proj) + raise TypeError(f'Data type "{type(proj)}" not supported for CRS')