Skip to content

Commit

Permalink
solve the issue of fig_share tests were not running with other tests
Browse files Browse the repository at this point in the history
  • Loading branch information
MAfarrag committed Jan 6, 2025
1 parent 1d0a261 commit 0ec4ccb
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 22 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ markers = [
"e2e: end-to-end test (deselect with '-m \"not e2e\"')",
"mock: mock test (deselect with '-m \"not mock\"')",
"integration: mock test (deselect with '-m \"not integration\"')",
"fig_share: mock test (deselect with '-m \"not fig_share\"')"
]


Expand Down
3 changes: 3 additions & 0 deletions src/Hapi/parameters/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,7 @@ def download_files(
Path(download_dir) if isinstance(download_dir, str) else download_dir
)
files = self.list_files(set_id, version)

for file in files:
dest_path = download_dir / file["name"]
FileManager.download_file(file["download_url"], dest_path)
Expand Down Expand Up @@ -682,6 +683,8 @@ def get_parameter_set(

if download_dir is None:
download_dir = self.download_dir / f"{set_id}"
else:
download_dir = Path(download_dir) / f"{set_id}"

self.manager.download_files(set_id, download_dir, self.version)
logger.debug(f"Downloaded parameter set: {set_id} to {download_dir}")
Expand Down
20 changes: 12 additions & 8 deletions tests/rrm/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,34 +3,38 @@

import geopandas as gpd
import numpy as np
import pytest
from geopandas import GeoDataFrame
from osgeo import gdal

from Hapi.parameters.parameters import Parameter
from tests.rrm.calibration.conftest import *
from tests.rrm.catchment.conftest import *

data_dir = os.getenv("HAPI_DATA_DIR")
@pytest.fixture(scope="session")
def hapi_data_dir() -> str:
data_dir = os.getenv("HAPI_DATA_DIR")

if not os.path.exists(data_dir):
raise FileNotFoundError("HAPI_DATA_DIR not found")
if data_dir is None or not os.path.exists(data_dir):
raise ValueError("please set the `HAPI_DATA_DIR` emvironment variable")
return data_dir


@pytest.fixture(scope="session")
def download_03_parameter():
def download_03_parameter(hapi_data_dir: str):
"""Download Parameter Set 03"""
if not os.path.exists(f"{data_dir}/3"):
if not os.path.exists(f"{hapi_data_dir}/3"):
par = Parameter()
par.get_parameter_set(3)


@pytest.fixture(scope="session")
def download_max_min_parameter():
def download_max_min_parameter(hapi_data_dir: str):
"""Download Parameter Set 03"""
par = Parameter()
if not os.path.exists(f"{data_dir}/max"):
if not os.path.exists(f"{hapi_data_dir}/max"):
par.get_parameter_set("max")
if not os.path.exists(f"{data_dir}/min"):
if not os.path.exists(f"{hapi_data_dir}/min"):
par.get_parameter_set("min")


Expand Down
30 changes: 17 additions & 13 deletions tests/rrm/parameters/test_parameters.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import os
import shutil
from pathlib import Path
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch

import pytest

from Hapi import __file__ as hapi_init
from Hapi.parameters.parameters import (
FigshareAPIClient,
FileManager,
Expand Down Expand Up @@ -174,10 +174,9 @@ def test_download_files(self, parameter_manager, mock_api_client, tmp_path):
{"name": "file2.txt", "download_url": "http://example.com/file2"},
]
}
mock_download = MagicMock()
FileManager.download_file = mock_download

parameter_manager.download_files(set_id=1, download_dir=tmp_path)
with patch("Hapi.parameters.parameters.FileManager.download_file") as mock_download:
parameter_manager.download_files(set_id=1, download_dir=tmp_path)

mock_api_client.send_request.assert_called_once_with("GET", "articles/19999901")
assert mock_download.call_count == 2, "Two files should be downloaded."
Expand Down Expand Up @@ -271,6 +270,10 @@ def test_integration_download_files(self, parameter_manager):
assert (
len(downloaded_files) == 19
), "Files should be downloaded to the specified directory."
try:
shutil.rmtree(int_test_dir)
except PermissionError:
pass

def test_integration_get_article_id(self, parameter_manager):
"""Integration test for mapping a friendly ID to an article ID."""
Expand Down Expand Up @@ -305,13 +308,12 @@ def int_test_dir(self, tmp_path):
"""Provide a temporary directory for testing file downloads."""
return tmp_path / "integration_test_parameters"

@pytest.mark.fig_share
def test_integration_get_parameters(self, int_test_dir):
def test_integration_get_parameters(self):
"""Integration test for downloading all parameter sets."""
parameter = Parameter(version=1)
int_test_dir.mkdir(parents=True, exist_ok=True)
int_test_dir = parameter.download_dir

parameter.get_parameters(int_test_dir)
parameter.get_parameters()

downloaded_files = list(int_test_dir.glob("**/*"))
assert (
Expand All @@ -325,8 +327,8 @@ def test_integration_get_parameter_set_with_download_dir(self, int_test_dir):
int_test_dir.mkdir(parents=True, exist_ok=True)

parameter.get_parameter_set(1, int_test_dir)

downloaded_files = list(int_test_dir.glob("**/*"))

assert (
len(downloaded_files) > 0
), "Parameter sets should be downloaded to the specified directory."
Expand Down Expand Up @@ -374,12 +376,14 @@ def mock_file_manager(self):
@pytest.fixture
def parameter(self, mock_parameter_manager):
"""Fixture to provide a Parameter instance with a mocked ParameterManager."""
parameter_instance = Parameter(version=1)
with patch("os.getenv", return_value="/mocked/path/to/data"):
parameter_instance = Parameter(version=1)
parameter_instance.manager = mock_parameter_manager
return parameter_instance

def test_get_parameters(self, parameter, mock_parameter_manager, tmp_path):
"""Test downloading all parameter sets."""

parameter.get_parameters(tmp_path)

# Ensure download_files was called for each parameter set ID
Expand All @@ -388,7 +392,7 @@ def test_get_parameters(self, parameter, mock_parameter_manager, tmp_path):
), "download_files should be called for each parameter set ID."
for set_id in ParameterManager.PARAMETER_SET_ID:
mock_parameter_manager.download_files.assert_any_call(
set_id, tmp_path, parameter.version
set_id, Path(f"{tmp_path}/{set_id}"), parameter.version
)

def test_get_parameter_set(self, parameter, mock_parameter_manager, tmp_path):
Expand All @@ -398,7 +402,7 @@ def test_get_parameter_set(self, parameter, mock_parameter_manager, tmp_path):

# Ensure download_files was called with the correct arguments
mock_parameter_manager.download_files.assert_called_once_with(
set_id, tmp_path, parameter.version
set_id, Path(f"{tmp_path}/{set_id}"), parameter.version
)

def test_list_parameter_names(self):
Expand Down

0 comments on commit 0ec4ccb

Please sign in to comment.