Skip to content

Commit 5c5639d

Browse files
authored
Remove GDAL and RichDEM dependancy from tests (#675)
1 parent c316ff1 commit 5c5639d

8 files changed

+76
-313
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ xdem/_version.py
146146

147147
# Example data downloaded/produced during tests
148148
examples/data/
149+
tests/test_data/
149150

150151
doc/source/basic_examples/
151152
doc/source/advanced_examples/

Makefile

+9-23
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,18 @@ ifndef VENV
1111
VENV = "venv"
1212
endif
1313

14-
# Python version requirement
15-
PYTHON_VERSION_REQUIRED = 3.10
16-
14+
# Python global variables definition
15+
PYTHON_VERSION_MIN = 3.10
16+
# Set PYTHON if not defined in command line
17+
# Example: PYTHON="python3.10" make venv to use python 3.10 for the venv
18+
# By default the default python3 of the system.
1719
ifndef PYTHON
18-
# Try to find python version required
19-
PYTHON = "python$(PYTHON_VERSION_REQUIRED)"
20+
PYTHON = "python3"
2021
endif
2122
PYTHON_CMD=$(shell command -v $(PYTHON))
2223

23-
PYTHON_VERSION_CUR=$(shell $(PYTHON_CMD) -c 'import sys; print("%d.%d" % sys.version_info[0:2])')
24-
PYTHON_VERSION_OK=$(shell $(PYTHON_CMD) -c 'import sys; req_ver = tuple(map(int, "$(PYTHON_VERSION_REQUIRED)".split("."))); cur_ver = sys.version_info[0:2]; print(int(cur_ver == req_ver))')
24+
PYTHON_VERSION_CUR=$(shell $(PYTHON_CMD) -c 'import sys; print("%d.%d"% sys.version_info[0:2])')
25+
PYTHON_VERSION_OK=$(shell $(PYTHON_CMD) -c 'import sys; cur_ver = sys.version_info[0:2]; min_ver = tuple(map(int, "$(PYTHON_VERSION_MIN)".split("."))); print(int(cur_ver >= min_ver))')
2526

2627
############### Check python version supported ############
2728

@@ -30,7 +31,7 @@ ifeq (, $(PYTHON_CMD))
3031
endif
3132

3233
ifeq ($(PYTHON_VERSION_OK), 0)
33-
$(error "Requires Python version == $(PYTHON_VERSION_REQUIRED). Current version is $(PYTHON_VERSION_CUR)")
34+
$(error "Requires Python version >= $(PYTHON_VERSION_MIN). Current version is $(PYTHON_VERSION_CUR)")
3435
endif
3536

3637
################ MAKE Targets ######################
@@ -45,19 +46,6 @@ venv: ## Create a virtual environment in 'venv' directory if it doesn't exist
4546
@touch ${VENV}/bin/activate
4647
@${VENV}/bin/python -m pip install --upgrade wheel setuptools pip
4748

48-
.PHONY: install-gdal
49-
install-gdal: ## Install GDAL version matching the system's GDAL via pip
50-
@if command -v gdalinfo >/dev/null 2>&1; then \
51-
GDAL_VERSION=$$(gdalinfo --version | awk '{print $$2}'); \
52-
echo "System GDAL version: $$GDAL_VERSION"; \
53-
${VENV}/bin/pip install gdal==$$GDAL_VERSION; \
54-
else \
55-
echo "Warning: GDAL not found on the system. Proceeding without GDAL."; \
56-
echo "Try installing GDAL by running the following commands depending on your system:"; \
57-
echo "Debian/Ubuntu: sudo apt-get install -y gdal-bin libgdal-dev"; \
58-
echo "Red Hat/CentOS: sudo yum install -y gdal gdal-devel"; \
59-
echo "Then run 'make install-gdal' to proceed with GDAL installation."; \
60-
fi
6149

6250
.PHONY: install
6351
install: venv ## Install xDEM for development (depends on venv)
@@ -66,8 +54,6 @@ install: venv ## Install xDEM for development (depends on venv)
6654
@test -f .git/hooks/pre-commit || echo "Installing pre-commit hooks"
6755
@test -f .git/hooks/pre-commit || ${VENV}/bin/pre-commit install -t pre-commit
6856
@test -f .git/hooks/pre-push || ${VENV}/bin/pre-commit install -t pre-push
69-
@echo "Attempting to install GDAL..."
70-
@make install-gdal
7157
@echo "xdem installed in development mode in virtualenv ${VENV}"
7258
@echo "To use: source ${VENV}/bin/activate; xdem -h"
7359

dev-environment.yml

-2
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,11 @@ dependencies:
2828
- scikit-learn
2929

3030
# Test dependencies
31-
- gdal # To test against GDAL
3231
- pytest
3332
- pytest-xdist
3433
- pyyaml
3534
- flake8
3635
- pylint
37-
- richdem # To test against richdem
3836

3937
# Doc dependencies
4038
- sphinx

setup.cfg

-1
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ test =
6161
flake8
6262
pylint
6363
scikit-learn
64-
richdem
6564
doc =
6665
sphinx
6766
sphinx-book-theme

tests/conftest.py

+15-132
Original file line numberDiff line numberDiff line change
@@ -1,142 +1,25 @@
1-
from typing import Callable, List, Union
1+
import os
2+
from typing import Callable
23

3-
import geoutils as gu
4-
import numpy as np
54
import pytest
6-
import richdem as rd
7-
from geoutils.raster import RasterType
85

9-
from xdem._typing import NDArrayf
6+
from xdem.examples import download_and_extract_tarball
107

11-
12-
@pytest.fixture(scope="session") # type: ignore
13-
def raster_to_rda() -> Callable[[RasterType], rd.rdarray]:
14-
def _raster_to_rda(rst: RasterType) -> rd.rdarray:
15-
"""
16-
Convert geoutils.Raster to richDEM rdarray.
17-
"""
18-
arr = rst.data.filled(rst.nodata).squeeze()
19-
rda = rd.rdarray(arr, no_data=rst.nodata)
20-
rda.geotransform = rst.transform.to_gdal()
21-
return rda
22-
23-
return _raster_to_rda
24-
25-
26-
@pytest.fixture(scope="session") # type: ignore
27-
def get_terrainattr_richdem(raster_to_rda: Callable[[RasterType], rd.rdarray]) -> Callable[[RasterType, str], NDArrayf]:
28-
def _get_terrainattr_richdem(rst: RasterType, attribute: str = "slope_radians") -> NDArrayf:
29-
"""
30-
Derive terrain attribute for DEM opened with geoutils.Raster using RichDEM.
31-
"""
32-
rda = raster_to_rda(rst)
33-
terrattr = rd.TerrainAttribute(rda, attrib=attribute)
34-
terrattr[terrattr == terrattr.no_data] = np.nan
35-
return np.array(terrattr)
36-
37-
return _get_terrainattr_richdem
8+
_TESTDATA_DIRECTORY = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "tests", "test_data"))
389

