From 6ddc3c25f8a6d33bb38d6c99203214012f40c903 Mon Sep 17 00:00:00 2001 From: Pankaj Koti Date: Wed, 19 Feb 2025 14:28:39 +0530 Subject: [PATCH] Extend VirtualEnv operator and mock dbt adapters for setup & teardown tasks in ExecutionMode.AIRFLOW_ASYNC (#1544) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary This PR extends the `DbtRunVirtualenvOperator` to the setup and teardown tasks for `ExecutionMode.AIRFLOW_ASYNC`. It ensures that dbt adapters are installed in the virtualenv created by the `DbtRunVirtualenvOperator` and are properly mocked within the virtual environment. ## Key Changes ✅ **Extending SetupAsyncOperator and TeardownAsyncOperator with DbtRunVirtualenvOperator and mocking dbt adapters** - These operators inherit from `DbtRunVirtualenvOperator` and override `run_subprocess`. - The operators extract the mock function from the appropriate `cosmos.operators._asynchronous` module. - The extracted function is injected dynamically into the dbt CLI entry script within the virtual environment, so that mocked dbt adapters get called before executing dbt commands. ✅ **Decoupled dbt Adapters from Airflow Environment** - With this change, dbt adapters (e.g., `dbt-bigquery`, `dbt-postgres`) no longer need to be installed in the same environment as the Airflow installation. - The `ExecutionConfig` now exposes `async_py_requirements` which can be set by DAG authors, ensuring that the necessary dbt dependencies are installed inside the virtual environment for `ExecutionMode.AIRFLOW_ASYNC`. closes: #1533 --------- Co-authored-by: Tatiana Al-Chueyr Co-authored-by: Pankaj Singh <98807258+pankajastro@users.noreply.github.com> --- cosmos/airflow/graph.py | 32 ++++++- cosmos/config.py | 3 + cosmos/converter.py | 1 + cosmos/operators/_asynchronous/__init__.py | 51 +++++++++- cosmos/operators/local.py | 2 +- cosmos/operators/virtualenv.py | 4 +- dev/dags/simple_dag_async.py | 3 +- tests/airflow/test_graph.py | 18 +++- tests/operators/_asynchronous/test_base.py | 104 ++++++++++++++++++++- tests/operators/test_airflow_async.py | 28 ++++++ 10 files changed, 234 insertions(+), 12 deletions(-) diff --git a/cosmos/airflow/graph.py b/cosmos/airflow/graph.py index ee03e17f9..f96d3ba2a 100644 --- a/cosmos/airflow/graph.py +++ b/cosmos/airflow/graph.py @@ -25,6 +25,7 @@ from cosmos.core.airflow import get_airflow_task as create_airflow_task from cosmos.core.graph.entities import Task as TaskMetadata from cosmos.dbt.graph import DbtNode +from cosmos.exceptions import CosmosValueError from cosmos.log import get_logger from cosmos.settings import enable_setup_async_task, enable_teardown_async_task @@ -413,14 +414,19 @@ def _add_dbt_setup_async_task( tasks_map: dict[str, Any], task_group: TaskGroup | None, render_config: RenderConfig | None = None, + async_py_requirements: list[str] | None = None, ) -> None: if execution_mode != ExecutionMode.AIRFLOW_ASYNC: return + if not async_py_requirements: + raise CosmosValueError("ExecutionConfig.AIRFLOW_ASYNC needs async_py_requirements to be set") + if render_config is not None: task_args["select"] = render_config.select task_args["selector"] = render_config.selector task_args["exclude"] = render_config.exclude + task_args["py_requirements"] = async_py_requirements setup_task_metadata = TaskMetadata( id=DBT_SETUP_ASYNC_TASK_ID, @@ -495,14 +501,19 @@ def _add_teardown_task( tasks_map: dict[str, Any], task_group: TaskGroup | None, render_config: RenderConfig | None = None, + async_py_requirements: list[str] | None = None, ) -> None: if execution_mode != ExecutionMode.AIRFLOW_ASYNC: return + if not async_py_requirements: + raise CosmosValueError("ExecutionConfig.AIRFLOW_ASYNC needs async_py_requirements to be set") + if render_config is not None: task_args["select"] = render_config.select task_args["selector"] = render_config.selector task_args["exclude"] = render_config.exclude + task_args["py_requirements"] = async_py_requirements teardown_task_metadata = TaskMetadata( id=DBT_TEARDOWN_ASYNC_TASK_ID, @@ -529,6 +540,7 @@ def build_airflow_graph( render_config: RenderConfig, task_group: TaskGroup | None = None, on_warning_callback: Callable[..., Any] | None = None, # argument specific to the DBT test command + async_py_requirements: list[str] | None = None, ) -> dict[str, Union[TaskGroup, BaseOperator]]: """ Instantiate dbt `nodes` as Airflow tasks within the given `task_group` (optional) or `dag` (mandatory). @@ -626,9 +638,25 @@ def build_airflow_graph( create_airflow_task_dependencies(nodes, tasks_map) if enable_setup_async_task: - _add_dbt_setup_async_task(dag, execution_mode, task_args, tasks_map, task_group, render_config=render_config) + _add_dbt_setup_async_task( + dag, + execution_mode, + task_args, + tasks_map, + task_group, + render_config=render_config, + async_py_requirements=async_py_requirements, + ) if enable_teardown_async_task: - _add_teardown_task(dag, execution_mode, task_args, tasks_map, task_group, render_config=render_config) + _add_teardown_task( + dag, + execution_mode, + task_args, + tasks_map, + task_group, + render_config=render_config, + async_py_requirements=async_py_requirements, + ) return tasks_map diff --git a/cosmos/config.py b/cosmos/config.py index 37b4a67c8..0fe17ce6f 100644 --- a/cosmos/config.py +++ b/cosmos/config.py @@ -398,6 +398,8 @@ class ExecutionConfig: :param dbt_project_path: Configures the DBT project location accessible at runtime for dag execution. This is the project path in a docker container for ExecutionMode.DOCKER or ExecutionMode.KUBERNETES. Mutually Exclusive with ProjectConfig.dbt_project_path :param virtualenv_dir: Directory path to locate the (cached) virtual env that should be used for execution when execution mode is set to `ExecutionMode.VIRTUALENV` + :param async_py_requirements: A list of Python packages to install when `ExecutionMode.AIRFLOW_ASYNC`(Experimental) is used. This parameter is required only if both `enable_setup_async_task` and `enable_teardown_async_task` are set to `True`. + Example: `["dbt-postgres==1.5.0"]` """ execution_mode: ExecutionMode = ExecutionMode.LOCAL @@ -409,6 +411,7 @@ class ExecutionConfig: virtualenv_dir: str | Path | None = None project_path: Path | None = field(init=False) + async_py_requirements: list[str] | None = None def __post_init__(self, dbt_project_path: str | Path | None) -> None: if self.invocation_mode and self.execution_mode not in (ExecutionMode.LOCAL, ExecutionMode.VIRTUALENV): diff --git a/cosmos/converter.py b/cosmos/converter.py index 7c4205f37..7c917a022 100644 --- a/cosmos/converter.py +++ b/cosmos/converter.py @@ -335,6 +335,7 @@ def __init__( dbt_project_name=render_config.project_name, on_warning_callback=on_warning_callback, render_config=render_config, + async_py_requirements=execution_config.async_py_requirements, ) current_time = time.perf_counter() diff --git a/cosmos/operators/_asynchronous/__init__.py b/cosmos/operators/_asynchronous/__init__.py index 84d6820d2..6853313f9 100644 --- a/cosmos/operators/_asynchronous/__init__.py +++ b/cosmos/operators/_asynchronous/__init__.py @@ -1,17 +1,42 @@ from __future__ import annotations +import inspect +import textwrap +from pathlib import Path from typing import Any from airflow.utils.context import Context -from cosmos.operators.local import DbtRunLocalOperator as DbtRunOperator +from cosmos._utils.importer import load_method_from_module +from cosmos.hooks.subprocess import FullOutputSubprocessResult +from cosmos.operators.virtualenv import DbtRunVirtualenvOperator -class SetupAsyncOperator(DbtRunOperator): +class SetupAsyncOperator(DbtRunVirtualenvOperator): def __init__(self, *args: Any, **kwargs: Any): kwargs["emit_datasets"] = False super().__init__(*args, **kwargs) + def run_subprocess(self, command: list[str], env: dict[str, str], cwd: str) -> FullOutputSubprocessResult: + profile_type = self.profile_config.get_profile_type() + if not self._py_bin: + raise AttributeError("_py_bin attribute not set for VirtualEnv operator") + dbt_executable_path = str(Path(self._py_bin).parent / "dbt") + asynchronous_operator_module = f"cosmos.operators._asynchronous.{profile_type}" + mock_function_name = f"_mock_{profile_type}_adapter" + mock_function = load_method_from_module(asynchronous_operator_module, mock_function_name) + mock_function_full_source = inspect.getsource(mock_function) + mock_function_body = textwrap.dedent("\n".join(mock_function_full_source.split("\n")[1:])) + + with open(dbt_executable_path) as f: + dbt_entrypoint_script = f.readlines() + if dbt_entrypoint_script[0].startswith("#!"): + dbt_entrypoint_script.insert(1, mock_function_body) + with open(dbt_executable_path, "w") as f: + f.writelines(dbt_entrypoint_script) + + return super().run_subprocess(command, env, cwd) + def execute(self, context: Context, **kwargs: Any) -> None: async_context = {"profile_type": self.profile_config.get_profile_type()} self.build_and_run_cmd( @@ -19,11 +44,31 @@ def execute(self, context: Context, **kwargs: Any) -> None: ) -class TeardownAsyncOperator(DbtRunOperator): +class TeardownAsyncOperator(DbtRunVirtualenvOperator): def __init__(self, *args: Any, **kwargs: Any): kwargs["emit_datasets"] = False super().__init__(*args, **kwargs) + def run_subprocess(self, command: list[str], env: dict[str, str], cwd: str) -> FullOutputSubprocessResult: + profile_type = self.profile_config.get_profile_type() + if not self._py_bin: + raise AttributeError("_py_bin attribute not set for VirtualEnv operator") + dbt_executable_path = str(Path(self._py_bin).parent / "dbt") + asynchronous_operator_module = f"cosmos.operators._asynchronous.{profile_type}" + mock_function_name = f"_mock_{profile_type}_adapter" + mock_function = load_method_from_module(asynchronous_operator_module, mock_function_name) + mock_function_full_source = inspect.getsource(mock_function) + mock_function_body = textwrap.dedent("\n".join(mock_function_full_source.split("\n")[1:])) + + with open(dbt_executable_path) as f: + dbt_entrypoint_script = f.readlines() + if dbt_entrypoint_script[0].startswith("#!"): + dbt_entrypoint_script.insert(1, mock_function_body) + with open(dbt_executable_path, "w") as f: + f.writelines(dbt_entrypoint_script) + + return super().run_subprocess(command, env, cwd) + def execute(self, context: Context, **kwargs: Any) -> Any: async_context = {"profile_type": self.profile_config.get_profile_type(), "teardown_task": True} self.build_and_run_cmd( diff --git a/cosmos/operators/local.py b/cosmos/operators/local.py index 6652ad411..21fa6ae91 100644 --- a/cosmos/operators/local.py +++ b/cosmos/operators/local.py @@ -531,7 +531,7 @@ def run_command( if self.install_deps: self._install_dependencies(tmp_dir_path, flags, env) - if run_as_async: + if run_as_async and not enable_setup_async_task: self._mock_dbt_adapter(async_context) full_cmd = cmd + flags diff --git a/cosmos/operators/virtualenv.py b/cosmos/operators/virtualenv.py index 4026d3eb4..32d5d6a9a 100644 --- a/cosmos/operators/virtualenv.py +++ b/cosmos/operators/virtualenv.py @@ -105,7 +105,7 @@ def run_command( with TemporaryDirectory(prefix="cosmos-venv") as tempdir: self.virtualenv_dir = Path(tempdir) self._py_bin = self._prepare_virtualenv() - return super().run_command(cmd, env, context) + return super().run_command(cmd, env, context, run_as_async=run_as_async, async_context=async_context) try: self.log.info(f"Checking if the virtualenv lock {str(self._lock_file)} exists") @@ -117,7 +117,7 @@ def run_command( self.log.info("Acquiring the virtualenv lock") self._acquire_venv_lock() self._py_bin = self._prepare_virtualenv() - return super().run_command(cmd, env, context) + return super().run_command(cmd, env, context, run_as_async=run_as_async, async_context=async_context) finally: self.log.info("Releasing virtualenv lock") self._release_venv_lock() diff --git a/dev/dags/simple_dag_async.py b/dev/dags/simple_dag_async.py index 4eb910132..23e9836cf 100644 --- a/dev/dags/simple_dag_async.py +++ b/dev/dags/simple_dag_async.py @@ -26,6 +26,7 @@ profile_config=profile_config, execution_config=ExecutionConfig( execution_mode=ExecutionMode.AIRFLOW_ASYNC, + async_py_requirements=["dbt-bigquery"], ), render_config=RenderConfig( select=["path:models"], @@ -37,6 +38,6 @@ catchup=False, dag_id="simple_dag_async", tags=["simple"], - operator_args={"location": "northamerica-northeast1", "install_deps": True}, + operator_args={"location": "US", "install_deps": True}, ) # [END airflow_async_execution_mode_example] diff --git a/tests/airflow/test_graph.py b/tests/airflow/test_graph.py index d86abab74..a3d9474dc 100644 --- a/tests/airflow/test_graph.py +++ b/tests/airflow/test_graph.py @@ -1,7 +1,7 @@ import os from datetime import datetime from pathlib import Path -from unittest.mock import patch +from unittest.mock import Mock, patch import pytest from airflow import __version__ as airflow_version @@ -11,6 +11,7 @@ from packaging import version from cosmos.airflow.graph import ( + _add_teardown_task, _snake_case_to_camelcase, build_airflow_graph, calculate_detached_node_name, @@ -30,6 +31,7 @@ ) from cosmos.converter import airflow_kwargs from cosmos.dbt.graph import DbtNode +from cosmos.exceptions import CosmosValueError from cosmos.profiles import PostgresUserPasswordProfileMapping SAMPLE_PROJ_PATH = Path("/home/user/path/dbt-proj/") @@ -982,3 +984,17 @@ def test_custom_meta(): assert task.queue == "custom_queue" else: assert task.queue == "default" + + +def test_add_teardown_task_raises_error_without_async_py_requirements(): + """Test that an error is raised if async_py_requirements is not provided.""" + task_args = {} + + sample_dag = DAG(dag_id="test_dag") + sample_tasks_map = { + "task_1": Mock(downstream_list=[]), + "task_2": Mock(downstream_list=[]), + } + + with pytest.raises(CosmosValueError, match="ExecutionConfig.AIRFLOW_ASYNC needs async_py_requirements to be set"): + _add_teardown_task(sample_dag, ExecutionMode.AIRFLOW_ASYNC, task_args, sample_tasks_map, None, None) diff --git a/tests/operators/_asynchronous/test_base.py b/tests/operators/_asynchronous/test_base.py index 6bcad07aa..4a1e953bb 100644 --- a/tests/operators/_asynchronous/test_base.py +++ b/tests/operators/_asynchronous/test_base.py @@ -1,11 +1,12 @@ from __future__ import annotations -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import MagicMock, Mock, mock_open, patch import pytest from cosmos.config import ProfileConfig -from cosmos.operators._asynchronous import TeardownAsyncOperator +from cosmos.hooks.subprocess import FullOutputSubprocessResult +from cosmos.operators._asynchronous import SetupAsyncOperator, TeardownAsyncOperator from cosmos.operators._asynchronous.base import DbtRunAirflowAsyncFactoryOperator, _create_async_operator_class from cosmos.operators._asynchronous.bigquery import DbtRunAirflowAsyncBigqueryOperator from cosmos.operators._asynchronous.databricks import DbtRunAirflowAsyncDatabricksOperator @@ -79,3 +80,102 @@ def test_teardown_execute(mock_build_and_run_cmd): ) operator.execute({}) mock_build_and_run_cmd.assert_called_once() + + +@pytest.fixture +def mock_operator_params(): + return { + "task_id": "test_task", + "project_dir": "/tmp", + "profile_config": MagicMock(get_profile_type=MagicMock(return_value="bigquery")), + } + + +@pytest.fixture +def mock_load_method(): + """Mock load_method_from_module to return a fake function.""" + mock_function = MagicMock() + mock_function.__name__ = "_mock_bigquery_adapter" + mock_function.__module__ = "cosmos.operators._asynchronous.bigquery" + with patch("cosmos._utils.importer.load_method_from_module", return_value=mock_function): + yield mock_function + + +@pytest.fixture +def mock_file_operations(): + """Mock file reading/writing operations.""" + with patch("builtins.open", mock_open(read_data="#!/usr/bin/env python\n")) as mock_file: + yield mock_file + + +@pytest.fixture +def mock_super_run_subprocess(): + with patch( + "cosmos.operators.virtualenv.DbtRunVirtualenvOperator.run_subprocess", + return_value=FullOutputSubprocessResult(0, "", ""), + ) as mock_run: + yield mock_run + + +def test_setup_run_subprocess(mock_operator_params, mock_load_method, mock_file_operations, mock_super_run_subprocess): + op = SetupAsyncOperator(**mock_operator_params) + op._py_bin = "/fake/venv/bin/python" + command = ["dbt", "run"] + env = {} + cwd = "/tmp" + + op.run_subprocess(command, env, cwd) + + mock_file_operations.assert_called_with("/fake/venv/bin/dbt", "w") + mock_super_run_subprocess.assert_called_once_with(command, env, cwd) + + +def test_teardown_run_subprocess( + mock_operator_params, mock_load_method, mock_file_operations, mock_super_run_subprocess +): + op = TeardownAsyncOperator(**mock_operator_params) + op._py_bin = "/fake/venv/bin/python" + + command = ["dbt", "clean"] + env = {} + cwd = "/tmp" + + op.run_subprocess(command, env, cwd) + + mock_file_operations.assert_called_with("/fake/venv/bin/dbt", "w") + mock_super_run_subprocess.assert_called_once_with(command, env, cwd) + + +def test_setup_execute(mock_operator_params): + op = SetupAsyncOperator(**mock_operator_params) + + with patch.object(op, "build_and_run_cmd") as mock_build_and_run: + op.execute(context={}) + + mock_build_and_run.assert_called_once_with( + context={}, cmd_flags=op.dbt_cmd_flags, run_as_async=True, async_context={"profile_type": "bigquery"} + ) + + +def test_setup_run_subprocess_py_bin_unset( + mock_operator_params, mock_load_method, mock_file_operations, mock_super_run_subprocess +): + op = SetupAsyncOperator(**mock_operator_params) + command = ["dbt", "run"] + env = {} + cwd = "/tmp" + + with pytest.raises(AttributeError, match="_py_bin attribute not set for VirtualEnv operator"): + op.run_subprocess(command, env, cwd) + + +def test_teardown_run_subprocess_py_bin_unset( + mock_operator_params, mock_load_method, mock_file_operations, mock_super_run_subprocess +): + op = TeardownAsyncOperator(**mock_operator_params) + command = ["dbt", "run"] + env = {} + cwd = "/tmp" + + with pytest.raises(AttributeError, match="_py_bin attribute not set for VirtualEnv operator"): + op.run_subprocess(command, env, cwd) diff --git a/tests/operators/test_airflow_async.py b/tests/operators/test_airflow_async.py index 309a8341e..1f36052c7 100644 --- a/tests/operators/test_airflow_async.py +++ b/tests/operators/test_airflow_async.py @@ -4,6 +4,7 @@ import pytest from cosmos import DbtDag, ExecutionConfig, ExecutionMode, ProfileConfig, ProjectConfig +from cosmos.exceptions import CosmosValueError from cosmos.operators.airflow_async import ( DbtBuildAirflowAsyncOperator, DbtCompileAirflowAsyncOperator, @@ -46,6 +47,7 @@ def test_airflow_async_operator_init(mock_bigquery_conn): profile_config=profile_config, execution_config=ExecutionConfig( execution_mode=ExecutionMode.AIRFLOW_ASYNC, + async_py_requirements=["dbt-bigquery"], ), schedule_interval=None, start_date=datetime(2023, 1, 1), @@ -55,6 +57,32 @@ def test_airflow_async_operator_init(mock_bigquery_conn): ) +@pytest.mark.integration +def test_airflow_async_operator_init_no_async_py_requirements_raises_error(mock_bigquery_conn): + """Test that Airflow can correctly parse an async operator with operator args""" + profile_mapping = get_automatic_profile_mapping(mock_bigquery_conn.conn_id, {}) + + profile_config = ProfileConfig( + profile_name="airflow_db", + target_name="bq", + profile_mapping=profile_mapping, + ) + + with pytest.raises(CosmosValueError, match="ExecutionConfig.AIRFLOW_ASYNC needs async_py_requirements to be set"): + DbtDag( + project_config=ProjectConfig(dbt_project_path=DBT_PROJECTS_ROOT_DIR / DBT_PROJECT_NAME), + profile_config=profile_config, + execution_config=ExecutionConfig( + execution_mode=ExecutionMode.AIRFLOW_ASYNC, + ), + schedule_interval=None, + start_date=datetime(2023, 1, 1), + catchup=False, + dag_id="simple_dag_async", + operator_args={"location": "us", "install_deps": True}, + ) + + def test_dbt_build_airflow_async_operator_inheritance(): assert issubclass(DbtBuildAirflowAsyncOperator, DbtBuildLocalOperator)