diff --git a/dbterd/adapters/algos/base.py b/dbterd/adapters/algos/base.py index 72206a4..2683ae8 100644 --- a/dbterd/adapters/algos/base.py +++ b/dbterd/adapters/algos/base.py @@ -152,7 +152,7 @@ def get_table_from_metadata(model_metadata, exposures=[], **kwargs) -> Table: ), ), node_name=node_name, - raw_sql="TODO", + raw_sql=None, database=node_database, schema=node_schema, columns=[], @@ -561,6 +561,8 @@ def get_table_map_from_metadata(test_node, **kwargs): to_model_possible_values = [ f"{first_test_parent_resource_type}('{first_test_parent_parts[2]}','{first_test_parent_parts[-1]}')", f"{first_test_parent_resource_type}('{first_test_parent_parts[-1]}')", + f'{first_test_parent_resource_type}("{first_test_parent_parts[2]}","{first_test_parent_parts[-1]}")', + f'{first_test_parent_resource_type}("{first_test_parent_parts[-1]}")', ] if test_metadata_to in to_model_possible_values: return test_parents diff --git a/dbterd/adapters/algos/test_relationship.py b/dbterd/adapters/algos/test_relationship.py index fa93fa9..8af0660 100644 --- a/dbterd/adapters/algos/test_relationship.py +++ b/dbterd/adapters/algos/test_relationship.py @@ -101,4 +101,7 @@ def parse(manifest, catalog, **kwargs): logger.info( f"Collected {len(tables)} table(s) and {len(relationships)} relationship(s)" ) - return (tables, relationships) + return ( + sorted(tables, key=lambda tbl: tbl.node_name), + sorted(relationships, key=lambda rel: rel.name), + ) diff --git a/dbterd/adapters/targets/mermaid/mermaid_test_relationship.py b/dbterd/adapters/targets/mermaid/mermaid_test_relationship.py index fbaf560..b3b5677 100644 --- a/dbterd/adapters/targets/mermaid/mermaid_test_relationship.py +++ b/dbterd/adapters/targets/mermaid/mermaid_test_relationship.py @@ -1,3 +1,6 @@ +import re +from typing import Optional + from dbterd.adapters.algos import test_relationship @@ -14,6 +17,56 @@ def run(manifest, catalog, **kwargs): return ("output.md", parse(manifest, catalog, **kwargs)) +def replace_column_name(column_name: str) -> str: + """Replace column names containing special characters. + To prevent mermaid from not being able to render column names that may contain special characters. + + Args: + column_name (str): column name + + Returns: + str: Column name with special characters substituted + """ + return column_name.replace(" ", "-").replace(".", "__") + + +def match_complex_column_type(column_type: str) -> Optional[str]: + """Returns the root type from nested complex types. + As an example, if the input is `Struct`, return `Struct`. + + Args: + column_type (str): column type + + Returns: + Optional[str]: Returns root type if input type is nested complex type, otherwise returns `None` for primitive types + """ + pattern = r"(\w+)<(\w+\s+\w+(\s*,\s*\w+\s+\w+)*)>" + match = re.match(pattern, column_type) + if match: + return match.group(1) + else: + return None + + +def replace_column_type(column_type: str) -> str: + """If type of column contains special characters that cannot be drawn by mermaid, replace them with strings that can be drawn. + If the type string contains a nested complex type, omit it to make it easier to read. + + Args: + column_type (str): column type + + Returns: + str: Type of column with special characters are substituted or omitted + """ + # Some specific DWHs may have types that cannot be drawn in mermaid, such as `Struct`. + # These types may be nested and can be very long, so omit them + complex_column_type = match_complex_column_type(column_type) + if complex_column_type: + return f"{complex_column_type}[OMITTED]" + else: + return column_type.replace(" ", "-") + + def parse(manifest, catalog, **kwargs): """Get the Mermaid content from dbt artifacts @@ -35,7 +88,7 @@ def parse(manifest, catalog, **kwargs): table_name = table.name.upper() columns = "\n".join( [ - f' {x.data_type.replace(" ","-")} {x.name.replace(" ","-")}' + f" {replace_column_type(x.data_type)} {replace_column_name(x.name)}" for x in table.columns ] ) @@ -49,9 +102,9 @@ def parse(manifest, catalog, **kwargs): for rel in relationships: key_from = f'"{rel.table_map[1]}"' key_to = f'"{rel.table_map[0]}"' - reference_text = rel.column_map[0].replace(" ", "-") + reference_text = replace_column_name(rel.column_map[0]) if rel.column_map[0] != rel.column_map[1]: - reference_text += f"--{ rel.column_map[1].replace(' ','-')}" + reference_text += f"--{ replace_column_name(rel.column_map[1])}" mermaid += f" {key_from.upper()} {get_rel_symbol(rel.type)} {key_to.upper()}: {reference_text}\n" return mermaid diff --git a/docs/nav/guide/cli-references.md b/docs/nav/guide/cli-references.md index 0d92cdb..212c277 100644 --- a/docs/nav/guide/cli-references.md +++ b/docs/nav/guide/cli-references.md @@ -403,7 +403,9 @@ Check [Download artifacts from a Job Run](./dbt-cloud/download-artifact-from-a-j ## dbterd run-metadata -Command to generate diagram-as-a-code file by connecting to dbt Cloud Discovery API using GraphQL connection +Command to generate diagram-as-a-code file by connecting to dbt Cloud Discovery API using GraphQL connection. + +Check [this guideline](./dbt-cloud/read-artifact-from-an-environment.md) for more details. **Examples:** === "CLI" diff --git a/mkdocs.yml b/mkdocs.yml index 7e0f60f..56783f7 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -28,7 +28,7 @@ nav: - Download the latest artifacts from a Job: nav/guide/dbt-cloud/download-artifact-from-a-job.md - Read the latest artifacts from an environment: nav/guide/dbt-cloud/read-artifact-from-an-environment.md - Contribution Guideline ❤️: nav/development/contributing-guide.md - - License: license.md + - License 🔑: license.md - Change Log ↗️: https://github.com/datnguye/dbterd/releases" target="_blank theme: diff --git a/tests/unit/adapters/algos/test_test_relationship.py b/tests/unit/adapters/algos/test_test_relationship.py index b9a7d1b..298c8af 100644 --- a/tests/unit/adapters/algos/test_test_relationship.py +++ b/tests/unit/adapters/algos/test_test_relationship.py @@ -2,6 +2,7 @@ from unittest import mock from unittest.mock import MagicMock +import click import pytest from dbterd.adapters.algos import base as base_algo @@ -407,3 +408,385 @@ def test_get_relationship_type(self, meta, type): ) def test_get_node_exposures(self, manifest, expected): assert expected == base_algo.get_node_exposures(manifest=manifest) + + @pytest.mark.parametrize( + "data, expected", + [ + ([], []), + ([{}], []), + ([{}, {}], []), + ([{"models": {"edges": []}, "sources": {"edges": []}}], []), + ], + ) + @mock.patch("dbterd.adapters.algos.base.get_table_from_metadata") + @mock.patch("dbterd.adapters.algos.base.get_node_exposures_from_metadata") + def test_get_tables_from_metadata_w_empty_data( + self, + mock_get_node_exposures_from_metadata, + mock_get_table_from_metadata, + data, + expected, + ): + assert expected == base_algo.get_tables_from_metadata( + data=data, resource_type=["model", "source"] + ) + mock_get_node_exposures_from_metadata.assert_called_once() + assert mock_get_table_from_metadata.call_count == 0 + + @pytest.mark.parametrize( + "resource_type, data, get_table_from_metadata_call_count", + [ + ( + ["model"], + [{"models": {"edges": ["item1", "item2"]}, "sources": {"edges": []}}], + 2, + ), + ( + ["model"], + [ + { + "models": {"edges": ["item1", "item2"]}, + "sources": {"edges": ["source1"]}, + } + ], + 2, + ), + ( + ["model", "source"], + [ + { + "models": {"edges": ["item1", "item2"]}, + "sources": {"edges": ["source1"]}, + } + ], + 3, + ), + ( + [], + [ + { + "models": {"edges": ["item1", "item2"]}, + "sources": {"edges": ["source1"]}, + } + ], + 0, + ), + ], + ) + @mock.patch("dbterd.adapters.algos.base.get_table_from_metadata") + @mock.patch("dbterd.adapters.algos.base.get_node_exposures_from_metadata") + def test_get_tables_from_metadata_w_1_data( + self, + mock_get_node_exposures_from_metadata, + mock_get_table_from_metadata, + resource_type, + data, + get_table_from_metadata_call_count, + ): + base_algo.get_tables_from_metadata(data=data, resource_type=resource_type) + mock_get_node_exposures_from_metadata.assert_called_once() + assert ( + mock_get_table_from_metadata.call_count + == get_table_from_metadata_call_count + ) + + @pytest.mark.parametrize( + "model_metadata, exposures, kwargs, expected", + [ + ( + { + "node": { + "uniqueId": "model.package.name1", + "database": "db1", + "schema": "sc1", + "name": "name1", + "catalog": {}, + } + }, + [], + dict(entity_name_format="resource.package.model"), + Table( + name="model.package.name1", + node_name="model.package.name1", + database="db1", + schema="sc1", + columns=[ + Column(name="unknown", data_type="unknown", description="") + ], + raw_sql=None, + description=None, + ), + ), + ( + { + "node": { + "uniqueId": "model.package.name1", + "database": "db1", + "schema": "sc1", + "name": "name1", + "catalog": {"columns": [{"name": "col1"}]}, + } + }, + [], + dict(entity_name_format="resource.package.model"), + Table( + name="model.package.name1", + node_name="model.package.name1", + database="db1", + schema="sc1", + columns=[Column(name="col1", data_type="", description="")], + raw_sql=None, + description=None, + ), + ), + ( + { + "node": { + "uniqueId": "model.package.name1", + "database": "db1", + "schema": "sc1", + "name": "name1", + "catalog": { + "columns": [ + {"name": "col1", "type": "type1"}, + {"name": "col2", "type": "type2"}, + ] + }, + } + }, + [], + dict(entity_name_format="resource.package.model"), + Table( + name="model.package.name1", + node_name="model.package.name1", + database="db1", + schema="sc1", + columns=[ + Column(name="col1", data_type="type1", description=""), + Column(name="col2", data_type="type2", description=""), + ], + raw_sql=None, + description=None, + ), + ), + ], + ) + def test_get_table_from_metadata(self, model_metadata, exposures, kwargs, expected): + assert expected == base_algo.get_table_from_metadata( + model_metadata=model_metadata, exposures=exposures, **kwargs + ) + + @pytest.mark.parametrize( + "data, kwargs, expected", + [ + ([], dict(resource_type=["model", "source"]), []), + ( + [{"exposures": {"edges": []}}], + dict(resource_type=["model", "source"]), + [], + ), + ( + [ + { + "exposures": { + "edges": [ + { + "node": { + "name": "ex1", + "parents": [{"uniqueId": "model.x"}], + } + } + ] + } + } + ], + dict(resource_type=["model", "source"]), + [dict(node_name="model.x", exposure_name="ex1")], + ), + ( + [ + { + "exposures": { + "edges": [ + { + "node": { + "name": "ex1", + "parents": [ + {"uniqueId": "model.x"}, + {"uniqueId": "model.y"}, + {"uniqueId": "source.z"}, + ], + } + } + ] + } + } + ], + dict(resource_type=["model"]), + [ + dict(node_name="model.x", exposure_name="ex1"), + dict(node_name="model.y", exposure_name="ex1"), + ], + ), + ], + ) + def test_get_node_exposures_from_metadata(self, data, kwargs, expected): + assert expected == base_algo.get_node_exposures_from_metadata( + data=data, **kwargs + ) + + @pytest.mark.parametrize( + "data, kwargs, expected", + [ + ([], dict(algo="test_relationship"), []), + ( + [{"tests": {"edges": []}}], + dict(algo="test_relationship", resource_type=["model", "source"]), + [], + ), + ( + [ + { + "tests": { + "edges": [ + { + "node": { + "uniqueId": "test.relationship_1", + "testMetadata": { + "kwargs": { + "columnName": "coly", + "to": 'ref("x")', + "field": "colx", + } + }, + "parents": [], + } + } + ] + } + } + ], + dict(algo="test_relationship", resource_type=["model", "source"]), + [ + Ref( + name="test.relationship_1", + table_map=["", ""], + column_map=["colx", "coly"], + type="n1", + ) + ], + ), + ( + [ + { + "tests": { + "edges": [ + { + "node": { + "uniqueId": "test.relationship_1", + "meta": {}, + "testMetadata": { + "kwargs": { + "columnName": "coly", + "to": 'ref("x")', + "field": "colx", + } + }, + "parents": [ + {"uniqueId": "model.p.x"}, + {"uniqueId": "model.p.y"}, + ], + } + } + ] + } + } + ], + dict(algo="test_relationship", resource_type=["model", "source"]), + [ + Ref( + name="test.relationship_1", + table_map=["model.p.x", "model.p.y"], + column_map=["colx", "coly"], + type="n1", + ) + ], + ), + ( + [ + { + "tests": { + "edges": [ + { + "node": { + "uniqueId": "test.relationship_1", + "meta": {}, + "testMetadata": { + "kwargs": { + "columnName": "coly", + "to": 'ref("x")', + "field": "colx", + } + }, + "parents": [ + {"uniqueId": "model.p.y"}, + {"uniqueId": "model.p.x"}, + ], + } + } + ] + } + } + ], + dict(algo="test_relationship", resource_type=["model", "source"]), + [ + Ref( + name="test.relationship_1", + table_map=["model.p.x", "model.p.y"], + column_map=["colx", "coly"], + type="n1", + ) + ], + ), + ], + ) + def test_get_relationships_from_metadata(self, data, kwargs, expected): + assert expected == base_algo.get_relationships_from_metadata( + data=data, **kwargs + ) + + @pytest.mark.parametrize( + "data, kwargs", + [ + ( + [ + { + "tests": { + "edges": [ + { + "node": { + "uniqueId": "test.relationship_1", + "meta": {}, + "testMetadata": { + "kwargs": { + "columnName": "coly", + "to": 'ref("x")', + "field": "colx", + } + }, + "parents": [ + {"uniqueId": "model.p.x"}, + ], + } + } + ] + } + } + ], + dict(algo="test_relationship", resource_type=["model", "source"]), + ), + ], + ) + def test_get_relationships_from_metadata_error(self, data, kwargs): + with pytest.raises(click.BadParameter): + base_algo.get_relationships_from_metadata(data=data, **kwargs) diff --git a/tests/unit/adapters/dbt_cloud/__init__.py b/tests/unit/adapters/dbt_cloud/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/adapters/test_dbt_cloud.py b/tests/unit/adapters/dbt_cloud/test_administrative.py similarity index 100% rename from tests/unit/adapters/test_dbt_cloud.py rename to tests/unit/adapters/dbt_cloud/test_administrative.py diff --git a/tests/unit/adapters/dbt_cloud/test_discovery.py b/tests/unit/adapters/dbt_cloud/test_discovery.py new file mode 100644 index 0000000..1a6f130 --- /dev/null +++ b/tests/unit/adapters/dbt_cloud/test_discovery.py @@ -0,0 +1,94 @@ +from unittest import mock + +import pytest + +from dbterd.adapters.dbt_cloud.discovery import DbtCloudMetadata + + +class TestDbtCloudMetadata: + @pytest.fixture + def dbtCloudMetadata(self) -> DbtCloudMetadata: + return DbtCloudMetadata( + dbt_cloud_host_url="irrelevant_url", + dbt_cloud_service_token="irrelevant_st", + dbt_cloud_environment_id="irrelevant_env_id", + dbt_cloud_query_file_path="irrelevant_query_file_path", + ) + + @pytest.mark.parametrize( + "graphql_query_data, extract_data", + [ + ({}, [{}]), + ( + dict( + environment=dict( + applied=dict( + models=dict(edges=[]), + sources=dict(edges=[]), + exposures=dict(edges=[]), + tests=dict(edges=[]), + ) + ) + ), + [ + dict( + models=dict(edges=[]), + sources=dict(edges=[]), + exposures=dict(edges=[]), + tests=dict(edges=[]), + ) + ], + ), + ], + ) + @mock.patch("dbterd.adapters.dbt_cloud.graphql.GraphQLHelper.query") + def test_query_erd_data(self, mock_graphql_query, graphql_query_data, extract_data): + mock_graphql_query.return_value = graphql_query_data + assert extract_data == DbtCloudMetadata().query_erd_data() + assert mock_graphql_query.call_count == len(extract_data) + + @mock.patch("dbterd.adapters.dbt_cloud.graphql.GraphQLHelper.query") + def test_query_erd_data_no_polling(self, mock_graphql_query): + mock_graphql_query.return_value = {} + assert [{}] == DbtCloudMetadata().query_erd_data(poll_until_end=False) + assert mock_graphql_query.call_count == 1 + + @mock.patch("dbterd.adapters.dbt_cloud.graphql.GraphQLHelper.query") + def test_query_erd_data_polling_twice(self, mock_graphql_query): + mock_graphql_query.side_effect = [ + dict( + environment=dict( + applied=dict( + models=dict(edges=[], pageInfo=dict(hasNextPage=True)), + sources=dict(edges=[]), + exposures=dict(edges=[]), + tests=dict(edges=[]), + ) + ) + ), + dict( + environment=dict( + applied=dict( + models=dict(edges=[]), + sources=dict(edges=[]), + exposures=dict(edges=[]), + tests=dict(edges=[]), + ) + ) + ), + ] + assert [ + dict( + models=dict(edges=[], pageInfo=dict(hasNextPage=True)), + sources=dict(edges=[]), + exposures=dict(edges=[]), + tests=dict(edges=[]), + ), + dict( + models=dict(edges=[]), + sources=dict(edges=[]), + exposures=dict(edges=[]), + tests=dict(edges=[]), + ), + ] == DbtCloudMetadata().query_erd_data() + assert mock_graphql_query.call_count == 2 diff --git a/tests/unit/adapters/dbt_cloud/test_graphql.py b/tests/unit/adapters/dbt_cloud/test_graphql.py new file mode 100644 index 0000000..40f80e8 --- /dev/null +++ b/tests/unit/adapters/dbt_cloud/test_graphql.py @@ -0,0 +1,58 @@ +from unittest import mock + +import pytest + +from dbterd.adapters.dbt_cloud.graphql import GraphQLHelper + + +class MockResponse: + def __init__(self, status_code, data=None) -> None: + self.status_code = status_code + self.data = data + + def json(self): + return dict(data=self.data) + + +class TestGraphQL: + @pytest.mark.parametrize( + "kwargs, expected", + [ + ( + dict(), + dict(host_url=None, service_token=None), + ), + ( + dict(dbt_cloud_host_url="host_url", dbt_cloud_service_token="token"), + dict(host_url="host_url", service_token="token"), + ), + ], + ) + def test_init(self, kwargs, expected): + helper = GraphQLHelper(**kwargs) + assert vars(helper) == expected + assert helper.request_headers == { + "authorization": f"Bearer {helper.service_token}", + "content-type": "application/json", + } + assert helper.api_endpoint == f"https://{helper.host_url}/graphql/" + + @mock.patch("dbterd.adapters.dbt_cloud.administrative.requests.post") + def test_query(self, mock_requests_post): + mock_requests_post.return_value = MockResponse(status_code=200, data={}) + assert {} == GraphQLHelper().query(query="irrelevant", **dict()) + assert mock_requests_post.call_count == 1 + + @mock.patch("dbterd.adapters.dbt_cloud.administrative.requests.post") + def test_query_failed(self, mock_requests_post): + mock_requests_post.return_value = MockResponse( + status_code="irrelevant", data={} + ) + assert GraphQLHelper().query(query="irrelevant", **dict()) is None + assert mock_requests_post.call_count == 1 + + @mock.patch("dbterd.adapters.dbt_cloud.administrative.requests.post") + def test_query_with_exception(self, mock_requests_post): + mock_requests_post.side_effect = Exception("any error") + assert GraphQLHelper().query(query="irrelevant", **dict()) is None + assert mock_requests_post.call_count == 1 diff --git a/tests/unit/adapters/dbt_cloud/test_query.py b/tests/unit/adapters/dbt_cloud/test_query.py new file mode 100644 index 0000000..262c80d --- /dev/null +++ b/tests/unit/adapters/dbt_cloud/test_query.py @@ -0,0 +1,6 @@ +from dbterd.adapters.dbt_cloud.query import Query + + +class TestQuery: + def test_get_file_content_error(self): + assert Query().get_file_content(file_path="invalid-file-path") is None diff --git a/tests/unit/adapters/targets/mermaid/test_mermaid_test_relationship.py b/tests/unit/adapters/targets/mermaid/test_mermaid_test_relationship.py index 7aa9801..e7c19ee 100644 --- a/tests/unit/adapters/targets/mermaid/test_mermaid_test_relationship.py +++ b/tests/unit/adapters/targets/mermaid/test_mermaid_test_relationship.py @@ -249,6 +249,55 @@ class TestMermaidTestRelationship: } """, ), + ( + [ + Table( + name="model.dbt_resto.table1", + node_name="model.dbt_resto.table1", + database="--database--", + schema="--schema--", + columns=[ + Column(name="name1.first_name", data_type="name1-type") + ], + raw_sql="--irrelevant--", + ), + Table( + name="model.dbt_resto.table2", + node_name="model.dbt_resto.table2", + database="--database2--", + schema="--schema2--", + columns=[ + Column(name="name2.first_name", data_type="name2-type2"), + Column( + name="complex_struct", + data_type="Struct", + ), + ], + raw_sql="--irrelevant--", + ), + ], + [ + Ref( + name="test.dbt_resto.relationships_table1", + table_map=["model.dbt_resto.table2", "model.dbt_resto.table1"], + column_map=["name2.first_name", "name1.first_name"], + ), + ], + [], + [], + ["model", "source"], + False, + """erDiagram + "MODEL.DBT_RESTO.TABLE1" { + name1-type name1__first_name + } + "MODEL.DBT_RESTO.TABLE2" { + name2-type2 name2__first_name + Struct[OMITTED] complex_struct + } + "MODEL.DBT_RESTO.TABLE1" }|--|| "MODEL.DBT_RESTO.TABLE2": name2__first_name--name1__first_name + """, + ), ], ) def test_parse( diff --git a/tests/unit/adapters/test_base.py b/tests/unit/adapters/test_base.py index d2065a2..79d14c3 100644 --- a/tests/unit/adapters/test_base.py +++ b/tests/unit/adapters/test_base.py @@ -10,6 +10,50 @@ class TestBase: + @mock.patch("dbterd.adapters.base.Executor.evaluate_kwargs") + @mock.patch("dbterd.adapters.base.Executor._Executor__run_metadata_by_strategy") + def test_run_metadata(self, mock_run_metadata_by_strategy, mock_evaluate_kwargs): + Executor(ctx=click.Context(command=click.BaseCommand("dummy"))).run_metadata() + mock_evaluate_kwargs.assert_called_once() + mock_run_metadata_by_strategy.assert_called_once() + + @mock.patch("dbterd.adapters.base.DbtCloudMetadata.query_erd_data") + @mock.patch("dbterd.adapters.base.Executor._Executor__save_result") + def test___run_metadata_by_strategy(self, mock_query_erd_data, mock_save_result): + Executor( + ctx=click.Context(command=click.BaseCommand("dummy")) + )._Executor__run_metadata_by_strategy(target="dbml", algo="test_relationship") + mock_query_erd_data.assert_called_once() + mock_save_result.assert_called_once() + + @mock.patch("builtins.open") + def test___save_result(self, mock_open): + Executor( + ctx=click.Context(command=click.BaseCommand("dummy")) + )._Executor__save_result(path="irrelevant", data=("file_name", {})) + mock_open.assert_called_once_with("irrelevant/file_name", "w") + + @mock.patch("dbterd.adapters.base.DbtCloudArtifact.get") + @mock.patch("dbterd.adapters.base.Executor._Executor__read_manifest") + @mock.patch("dbterd.adapters.base.Executor._Executor__read_catalog") + @mock.patch("dbterd.adapters.base.Executor._Executor__save_result") + def test___run_by_strategy_w_dbt_cloud( + self, + mock_cloud_artifact_get, + mock_read_manifest, + mock_read_catalog, + mock_save_result, + ): + Executor( + ctx=click.Context(command=click.BaseCommand("dummy")) + )._Executor__run_by_strategy( + target="dbml", algo="test_relationship", dbt_cloud=True + ) + mock_cloud_artifact_get.assert_called_once() + mock_read_manifest.assert_called_once() + mock_read_catalog.assert_called_once() + mock_save_result.assert_called_once() + def test_worker(self): worker = Executor(ctx=click.Context(command=click.BaseCommand("dummy"))) assert worker.filename_manifest == "manifest.json" @@ -59,9 +103,10 @@ def test__get_selection__error(self, mock_dbt_invocation): worker._Executor__get_selection() @pytest.mark.parametrize( - "kwargs, expected", + "command, kwargs, expected", [ ( + "run", dict( select=[], exclude=[], @@ -74,6 +119,7 @@ def test__get_selection__error(self, mock_dbt_invocation): ), ), ( + "run", dict(select=[], exclude=[], dbt=True), dict( dbt=True, @@ -84,6 +130,7 @@ def test__get_selection__error(self, mock_dbt_invocation): ), ), ( + "run", dict(select=[], exclude=[], dbt=True, dbt_auto_artifacts=True), dict( dbt=True, @@ -95,6 +142,7 @@ def test__get_selection__error(self, mock_dbt_invocation): ), ), ( + "run", dict(select=[], exclude=[], dbt_cloud=True), dict( dbt_cloud=True, @@ -114,16 +162,20 @@ def test_evaluate_kwargs( mock_get_artifacts_for_erd, mock_get_selection, mock_get_dir, + command, kwargs, expected, ): - worker = Executor(ctx=click.Context(command=click.BaseCommand("run"))) + worker = Executor(ctx=click.Context(command=click.BaseCommand(command))) mock_get_dir.return_value = ("/path/ad", "/path/dpd") mock_get_selection.return_value = ["yolo"] assert expected == worker.evaluate_kwargs(**kwargs) mock_get_dir.assert_called_once() - if kwargs.get("dbt_auto_artifacts"): - mock_get_artifacts_for_erd.assert_called_once() + if command == "run": + if kwargs.get("dbt") and kwargs.get("dbt_auto_artifacts"): + mock_get_artifacts_for_erd.assert_called_once() + else: + assert mock_get_artifacts_for_erd.called_count == 0 @pytest.mark.parametrize( "kwargs, mock_isfile_se, expected", diff --git a/tests/unit/cli/test_runner.py b/tests/unit/cli/test_runner.py index b04272a..a89e891 100644 --- a/tests/unit/cli/test_runner.py +++ b/tests/unit/cli/test_runner.py @@ -138,3 +138,10 @@ def test_invoke_run_failed_to_write_output( mock_read_c.assert_called_once() mock_engine_parse.assert_called_once() mock_open_w.assert_called_once() + + def test_invoke_run_metadata_ok(self, dbterd: dbterdRunner) -> None: + with mock.patch( + "dbterd.cli.main.Executor.run_metadata", return_value=None + ) as mock_run_metadata: + dbterd.invoke(["run-metadata"]) + mock_run_metadata.assert_called_once()