3910

4011
@pytest.fixture(scope="session") # type: ignore
41-
def get_terrain_attribute_richdem(
42-
get_terrainattr_richdem: Callable[[RasterType, str], NDArrayf]
43-
) -> Callable[[RasterType, Union[str, list[str]], bool, float, float, float], Union[RasterType, list[RasterType]]]:
44-
def _get_terrain_attribute_richdem(
45-
dem: RasterType,
46-
attribute: Union[str, List[str]],
47-
degrees: bool = True,
48-
hillshade_altitude: float = 45.0,
49-
hillshade_azimuth: float = 315.0,
50-
hillshade_z_factor: float = 1.0,
51-
) -> Union[RasterType, List[RasterType]]:
52-
"""
53-
Derive one or multiple terrain attributes from a DEM using RichDEM.
54-
"""
55-
if isinstance(attribute, str):
56-
attribute = [attribute]
57-
58-
if not isinstance(dem, gu.Raster):
59-
raise ValueError("DEM must be a geoutils.Raster object.")
60-
61-
terrain_attributes = {}
62-
63-
# Check which products should be made to optimize the processing
64-
make_aspect = any(attr in attribute for attr in ["aspect", "hillshade"])
65-
make_slope = any(
66-
attr in attribute
67-
for attr in [
68-
"slope",
69-
"hillshade",
70-
"planform_curvature",
71-
"aspect",
72-
"profile_curvature",
73-
"maximum_curvature",
74-
]
75-
)
76-
make_hillshade = "hillshade" in attribute
77-
make_curvature = "curvature" in attribute
78-
make_planform_curvature = "planform_curvature" in attribute or "maximum_curvature" in attribute
79-
make_profile_curvature = "profile_curvature" in attribute or "maximum_curvature" in attribute
80-
81-
if make_slope:
82-
terrain_attributes["slope"] = get_terrainattr_richdem(dem, "slope_radians")
83-
84-
if make_aspect:
85-
# The aspect of RichDEM is returned in degrees, we convert to radians to match the others
86-
terrain_attributes["aspect"] = np.deg2rad(get_terrainattr_richdem(dem, "aspect"))
87-
# For flat slopes, RichDEM returns a 90° aspect by default, while GDAL return a 180° aspect
88-
# We stay consistent with GDAL
89-
slope_tmp = get_terrainattr_richdem(dem, "slope_radians")
90-
terrain_attributes["aspect"][slope_tmp == 0] = np.pi
91-
92-
if make_hillshade:
93-
# If a different z-factor was given, slopemap with exaggerated gradients.
94-
if hillshade_z_factor != 1.0:
95-
slopemap = np.arctan(np.tan(terrain_attributes["slope"]) * hillshade_z_factor)
96-
else:
97-
slopemap = terrain_attributes["slope"]
98-
99-
azimuth_rad = np.deg2rad(360 - hillshade_azimuth)
100-
altitude_rad = np.deg2rad(hillshade_altitude)
101-
102-
# The operation below yielded the closest hillshade to GDAL (multiplying by 255 did not work)
103-
# As 0 is generally no data for this uint8, we add 1 and then 0.5 for the rounding to occur between
104-
# 1 and 255
105-
terrain_attributes["hillshade"] = np.clip(
106-
1.5
107-
+ 254
108-
* (
109-
np.sin(altitude_rad) * np.cos(slopemap)
110-
+ np.cos(altitude_rad) * np.sin(slopemap) * np.sin(azimuth_rad - terrain_attributes["aspect"])
111-
),
112-
0,
113-
255,
114-
).astype("float32")
115-
116-
if make_curvature:
117-
terrain_attributes["curvature"] = get_terrainattr_richdem(dem, "curvature")
118-
119-
if make_planform_curvature:
120-
terrain_attributes["planform_curvature"] = get_terrainattr_richdem(dem, "planform_curvature")
121-
122-
if make_profile_curvature:
123-
terrain_attributes["profile_curvature"] = get_terrainattr_richdem(dem, "profile_curvature")
124-
125-
# Convert the unit if wanted.
126-
if degrees:
127-
for attr in ["slope", "aspect"]:
128-
if attr not in terrain_attributes:
129-
continue
130-
terrain_attributes[attr] = np.rad2deg(terrain_attributes[attr])
131-
132-
output_attributes = [terrain_attributes[key].reshape(dem.shape) for key in attribute]
12+
def get_test_data_path() -> Callable[[str], str]:
13+
def _get_test_data_path(filename: str, overwrite: bool = False) -> str:
14+
"""Get file from test_data"""
15+
download_and_extract_tarball(dir="test_data", target_dir=_TESTDATA_DIRECTORY, overwrite=overwrite)
16+
file_path = os.path.join(_TESTDATA_DIRECTORY, filename)
13317

