diff --git a/cosmos/operators/airflow_async.py b/cosmos/operators/airflow_async.py index d6b1bda5a..ff874a06a 100644 --- a/cosmos/operators/airflow_async.py +++ b/cosmos/operators/airflow_async.py @@ -8,6 +8,7 @@ from cosmos.operators._asynchronous.base import DbtRunAirflowAsyncFactoryOperator from cosmos.operators.base import AbstractDbtBase from cosmos.operators.local import ( + AbstractDbtLocalBase, DbtBuildLocalOperator, DbtCloneLocalOperator, DbtCompileLocalOperator, @@ -58,6 +59,7 @@ def __init__( clean_kwargs = {} non_async_args = set(inspect.signature(AbstractDbtBase.__init__).parameters.keys()) non_async_args |= set(inspect.signature(DbtLocalBaseOperator.__init__).parameters.keys()) + non_async_args |= set(inspect.signature(AbstractDbtLocalBase.__init__).parameters.keys()) dbt_kwargs = {} diff --git a/dev/dags/simple_dag_async.py b/dev/dags/simple_dag_async.py index 8fb8cb844..4eb910132 100644 --- a/dev/dags/simple_dag_async.py +++ b/dev/dags/simple_dag_async.py @@ -37,6 +37,6 @@ catchup=False, dag_id="simple_dag_async", tags=["simple"], - operator_args={"location": "northamerica-northeast1"}, + operator_args={"location": "northamerica-northeast1", "install_deps": True}, ) # [END airflow_async_execution_mode_example] diff --git a/tests/conftest.py b/tests/conftest.py index d553fb7b4..6f7991c6d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,10 +10,7 @@ def mock_bigquery_conn(): # type: ignore """ Mocks and returns an Airflow BigQuery connection. """ - extra = { - "project": "my_project", - "key_path": "my_key_path.json", - } + extra = {"project": "my_project", "key_path": "my_key_path.json", "dataset": "test"} conn = Connection( conn_id="my_bigquery_connection", conn_type="google_cloud_platform", diff --git a/tests/operators/test_airflow_async.py b/tests/operators/test_airflow_async.py index 0f0d5cdf7..309a8341e 100644 --- a/tests/operators/test_airflow_async.py +++ b/tests/operators/test_airflow_async.py @@ -1,3 +1,9 @@ +from datetime import datetime +from pathlib import Path + +import pytest + +from cosmos import DbtDag, ExecutionConfig, ExecutionMode, ProfileConfig, ProjectConfig from cosmos.operators.airflow_async import ( DbtBuildAirflowAsyncOperator, DbtCompileAirflowAsyncOperator, @@ -18,6 +24,35 @@ DbtSourceLocalOperator, DbtTestLocalOperator, ) +from cosmos.profiles import get_automatic_profile_mapping + +DBT_PROJECTS_ROOT_DIR = Path(__file__).parent.parent.parent / "dev/dags/dbt" +DBT_PROJECT_NAME = "original_jaffle_shop" + + +@pytest.mark.integration +def test_airflow_async_operator_init(mock_bigquery_conn): + """Test that Airflow can correctly parse an async operator with operator args""" + profile_mapping = get_automatic_profile_mapping(mock_bigquery_conn.conn_id, {}) + + profile_config = ProfileConfig( + profile_name="airflow_db", + target_name="bq", + profile_mapping=profile_mapping, + ) + + DbtDag( + project_config=ProjectConfig(dbt_project_path=DBT_PROJECTS_ROOT_DIR / DBT_PROJECT_NAME), + profile_config=profile_config, + execution_config=ExecutionConfig( + execution_mode=ExecutionMode.AIRFLOW_ASYNC, + ), + schedule_interval=None, + start_date=datetime(2023, 1, 1), + catchup=False, + dag_id="simple_dag_async", + operator_args={"location": "us", "install_deps": True}, + ) def test_dbt_build_airflow_async_operator_inheritance():