diff --git a/.github/workflows/changelog-check.yml b/.github/workflows/changelog-check.yml index 3b1e740..232d149 100644 --- a/.github/workflows/changelog-check.yml +++ b/.github/workflows/changelog-check.yml @@ -13,4 +13,4 @@ on: jobs: call-changelog-check-workflow: - uses: ASFHyP3/actions/.github/workflows/reusable-changelog-check.yml@v0.12.0 + uses: ASFHyP3/actions/.github/workflows/reusable-changelog-check.yml@v0.13.2 diff --git a/.github/workflows/create-jira-issue.yml b/.github/workflows/create-jira-issue.yml index d95ef84..7646baa 100644 --- a/.github/workflows/create-jira-issue.yml +++ b/.github/workflows/create-jira-issue.yml @@ -6,7 +6,7 @@ on: jobs: call-create-jira-issue-workflow: - uses: ASFHyP3/actions/.github/workflows/reusable-create-jira-issue.yml@v0.12.0 + uses: ASFHyP3/actions/.github/workflows/reusable-create-jira-issue.yml@v0.13.2 secrets: JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }} JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }} diff --git a/.github/workflows/labeled-pr-check.yml b/.github/workflows/labeled-pr-check.yml index f408f3b..465aaa8 100644 --- a/.github/workflows/labeled-pr-check.yml +++ b/.github/workflows/labeled-pr-check.yml @@ -12,4 +12,4 @@ on: jobs: call-labeled-pr-check-workflow: - uses: ASFHyP3/actions/.github/workflows/reusable-labeled-pr-check.yml@v0.12.0 + uses: ASFHyP3/actions/.github/workflows/reusable-labeled-pr-check.yml@v0.13.2 diff --git a/.github/workflows/release-checklist-comment.yml b/.github/workflows/release-checklist-comment.yml index b5c711f..1c3f8a9 100644 --- a/.github/workflows/release-checklist-comment.yml +++ b/.github/workflows/release-checklist-comment.yml @@ -9,7 +9,7 @@ on: jobs: call-release-workflow: - uses: ASFHyP3/actions/.github/workflows/reusable-release-checklist-comment.yml@v0.12.0 + uses: ASFHyP3/actions/.github/workflows/reusable-release-checklist-comment.yml@v0.13.2 permissions: pull-requests: write secrets: diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 26887e5..0a4dd92 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -8,7 +8,7 @@ on: jobs: call-release-workflow: # Docs: https://github.com/ASFHyP3/actions - uses: ASFHyP3/actions/.github/workflows/reusable-release.yml@v0.12.0 + uses: ASFHyP3/actions/.github/workflows/reusable-release.yml@v0.13.2 with: release_prefix: burst2safe secrets: diff --git a/.github/workflows/static-analysis.yml b/.github/workflows/static-analysis.yml index d8f7b24..35c37dd 100644 --- a/.github/workflows/static-analysis.yml +++ b/.github/workflows/static-analysis.yml @@ -5,7 +5,7 @@ on: [pull_request] jobs: call-secrets-analysis-workflow: # Docs: https://github.com/ASFHyP3/actions - uses: ASFHyP3/actions/.github/workflows/reusable-secrets-analysis.yml@v0.12.0 + uses: ASFHyP3/actions/.github/workflows/reusable-secrets-analysis.yml@v0.13.2 check-with-black: runs-on: ubuntu-latest @@ -18,4 +18,4 @@ jobs: call-ruff-workflow: # Docs: https://github.com/ASFHyP3/actions - uses: ASFHyP3/actions/.github/workflows/reusable-ruff.yml@v0.12.0 + uses: ASFHyP3/actions/.github/workflows/reusable-ruff.yml@v0.13.2 diff --git a/CHANGELOG.md b/CHANGELOG.md index d4b2c48..3de0950 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,18 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [PEP 440](https://www.python.org/dev/peps/pep-0440/) and uses [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [1.4.0] + +### Added +* download.py to support asynchronous downloads. +* Support for EDL token based authentication. + +### Changed +* Authorization behavior so that EDL credentials from an EDL token are prioritized above a username/password in either a netrc or the environment. +* Authorization behavior so that EDL username/password from a user's netrc are prioritized. Now writes username/password to the netrc if they are provided as environment variables. +* Switched to an asynchronous download approach. +* In burst2stack.py all input files are now downloaded first. + ## [1.3.1] ### Changed diff --git a/README.md b/README.md index bdfcd81..8be34fa 100644 --- a/README.md +++ b/README.md @@ -29,12 +29,16 @@ conda install -c conda-forge burst2safe ### Credentials To use `burst2safe`, you must provide your Earthdata Login credentials via two environment variables -(`EARTHDATA_USERNAME` and `EARTHDATA_PASSWORD`), or via your `.netrc` file. +(`EARTHDATA_USERNAME` and `EARTHDATA_PASSWORD`), or via your `.netrc` file. Alternatively, you can use an Earthdata Login Token stored in the `EARTHDATA_TOKEN` environment variable. If you do not already have an Earthdata account, you can sign up [here](https://urs.earthdata.nasa.gov/home). If you would like to set up Earthdata Login via your `.netrc` file, check out this [guide](https://harmony.earthdata.nasa.gov/docs#getting-started) to get started. +If you would like to set up Earthdata Login via a token, check out this [guide](https://urs.earthdata.nasa.gov/documentation/for_users/user_token) to get started. + +Note that `burst2safe` will prefer authorization information in this order: token > .netrc > username/password in environment. So if you have both a .netrc file and a token configured, it will use the token. + ## burst2safe usage The `burst2safe` command line tool can be run using the following structure: ```bash diff --git a/environment.yml b/environment.yml index 68b5866..e9a62ca 100644 --- a/environment.yml +++ b/environment.yml @@ -12,6 +12,7 @@ dependencies: - tifffile>=2022.04.22 - asf_search - dateparser!=1.1.0 + - aiohttp # For packaging, and testing - pytest - pytest-cov diff --git a/pyproject.toml b/pyproject.toml index 2951a39..fb80cd4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ dependencies = [ "tifffile>=2022.04.22", "asf_search", "dateparser!=1.1.0", + "aiohttp", ] [project.urls] diff --git a/src/burst2safe/auth.py b/src/burst2safe/auth.py index b42b859..2086660 100644 --- a/src/burst2safe/auth.py +++ b/src/burst2safe/auth.py @@ -6,6 +6,7 @@ EARTHDATA_HOST = 'urs.earthdata.nasa.gov' +TOKEN_ENV_VAR = 'EARTHDATA_TOKEN' def get_netrc() -> Path: @@ -57,21 +58,51 @@ def find_creds_in_netrc(service) -> Tuple[str, str]: return None, None -def get_earthdata_credentials() -> Tuple[str, str]: - """Get NASA EarthData credentials from the environment or netrc file. +def write_credentials_to_netrc_file(username: str, password: str) -> None: + """Write credentials to netrc file + + Args: + username: NASA EarthData username + password: NASA EarthData password + """ + netrc_file = get_netrc() + if not netrc_file.exists(): + netrc_file.touch() + + with open(netrc_file, 'a') as f: + f.write(f'machine {EARTHDATA_HOST} login {username} password {password}\n') + + +def check_earthdata_credentials(append=False) -> str: + """Check for NASA EarthData credentials in the netrc file or environment variables. + Will preferentially use the netrc file, and write credentials to the netrc file if found in the environment. + + Args: + append: Whether to append the credentials to the netrc file if creds found in the environment Returns: - Tuple of the NASA EarthData username and password + The location of the preferred credentials ('netrc' or 'token') """ - username, password = find_creds_in_env('EARTHDATA_USERNAME', 'EARTHDATA_PASSWORD') - if username and password: - return username, password + if os.getenv(TOKEN_ENV_VAR): + return 'token' username, password = find_creds_in_netrc(EARTHDATA_HOST) if username and password: - return username, password + return 'netrc' + + username, password = find_creds_in_env('EARTHDATA_USERNAME', 'EARTHDATA_PASSWORD') + if username and password: + if append: + write_credentials_to_netrc_file(username, password) + return 'netrc' + else: + raise ValueError( + 'NASA Earthdata credentials only found in environment variables,' + 'but appending to netrc file not allowed. Please allow appending to netrc.' + ) raise ValueError( - 'Please provide NASA EarthData credentials via the ' - 'EARTHDATA_USERNAME and EARTHDATA_PASSWORD environment variables, or your netrc file.' + 'Please provide NASA Earthdata credentials via your .netrc file,' + 'the EARTHDATA_USERNAME and EARTHDATA_PASSWORD environment variables,' + 'or an EDL Token via the EARTHDATA_TOKEN environment variable.' ) diff --git a/src/burst2safe/burst2safe.py b/src/burst2safe/burst2safe.py index bcbbec9..e2789e1 100644 --- a/src/burst2safe/burst2safe.py +++ b/src/burst2safe/burst2safe.py @@ -8,8 +8,9 @@ from shapely.geometry import Polygon from burst2safe import utils +from burst2safe.download import download_bursts from burst2safe.safe import Safe -from burst2safe.search import download_bursts, find_bursts +from burst2safe.search import find_bursts DESCRIPTION = """Convert a set of ASF burst SLCs to the ESA SAFE format. diff --git a/src/burst2safe/burst2stack.py b/src/burst2safe/burst2stack.py index 508d3aa..795db40 100644 --- a/src/burst2safe/burst2stack.py +++ b/src/burst2safe/burst2stack.py @@ -1,16 +1,16 @@ """A tool for converting stacks of ASF burst SLCs to stacks of SAFEs""" from argparse import ArgumentParser -from collections.abc import Iterable from datetime import datetime from pathlib import Path -from typing import Optional +from typing import Iterable, List, Optional from shapely.geometry import Polygon from burst2safe import utils -from burst2safe.burst2safe import burst2safe -from burst2safe.search import find_stack_orbits +from burst2safe.download import download_bursts +from burst2safe.safe import Safe +from burst2safe.search import find_group DESCRIPTION = """Convert a stack of ASF burst SLCs to a stack of ESA SAFEs. @@ -32,7 +32,7 @@ def burst2stack( all_anns: bool = False, keep_files: bool = False, work_dir: Optional[Path] = None, -) -> Path: +) -> List[Path]: """Convert a stack of burst granules to a stack of ESA SAFEs. Wraps the burst2safe function to handle multiple dates. @@ -41,30 +41,52 @@ def burst2stack( start_date: The start date of the bursts end_date: The end date of the bursts extent: The bounding box of the bursts - swaths: List of swaths to include polarizations: List of polarizations to include + swaths: List of swaths to include mode: The collection mode to use (IW or EW) (default: IW) min_bursts: The minimum number of bursts per swath (default: 1) all_anns: Include product annotation files for all swaths, regardless of included bursts keep_files: Keep the intermediate files work_dir: The directory to create the SAFE in (default: current directory) """ - absolute_orbits = find_stack_orbits(rel_orbit, extent, start_date, end_date) - print(f'Creating SAFEs for {len(absolute_orbits)} time periods...') - for orbit in absolute_orbits: - print() - burst2safe( - granules=None, - orbit=orbit, - extent=extent, - polarizations=polarizations, - swaths=swaths, - mode=mode, - min_bursts=min_bursts, - all_anns=all_anns, - keep_files=keep_files, - work_dir=work_dir, - ) + burst_search_results = find_group( + rel_orbit, + extent, + polarizations, + swaths, + mode, + min_bursts, + use_relative_orbit=True, + start_date=start_date, + end_date=end_date, + ) + burst_infos = utils.get_burst_infos(burst_search_results, work_dir) + abs_orbits = utils.drop_duplicates([burst_info.absolute_orbit for burst_info in burst_infos]) + print(f'Found {len(burst_infos)} burst(s), comprising {len(abs_orbits)} SAFE(s).') + + print('Check burst group validities...') + burst_sets = [[bi for bi in burst_infos if bi.absolute_orbit == orbit] for orbit in abs_orbits] + # Checking burst group validities before download to fail faster + for burst_set in burst_sets: + Safe.check_group_validity(burst_set) + + print('Downloading data...') + download_bursts(burst_infos) + print('Download complete.') + + print('Creating SAFEs...') + safe_paths = [] + for burst_set in burst_sets: + [info.add_shape_info() for info in burst_set] + [info.add_start_stop_utc() for info in burst_set] + safe = Safe(burst_set, all_anns, work_dir) + safe_path = safe.create_safe() + safe_paths.append(safe_path) + if not keep_files: + safe.cleanup() + print('SAFEs creaated!') + + return safe_paths def main() -> None: diff --git a/src/burst2safe/download.py b/src/burst2safe/download.py new file mode 100644 index 0000000..521043f --- /dev/null +++ b/src/burst2safe/download.py @@ -0,0 +1,119 @@ +import asyncio +import os +from pathlib import Path +from typing import Iterable + +import aiohttp +from tenacity import retry, retry_if_result, stop_after_attempt, stop_after_delay, wait_random + +from burst2safe.auth import TOKEN_ENV_VAR, check_earthdata_credentials +from burst2safe.utils import BurstInfo + + +COOKIE_URL = 'https://sentinel1.asf.alaska.edu/METADATA_RAW/SA/S1A_IW_RAW__0SSV_20141229T072718_20141229T072750_003931_004B96_B79F.iso.xml' + + +def get_url_dict(burst_infos: Iterable[BurstInfo], force: bool = False) -> dict: + """Get a dictionary of URLs to download. Keys are save paths, and values are download URLs. + + Args: + burst_infos: A list of BurstInfo objects + force: If True, download even if the file already exists + + Returns: + A dictionary of URLs to download + """ + url_dict = {} + for burst_info in burst_infos: + if force or not burst_info.data_path.exists(): + url_dict[burst_info.data_path] = burst_info.data_url + if force or not burst_info.metadata_path.exists(): + url_dict[burst_info.metadata_path] = burst_info.metadata_url + return url_dict + + +@retry( + reraise=True, retry=retry_if_result(lambda r: r.status == 202), wait=wait_random(0, 1), stop=stop_after_delay(120) +) +async def get_async(session: aiohttp.ClientSession, url: str) -> aiohttp.ClientResponse: + """Retry a GET request until a non-202 response is received + + Args: + session: An aiohttp ClientSession + url: The URL to download + + Returns: + The response object + """ + response = await session.get(url) + response.raise_for_status() + return response + + +@retry(reraise=True, stop=stop_after_attempt(3)) +async def download_burst_url_async(session: aiohttp.ClientSession, url: str, file_path: Path) -> None: + """Retry a burst URL GET request until a non-202 response is received, then download the file. + + Args: + session: An aiohttp ClientSession + url: The URL to download + file_path: The path to save the downloaded data to + """ + response = await get_async(session, url) + + if file_path.suffix in ['.tif', '.tiff']: + returned_filename = response.content_disposition.filename + elif file_path.suffix == '.xml': + url_parts = str(response.url).split('/') + ext = response.content_disposition.filename.split('.')[-1] + returned_filename = f'{url_parts[3]}_{url_parts[5]}.{ext}' + else: + raise ValueError(f'Invalid file extension: {file_path.suffix}') + + if file_path.name != returned_filename: + raise ValueError(f'Race condition encountered, incorrect url returned for file: {file_path.name}') + + try: + with open(file_path, 'wb') as f: + async for chunk in response.content.iter_chunked(2**14): + f.write(chunk) + except Exception as e: + file_path.unlink(missing_ok=True) + raise e + finally: + response.close() + + +async def download_bursts_async(url_dict: dict) -> None: + """Download a dictionary of URLs asynchronously. + + Args: + url_dict: A dictionary of URLs to download + """ + auth_type = check_earthdata_credentials(append=True) + headers = {'Authorization': f'Bearer {os.getenv(TOKEN_ENV_VAR)}'} if auth_type == 'token' else {} + async with aiohttp.ClientSession(headers=headers, trust_env=True) as session: + if auth_type == 'token': + # FIXME: Needed while burst extractor API doesn't support EDL tokens + cookie_response = await session.get(COOKIE_URL) + cookie_response.raise_for_status() + cookie_response.close() + + tasks = [] + for file_path, url in url_dict.items(): + tasks.append(download_burst_url_async(session, url, file_path)) + await asyncio.gather(*tasks) + + +def download_bursts(burst_infos: Iterable[BurstInfo]) -> None: + """Download the burst data and metadata files using an async queue. + + Args: + burst_infos: A list of BurstInfo objects + """ + url_dict = get_url_dict(burst_infos) + asyncio.run(download_bursts_async(url_dict)) + full_dict = get_url_dict(burst_infos, force=True) + missing_data = [x for x in full_dict.keys() if not x.exists] + if missing_data: + raise ValueError(f'Error downloading, missing files: {", ".join(missing_data.name)}') diff --git a/src/burst2safe/search.py b/src/burst2safe/search.py index 500b1d7..74c7de8 100644 --- a/src/burst2safe/search.py +++ b/src/burst2safe/search.py @@ -1,22 +1,15 @@ -"""A package for converting ASF burst SLCs to the SAFE format""" - import warnings from collections.abc import Iterable -from concurrent.futures import ProcessPoolExecutor from datetime import datetime from itertools import product -from multiprocessing import cpu_count from pathlib import Path -from typing import List, Optional +from typing import List, Optional, Tuple import asf_search import numpy as np from asf_search.Products.S1BurstProduct import S1BurstProduct from shapely.geometry import Polygon -from burst2safe.auth import get_earthdata_credentials -from burst2safe.utils import BurstInfo, download_url_with_retries - warnings.filterwarnings('ignore') @@ -39,29 +32,6 @@ def find_granules(granules: Iterable[str]) -> List[S1BurstProduct]: return list(results) -def find_stack_orbits(rel_orbit: int, extent: Polygon, start_date: datetime, end_date: datetime) -> List[int]: - """Find all orbits in a stack using ASF Search. - - Args: - rel_orbit: The relative orbit number of the stack - start_date: The start date of the stack - end_date: The end date of the stack - - Returns: - List of absolute orbit numbers - """ - dataset = asf_search.constants.DATASET.SLC_BURST - search_results = asf_search.geo_search( - dataset=dataset, - relativeOrbit=rel_orbit, - intersectsWith=extent.centroid.wkt, - start=start_date.strftime('%Y-%m-%d'), - end=end_date.strftime('%Y-%m-%d'), - ) - absolute_orbits = list(set([int(result.properties['orbit']) for result in search_results])) - return absolute_orbits - - def add_surrounding_bursts(bursts: List[S1BurstProduct], min_bursts: int) -> List[S1BurstProduct]: """Add bursts to the list to ensure each swath has at least `min_bursts` bursts. All bursts must be from the same absolute orbit, swath, and polarization. @@ -95,28 +65,35 @@ def add_surrounding_bursts(bursts: List[S1BurstProduct], min_bursts: int) -> Lis return search_results -def find_swath_pol_group( - search_results: List[S1BurstProduct], pol: str, swath: Optional[str], min_bursts: int +def get_burst_group( + search_results: List[S1BurstProduct], + pol: str, + swath: Optional[str] = None, + orbit: Optional[int] = None, + min_bursts: int = 0, ) -> List[S1BurstProduct]: - """Find a group of bursts with the same polarization and swath. + """Find a group of bursts with the same polarization, swath and optionally orbit. Add surrounding bursts if the group is too small. Args: search_results: A list of S1BurstProduct objects pol: The polarization to search for swath: The swath to search for + orbit: The absolute orbit number of the bursts min_bursts: The minimum number of bursts per swath Returns: An updated list of S1BurstProduct objects """ + params = [] + if orbit: + search_results = [result for result in search_results if result.properties['orbit'] == orbit] + params.append(f'orbit {orbit}') if swath: search_results = [result for result in search_results if result.properties['burst']['subswath'] == swath] - search_results = [result for result in search_results if result.properties['polarization'] == pol] - - params = [f'polarization {pol}'] - if swath: params.append(f'swath {swath}') + search_results = [result for result in search_results if result.properties['polarization'] == pol] + params.append(f'polarization {pol}') params = ', '.join(params) if not search_results: @@ -131,26 +108,18 @@ def find_swath_pol_group( return search_results -def find_group( - orbit: int, - footprint: Polygon, - polarizations: Optional[Iterable] = None, - swaths: Optional[Iterable] = None, - mode: str = 'IW', - min_bursts: int = 1, -) -> List[S1BurstProduct]: - """Find burst groups using ASF Search. +def sanitize_group_search_inputs( + polarizations: Optional[Iterable] = None, swaths: Optional[Iterable] = None, mode: str = 'IW' +) -> Tuple[List[str], List[str]]: + """Sanitize inputs for group search. Args: - orbit: The absolute orbit number of the bursts - footprint: The bounding box of the bursts polarizations: List of polarizations to include (default: VV) swaths: List of swaths to include (default: all) mode: The collection mode to use (IW or EW) (default: IW) - min_bursts: The minimum number of bursts per swath (default: 1) Returns: - A list of S1BurstProduct objects + A tuple of sanitized polarizations and swaths """ if polarizations is None: polarizations = ['VV'] @@ -172,14 +141,82 @@ def find_group( if bad_swaths: raise ValueError(f'Invalid swaths: {" ".join(bad_swaths)}') - dataset = asf_search.constants.DATASET.SLC_BURST - search_results = asf_search.geo_search( - dataset=dataset, absoluteOrbit=orbit, intersectsWith=footprint.wkt, beamMode=mode - ) - final_results = [] - for pol, swath in product(polarizations, swaths): - sub_results = find_swath_pol_group(search_results, pol, swath, min_bursts) - final_results.extend(sub_results) + return polarizations, swaths + + +def add_missing_bursts( + search_results: List[S1BurstProduct], + polarizations: List[str], + swaths: List[str], + min_bursts: int, + use_relative_orbit: bool, +) -> List[S1BurstProduct]: + """Add missing bursts to the search results to ensure each swath/pol combo has at least `min_bursts` bursts. + + Args: + search_results: A list of S1BurstProduct objects + polarizations: List of polarizations to include + swaths: List of swaths to include + min_bursts: The minimum number of bursts per swath (default: 1) + use_relative_orbit: Use relative orbit number instead of absolute orbit number (default: False) + + Returns: + A list of S1BurstProduct objects + """ + grouped_results = [] + if use_relative_orbit: + absolute_orbits = list(set([int(result.properties['orbit']) for result in search_results])) + group_definitions = product(polarizations, swaths, absolute_orbits) + else: + group_definitions = product(polarizations, swaths) + + for group_definition in group_definitions: + sub_results = get_burst_group(search_results, *group_definition, min_bursts=min_bursts) + grouped_results.extend(sub_results) + return grouped_results + + +def find_group( + orbit: int, + footprint: Polygon, + polarizations: Optional[Iterable] = None, + swaths: Optional[Iterable] = None, + mode: str = 'IW', + min_bursts: int = 1, + start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None, + use_relative_orbit: bool = False, +) -> List[S1BurstProduct]: + """Find burst groups using ASF Search. + + Args: + orbit: The absolute orbit number of the bursts + footprint: The bounding box of the bursts + polarizations: List of polarizations to include (default: VV) + swaths: List of swaths to include (default: all) + mode: The collection mode to use (IW or EW) (default: IW) + min_bursts: The minimum number of bursts per swath (default: 1) + start_date: The start date for relative orbit search + end_date: The end date for relative orbit search + use_relative_orbit: Use relative orbit number instead of absolute orbit number (default: False) + + Returns: + A list of S1BurstProduct objects + """ + if use_relative_orbit and not (start_date and end_date): + raise ValueError('You must provide start and end dates when using relative orbit number.') + + polarizations, swaths = sanitize_group_search_inputs(polarizations, swaths, mode) + opts = dict(dataset=asf_search.constants.DATASET.SLC_BURST, intersectsWith=footprint.wkt, beamMode=mode) + if use_relative_orbit: + opts['relativeOrbit'] = orbit + opts['start'] = (f'{start_date.strftime("%Y-%m-%d")}T00:00:00Z',) + opts['end'] = (f'{end_date.strftime("%Y-%m-%d")}T23:59:59Z',) + else: + opts['absoluteOrbit'] = orbit + search_results = asf_search.geo_search(**opts) + + final_results = add_missing_bursts(search_results, polarizations, swaths, min_bursts, use_relative_orbit) return final_results @@ -214,27 +251,3 @@ def find_bursts( 'You must provide either a list of granules or minimum set of group parameters (orbit, and footprint).' ) return results - - -def download_bursts(burst_infos: Iterable[BurstInfo]) -> None: - """Download the burst data and metadata files using multiple workers. - - Args: - burst_infos: A list of BurstInfo objects - """ - downloads = {} - for burst_info in burst_infos: - downloads[burst_info.data_path] = burst_info.data_url - downloads[burst_info.metadata_path] = burst_info.metadata_url - download_info = [(value, key.parent, key.name) for key, value in downloads.items()] - urls, dirs, names = zip(*download_info) - - username, password = get_earthdata_credentials() - session = asf_search.ASFSession().auth_with_creds(username, password) - n_workers = min(len(urls), max(cpu_count() - 2, 1)) - if n_workers == 1: - for url, dir, name in zip(urls, dirs, names): - download_url_with_retries(url, dir, name, session) - else: - with ProcessPoolExecutor(max_workers=n_workers) as executor: - executor.map(download_url_with_retries, urls, dirs, names, [session] * len(urls)) diff --git a/src/burst2safe/utils.py b/src/burst2safe/utils.py index a565299..3b04e6c 100644 --- a/src/burst2safe/utils.py +++ b/src/burst2safe/utils.py @@ -8,7 +8,6 @@ from pathlib import Path from typing import Dict, List, Optional -import asf_search import lxml.etree as ET from asf_search.Products.S1BurstProduct import S1BurstProduct from osgeo import gdal, ogr, osr @@ -223,31 +222,6 @@ def get_subxml_from_metadata( return desired_metadata -def download_url_with_retries( - url: str, path: str, filename: str = None, session: asf_search.ASFSession = None, max_retries: int = 3 -) -> None: - """Download a file using asf_search.download_url with retries and backoff. - - Args: - url: The URL to download - path: The path to save the file to - filename: The name of the file to save - session: The ASF session to use - max_retries: The maximum number of retries - """ - n_retries = 0 - file_exists = False - while n_retries < max_retries and not file_exists: - asf_search.download_url(url, path, filename, session) - - n_retries += 1 - if Path(path, filename).exists(): - file_exists = True - - if not file_exists: - raise ValueError(f'Failed to download {filename} after {max_retries} attempts.') - - def flatten(list_of_lists: List[List]) -> List: """Flatten a list of lists.""" return [item for sublist in list_of_lists for item in sublist] diff --git a/tests/test_auth.py b/tests/test_auth.py index 5bbf0e8..95bd60d 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -1,5 +1,7 @@ from pathlib import Path +import pytest + from burst2safe import auth @@ -35,3 +37,55 @@ def test_find_creds_in_netrc(tmp_path, monkeypatch): m.setattr(auth, 'get_netrc', lambda: tmp_path / '.netrc') (tmp_path / '.netrc').write_text('') assert auth.find_creds_in_netrc('test') == (None, None) + + +def test_write_credentials_to_netrc_file(tmp_path, monkeypatch): + with monkeypatch.context() as m: + m.setattr(auth, 'get_netrc', lambda: tmp_path / '.netrc') + auth.write_credentials_to_netrc_file('foo', 'bar') + assert (tmp_path / '.netrc').read_text() == 'machine urs.earthdata.nasa.gov login foo password bar\n' + + +def test_check_earthdata_credentials_token(tmp_path, monkeypatch): + with monkeypatch.context() as m: + m.setenv('EARTHDATA_TOKEN', 'foo') + assert auth.check_earthdata_credentials() == 'token' + + +def test_check_earthdata_credentials_netrc(tmp_path, monkeypatch): + netrc_path = tmp_path / '.netrc' + netrc_path.touch() + netrc_path.write_text('machine urs.earthdata.nasa.gov login foo password bar\n') + with monkeypatch.context() as m: + m.delenv('EARTHDATA_TOKEN', raising=False) + m.setenv('EARTHDATA_USERNAME', 'baz') + m.setenv('EARTHDATA_PASSWORD', 'buzz') + m.setattr(auth, 'get_netrc', lambda: netrc_path) + assert auth.check_earthdata_credentials() == 'netrc' + netrc_path.read_text() == 'machine urs.earthdata.nasa.gov login foo password bar\n' + + +def test_check_earthdata_credentials_env(tmp_path, monkeypatch): + netrc_path = tmp_path / '.netrc' + with monkeypatch.context() as m: + m.delenv('EARTHDATA_TOKEN', raising=False) + m.setenv('EARTHDATA_USERNAME', 'baz') + m.setenv('EARTHDATA_PASSWORD', 'buzz') + m.setattr(auth, 'get_netrc', lambda: netrc_path) + + with pytest.raises(ValueError, match='NASA Earthdata credentials only found in environment variables*'): + auth.check_earthdata_credentials() + + assert auth.check_earthdata_credentials(append=True) == 'netrc' + netrc_path.read_text() == 'machine urs.earthdata.nasa.gov login baz password buzz\n' + + +def test_check_earthdata_credentials_none(tmp_path, monkeypatch): + netrc_path = tmp_path / '.netrc' + with monkeypatch.context() as m: + m.delenv('EARTHDATA_TOKEN', raising=False) + m.delenv('EARTHDATA_USERNAME', raising=False) + m.delenv('EARTHDATA_PASSWORD', raising=False) + m.setattr(auth, 'get_netrc', lambda: netrc_path) + with pytest.raises(ValueError, match='Please provide NASA Earthdata credentials*'): + auth.check_earthdata_credentials() diff --git a/tests/test_download.py b/tests/test_download.py new file mode 100644 index 0000000..a4a4ad8 --- /dev/null +++ b/tests/test_download.py @@ -0,0 +1,34 @@ +from collections import namedtuple + +from burst2safe import download + + +def test_get_url_dict(tmp_path): + DummyBurst = namedtuple('DummyBurst', ['data_path', 'data_url', 'metadata_path', 'metadata_url']) + burst_infos = [ + DummyBurst( + data_path=tmp_path / 'data1.tif', + data_url='http://data1.tif', + metadata_path=tmp_path / 'metadata1.xml', + metadata_url='http://metadata1.xml', + ), + DummyBurst( + data_path=tmp_path / 'data2.tiff', + data_url='http://data2.tiff', + metadata_path=tmp_path / 'metadata2.xml', + metadata_url='http://metadata2.xml', + ), + ] + url_dict = download.get_url_dict(burst_infos) + expected = { + tmp_path / 'data1.tif': 'http://data1.tif', + tmp_path / 'metadata1.xml': 'http://metadata1.xml', + tmp_path / 'data2.tiff': 'http://data2.tiff', + tmp_path / 'metadata2.xml': 'http://metadata2.xml', + } + assert url_dict == expected + + del expected[tmp_path / 'data1.tif'] + (tmp_path / 'data1.tif').touch() + url_dict = download.get_url_dict(burst_infos) + assert url_dict == expected diff --git a/tests/test_search.py b/tests/test_search.py index 6ff5bf0..fbb0183 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -34,3 +34,21 @@ def test_add_surrounding_bursts(product): mock_search.assert_called_once_with( dataset='SLC-BURST', absoluteOrbit=1, polarization='VV', fullBurstID=burst_ids ) + + +def test_sanitize_group_search_inputs(): + pols, swaths = search.sanitize_group_search_inputs() + assert pols == ['VV'] + assert swaths == [None] + + assert search.sanitize_group_search_inputs(polarizations=['HH'])[0] == ['HH'] + assert search.sanitize_group_search_inputs(swaths=['IW2'])[1] == ['IW2'] + + with pytest.raises(ValueError, match='Invalid polarization*'): + search.sanitize_group_search_inputs(polarizations=['VV', 'BB']) + + with pytest.raises(ValueError, match='Invalid swath*'): + search.sanitize_group_search_inputs(swaths=['IW1'], mode='EW') + + with pytest.raises(ValueError, match='Invalid swath*'): + search.sanitize_group_search_inputs(swaths=['EW1'], mode='IW')