134-
if isinstance(dem, gu.Raster):
135-
output_attributes = [
136-
gu.Raster.from_array(attr, transform=dem.transform, crs=dem.crs, nodata=-99999)
137-
for attr in output_attributes
138-
]
18+
if not os.path.exists(file_path):
19+
if overwrite:
20+
raise FileNotFoundError(f"The file {filename} was not found in the test_data directory.")
21+
file_path = _get_test_data_path(filename, overwrite=True)
13922

140-
return output_attributes if len(output_attributes) > 1 else output_attributes[0]
23+
return file_path
14124

142-
return _get_terrain_attribute_richdem
25+
return _get_test_data_path

tests/test_coreg/test_affine.py

+4-50
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
import os.path
56
import warnings
67

78
import geopandas as gpd
@@ -11,7 +12,6 @@
1112
import pytransform3d
1213
import rasterio as rio
1314
from geoutils import Raster, Vector
14-
from geoutils._typing import NDArrayNum
1515
from geoutils.raster import RasterType
1616
from geoutils.raster.geotransformations import _translate
1717
from scipy.ndimage import binary_dilation
@@ -42,53 +42,6 @@ def load_examples(crop: bool = True) -> tuple[RasterType, RasterType, Vector]:
4242
return reference_dem, to_be_aligned_dem, glacier_mask
4343

