diff --git a/CHANGELOG.md b/CHANGELOG.md index 6ee01f5..a8b3504 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Fixed + +- Removed last grouping step when using `keep_all_tags` parameter with `GroupedOsmTagsFilter` filter + ## [0.4.1] - 2024-01-31 ### Changed diff --git a/quackosm/_osm_tags_filters.py b/quackosm/_osm_tags_filters.py index 6da8fba..8b951d8 100644 --- a/quackosm/_osm_tags_filters.py +++ b/quackosm/_osm_tags_filters.py @@ -52,14 +52,12 @@ def merge_osm_tags_filter( elif is_expected_type(osm_tags_filter, GroupedOsmTagsFilter): return _merge_grouped_osm_tags_filter(cast(GroupedOsmTagsFilter, osm_tags_filter)) elif is_expected_type(osm_tags_filter, Iterable): - return _merge_multiple_osm_tags_filters( - [ - merge_osm_tags_filter( - cast(Union[OsmTagsFilter, GroupedOsmTagsFilter], sub_osm_tags_filter) - ) - for sub_osm_tags_filter in osm_tags_filter - ] - ) + return _merge_multiple_osm_tags_filters([ + merge_osm_tags_filter( + cast(Union[OsmTagsFilter, GroupedOsmTagsFilter], sub_osm_tags_filter) + ) + for sub_osm_tags_filter in osm_tags_filter + ]) raise AttributeError( "Provided tags don't match required type definitions" diff --git a/quackosm/_rich_progress.py b/quackosm/_rich_progress.py index dab370b..454362e 100644 --- a/quackosm/_rich_progress.py +++ b/quackosm/_rich_progress.py @@ -68,7 +68,7 @@ def render(self, task: "Task") -> Text: elif task.speed >= 1: return Text(f"{task.speed:.2f} it/s") else: - return Text(f"{1/task.speed:.2f} s/it") # noqa: FURB126 + return Text(f"{1/task.speed:.2f} s/it") # noqa: FURB126 self.progress = Progress( SpinnerColumn(), diff --git a/quackosm/osm_extracts/geofabrik.py b/quackosm/osm_extracts/geofabrik.py index 0c17eec..201c451 100644 --- a/quackosm/osm_extracts/geofabrik.py +++ b/quackosm/osm_extracts/geofabrik.py @@ -16,6 +16,7 @@ __all__ = ["_get_geofabrik_index"] + def _get_geofabrik_index() -> gpd.GeoDataFrame: global GEOFABRIK_INDEX_GDF # noqa: PLW0603 @@ -39,7 +40,9 @@ def _load_geofabrik_index() -> gpd.GeoDataFrame: else: result = requests.get( GEOFABRIK_INDEX_URL, - headers={"User-Agent": "QuackOSM Python package (https://github.com/kraina-ai/quackosm)"}, + headers={ + "User-Agent": "QuackOSM Python package (https://github.com/kraina-ai/quackosm)" + }, ) parsed_data = json.loads(result.text) gdf = gpd.GeoDataFrame.from_features(parsed_data["features"]) diff --git a/quackosm/pbf_file_reader.py b/quackosm/pbf_file_reader.py index 0b24b35..987d45b 100644 --- a/quackosm/pbf_file_reader.py +++ b/quackosm/pbf_file_reader.py @@ -1659,7 +1659,9 @@ def _concatenate_results_to_geoparquet( FROM ({parsed_geometries.sql_query()}) """) - grouped_features = self._parse_features_relation_to_groups(unioned_features, explode_tags) + grouped_features = self._parse_features_relation_to_groups( + unioned_features, keep_all_tags=keep_all_tags, explode_tags=explode_tags + ) valid_features_full_relation = self.connection.sql(f""" SELECT * FROM ({grouped_features.sql_query()}) @@ -1870,6 +1872,7 @@ def _parse_features_relation_to_groups( self, features_relation: "duckdb.DuckDBPyRelation", explode_tags: bool, + keep_all_tags: bool, ) -> "duckdb.DuckDBPyRelation": """ Optionally group raw OSM features into groups defined in `GroupedOsmTagsFilter`. @@ -1883,11 +1886,19 @@ def _parse_features_relation_to_groups( Args: features_relation (duckdb.DuckDBPyRelation): Generated features from the loader. explode_tags (bool): Whether to split tags into columns based on OSM tag keys. + keep_all_tags (bool): Works only with the `tags_filter` parameter. + Whether to keep all tags related to the element, or return only those defined + in the `tags_filter`. When `True`, will override the optional grouping defined + in the `tags_filter`. Defaults to `False`. Returns: duckdb.DuckDBPyRelation: Parsed features_relation. """ - if not self.tags_filter or not is_expected_type(self.tags_filter, GroupedOsmTagsFilter): + if ( + not self.tags_filter + or not is_expected_type(self.tags_filter, GroupedOsmTagsFilter) + or keep_all_tags + ): return features_relation grouped_features_relation: "duckdb.DuckDBPyRelation" diff --git a/tests/base/test_pbf_file_reader.py b/tests/base/test_pbf_file_reader.py index cea2768..0b97832 100644 --- a/tests/base/test_pbf_file_reader.py +++ b/tests/base/test_pbf_file_reader.py @@ -101,6 +101,24 @@ def test_pbf_reader( ) +@pytest.mark.parametrize("tags_filter", [None, HEX2VEC_FILTER, GEOFABRIK_LAYERS]) # type: ignore +@pytest.mark.parametrize("explode_tags", [None, True, False]) # type: ignore +@pytest.mark.parametrize("keep_all_tags", [True, False]) # type: ignore +def test_pbf_to_geoparquet_parsing( + tags_filter: Optional[Union[OsmTagsFilter, GroupedOsmTagsFilter]], + explode_tags: Optional[bool], + keep_all_tags: bool, +): + """Test if pbf to geoparquet conversion works.""" + pbf_file = Path(__file__).parent.parent / "test_files" / "monaco.osm.pbf" + PbfFileReader(tags_filter=tags_filter).get_features_gdf( + file_paths=pbf_file, + ignore_cache=True, + explode_tags=explode_tags, + keep_all_tags=keep_all_tags, + ) + + def test_pbf_reader_geometry_filtering(): # type: ignore """Test proper spatial data filtering in `PbfFileReader`.""" file_name = "d17f922ed15e9609013a6b895e1e7af2d49158f03586f2c675d17b760af3452e.osm.pbf"