4444

45-
def gdal_reproject_horizontal_shift_samecrs(filepath_example: str, xoff: float, yoff: float) -> NDArrayNum:
46-
"""
47-
Reproject horizontal shift in same CRS with GDAL for testing purposes.
48-
49-
:param filepath_example: Path to raster file.
50-
:param xoff: X shift in georeferenced unit.
51-
:param yoff: Y shift in georeferenced unit.
52-
53-
:return: Reprojected shift array in the same CRS.
54-
"""
55-
56-
from osgeo import gdal, gdalconst
57-
58-
# Open source raster from file
59-
src = gdal.Open(filepath_example, gdalconst.GA_ReadOnly)
60-
61-
# Create output raster in memory
62-
driver = "MEM"
63-
method = gdal.GRA_Bilinear
64-
drv = gdal.GetDriverByName(driver)
65-
dest = drv.Create("", src.RasterXSize, src.RasterYSize, 1, gdal.GDT_Float32)
66-
proj = src.GetProjection()
67-
ndv = src.GetRasterBand(1).GetNoDataValue()
68-
dest.SetProjection(proj)
69-
70-
# Shift the horizontally shifted geotransform
71-
gt = src.GetGeoTransform()
72-
gtl = list(gt)
73-
gtl[0] += xoff
74-
gtl[3] += yoff
75-
dest.SetGeoTransform(tuple(gtl))
76-
77-
# Copy the raster metadata of the source to dest
78-
dest.SetMetadata(src.GetMetadata())
79-
dest.GetRasterBand(1).SetNoDataValue(ndv)
80-
dest.GetRasterBand(1).Fill(ndv)
81-
82-
# Reproject with resampling
83-
gdal.ReprojectImage(src, dest, proj, proj, method)
84-
85-
# Extract reprojected array
86-
array = dest.GetRasterBand(1).ReadAsArray().astype("float32")
87-
array[array == ndv] = np.nan
88-
89-
return array
90-
91-
9245
class TestAffineCoreg:
9346

9447
ref, tba, outlines = load_examples() # Load example reference, to-be-aligned and mask.
@@ -121,7 +74,7 @@ class TestAffineCoreg:
12174
"xoff_yoff",
12275
[(ref.res[0], ref.res[1]), (10 * ref.res[0], 10 * ref.res[1]), (-1.2 * ref.res[0], -1.2 * ref.res[1])],
12376
) # type: ignore
124-
def test_reproject_horizontal_shift_samecrs__gdal(self, xoff_yoff: tuple[float, float]) -> None:
77+
def test_reproject_horizontal_shift_samecrs__gdal(self, xoff_yoff: tuple[float, float], get_test_data_path) -> None:
12578
"""Check that the same-CRS reprojection based on SciPy (replacing Rasterio due to subpixel errors)
12679
is accurate by comparing to GDAL."""
12780

@@ -135,7 +88,8 @@ def test_reproject_horizontal_shift_samecrs__gdal(self, xoff_yoff: tuple[float,
13588
)
13689

13790
# Reproject with GDAL
138-
output2 = gdal_reproject_horizontal_shift_samecrs(filepath_example=ref.filename, xoff=xoff, yoff=yoff)
91+
path_output2 = get_test_data_path(os.path.join("gdal", f"shifted_reprojected_xoff{xoff}_yoff{yoff}.tif"))
92+
output2 = Raster(path_output2).data.data
13993

14094
# Reproject and NaN propagation is exactly the same for shifts that are a multiple of pixel resolution
14195
if xoff % ref.res[0] == 0 and yoff % ref.res[1] == 0:

0 commit comments

Comments
 (0)