Skip to content

Commit

Permalink
Extend VirtualEnv operator and mock dbt adapters for setup & teardown…
Browse files Browse the repository at this point in the history
… tasks in ExecutionMode.AIRFLOW_ASYNC (astronomer#1544)

## Summary

This PR extends the `DbtRunVirtualenvOperator` to the setup and teardown
tasks for `ExecutionMode.AIRFLOW_ASYNC`. It ensures that dbt adapters
are installed in the virtualenv created by the
`DbtRunVirtualenvOperator` and are properly mocked within the virtual
environment.

## Key Changes

✅ **Extending SetupAsyncOperator and TeardownAsyncOperator with
DbtRunVirtualenvOperator and mocking dbt adapters**
- These operators inherit from `DbtRunVirtualenvOperator` and override
`run_subprocess`.
- The operators extract the mock function from the appropriate
`cosmos.operators._asynchronous` module.
- The extracted function is injected dynamically into the dbt CLI entry
script within the virtual environment, so that mocked dbt adapters get
called before executing dbt commands.

✅ **Decoupled dbt Adapters from Airflow Environment**
- With this change, dbt adapters (e.g., `dbt-bigquery`, `dbt-postgres`)
no longer need to be installed in the same environment as the Airflow
installation.
- The `ExecutionConfig` now exposes `async_py_requirements` which can be
set by DAG authors, ensuring that the necessary dbt dependencies are
installed inside the virtual environment for
`ExecutionMode.AIRFLOW_ASYNC`.


closes: astronomer#1533 

---------

Co-authored-by: Tatiana Al-Chueyr <tatiana.alchueyr@gmail.com>
Co-authored-by: Pankaj Singh <98807258+pankajastro@users.noreply.github.com>
  • Loading branch information
3 people authored Feb 19, 2025
1 parent 968b80e commit 6ddc3c2
Show file tree
Hide file tree
Showing 10 changed files with 234 additions and 12 deletions.
32 changes: 30 additions & 2 deletions cosmos/airflow/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from cosmos.core.airflow import get_airflow_task as create_airflow_task
from cosmos.core.graph.entities import Task as TaskMetadata
from cosmos.dbt.graph import DbtNode
from cosmos.exceptions import CosmosValueError
from cosmos.log import get_logger
from cosmos.settings import enable_setup_async_task, enable_teardown_async_task

Expand Down Expand Up @@ -413,14 +414,19 @@ def _add_dbt_setup_async_task(
tasks_map: dict[str, Any],
task_group: TaskGroup | None,
render_config: RenderConfig | None = None,
async_py_requirements: list[str] | None = None,
) -> None:
if execution_mode != ExecutionMode.AIRFLOW_ASYNC:
return

if not async_py_requirements:
raise CosmosValueError("ExecutionConfig.AIRFLOW_ASYNC needs async_py_requirements to be set")

if render_config is not None:
task_args["select"] = render_config.select
task_args["selector"] = render_config.selector
task_args["exclude"] = render_config.exclude
task_args["py_requirements"] = async_py_requirements

setup_task_metadata = TaskMetadata(
id=DBT_SETUP_ASYNC_TASK_ID,
Expand Down Expand Up @@ -495,14 +501,19 @@ def _add_teardown_task(
tasks_map: dict[str, Any],
task_group: TaskGroup | None,
render_config: RenderConfig | None = None,
async_py_requirements: list[str] | None = None,
) -> None:
if execution_mode != ExecutionMode.AIRFLOW_ASYNC:
return

if not async_py_requirements:
raise CosmosValueError("ExecutionConfig.AIRFLOW_ASYNC needs async_py_requirements to be set")

if render_config is not None:
task_args["select"] = render_config.select
task_args["selector"] = render_config.selector
task_args["exclude"] = render_config.exclude
task_args["py_requirements"] = async_py_requirements

teardown_task_metadata = TaskMetadata(
id=DBT_TEARDOWN_ASYNC_TASK_ID,
Expand All @@ -529,6 +540,7 @@ def build_airflow_graph(
render_config: RenderConfig,
task_group: TaskGroup | None = None,
on_warning_callback: Callable[..., Any] | None = None, # argument specific to the DBT test command
async_py_requirements: list[str] | None = None,
) -> dict[str, Union[TaskGroup, BaseOperator]]:
"""
Instantiate dbt `nodes` as Airflow tasks within the given `task_group` (optional) or `dag` (mandatory).
Expand Down Expand Up @@ -626,9 +638,25 @@ def build_airflow_graph(

create_airflow_task_dependencies(nodes, tasks_map)
if enable_setup_async_task:
_add_dbt_setup_async_task(dag, execution_mode, task_args, tasks_map, task_group, render_config=render_config)
_add_dbt_setup_async_task(
dag,
execution_mode,
task_args,
tasks_map,
task_group,
render_config=render_config,
async_py_requirements=async_py_requirements,
)
if enable_teardown_async_task:
_add_teardown_task(dag, execution_mode, task_args, tasks_map, task_group, render_config=render_config)
_add_teardown_task(
dag,
execution_mode,
task_args,
tasks_map,
task_group,
render_config=render_config,
async_py_requirements=async_py_requirements,
)
return tasks_map


Expand Down
3 changes: 3 additions & 0 deletions cosmos/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,8 @@ class ExecutionConfig:
:param dbt_project_path: Configures the DBT project location accessible at runtime for dag execution. This is the project path in a docker container for ExecutionMode.DOCKER or ExecutionMode.KUBERNETES. Mutually Exclusive with ProjectConfig.dbt_project_path
:param virtualenv_dir: Directory path to locate the (cached) virtual env that
should be used for execution when execution mode is set to `ExecutionMode.VIRTUALENV`
:param async_py_requirements: A list of Python packages to install when `ExecutionMode.AIRFLOW_ASYNC`(Experimental) is used. This parameter is required only if both `enable_setup_async_task` and `enable_teardown_async_task` are set to `True`.
Example: `["dbt-postgres==1.5.0"]`
"""

execution_mode: ExecutionMode = ExecutionMode.LOCAL
Expand All @@ -409,6 +411,7 @@ class ExecutionConfig:
virtualenv_dir: str | Path | None = None

project_path: Path | None = field(init=False)
async_py_requirements: list[str] | None = None

def __post_init__(self, dbt_project_path: str | Path | None) -> None:
if self.invocation_mode and self.execution_mode not in (ExecutionMode.LOCAL, ExecutionMode.VIRTUALENV):
Expand Down
1 change: 1 addition & 0 deletions cosmos/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@ def __init__(
dbt_project_name=render_config.project_name,
on_warning_callback=on_warning_callback,
render_config=render_config,
async_py_requirements=execution_config.async_py_requirements,
)

current_time = time.perf_counter()
Expand Down
51 changes: 48 additions & 3 deletions cosmos/operators/_asynchronous/__init__.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,74 @@
from __future__ import annotations

import inspect
import textwrap
from pathlib import Path
from typing import Any

from airflow.utils.context import Context

from cosmos.operators.local import DbtRunLocalOperator as DbtRunOperator
from cosmos._utils.importer import load_method_from_module
from cosmos.hooks.subprocess import FullOutputSubprocessResult
from cosmos.operators.virtualenv import DbtRunVirtualenvOperator


class SetupAsyncOperator(DbtRunOperator):
class SetupAsyncOperator(DbtRunVirtualenvOperator):
def __init__(self, *args: Any, **kwargs: Any):
kwargs["emit_datasets"] = False
super().__init__(*args, **kwargs)

def run_subprocess(self, command: list[str], env: dict[str, str], cwd: str) -> FullOutputSubprocessResult:
profile_type = self.profile_config.get_profile_type()
if not self._py_bin:
raise AttributeError("_py_bin attribute not set for VirtualEnv operator")
dbt_executable_path = str(Path(self._py_bin).parent / "dbt")
asynchronous_operator_module = f"cosmos.operators._asynchronous.{profile_type}"
mock_function_name = f"_mock_{profile_type}_adapter"
mock_function = load_method_from_module(asynchronous_operator_module, mock_function_name)
mock_function_full_source = inspect.getsource(mock_function)
mock_function_body = textwrap.dedent("\n".join(mock_function_full_source.split("\n")[1:]))

with open(dbt_executable_path) as f:
dbt_entrypoint_script = f.readlines()
if dbt_entrypoint_script[0].startswith("#!"):
dbt_entrypoint_script.insert(1, mock_function_body)
with open(dbt_executable_path, "w") as f:
f.writelines(dbt_entrypoint_script)

return super().run_subprocess(command, env, cwd)

def execute(self, context: Context, **kwargs: Any) -> None:
async_context = {"profile_type": self.profile_config.get_profile_type()}
self.build_and_run_cmd(
context=context, cmd_flags=self.dbt_cmd_flags, run_as_async=True, async_context=async_context
)


class TeardownAsyncOperator(DbtRunOperator):
class TeardownAsyncOperator(DbtRunVirtualenvOperator):
def __init__(self, *args: Any, **kwargs: Any):
kwargs["emit_datasets"] = False
super().__init__(*args, **kwargs)

def run_subprocess(self, command: list[str], env: dict[str, str], cwd: str) -> FullOutputSubprocessResult:
profile_type = self.profile_config.get_profile_type()
if not self._py_bin:
raise AttributeError("_py_bin attribute not set for VirtualEnv operator")
dbt_executable_path = str(Path(self._py_bin).parent / "dbt")
asynchronous_operator_module = f"cosmos.operators._asynchronous.{profile_type}"
mock_function_name = f"_mock_{profile_type}_adapter"
mock_function = load_method_from_module(asynchronous_operator_module, mock_function_name)
mock_function_full_source = inspect.getsource(mock_function)
mock_function_body = textwrap.dedent("\n".join(mock_function_full_source.split("\n")[1:]))

with open(dbt_executable_path) as f:
dbt_entrypoint_script = f.readlines()
if dbt_entrypoint_script[0].startswith("#!"):
dbt_entrypoint_script.insert(1, mock_function_body)
with open(dbt_executable_path, "w") as f:
f.writelines(dbt_entrypoint_script)

return super().run_subprocess(command, env, cwd)

def execute(self, context: Context, **kwargs: Any) -> Any:
async_context = {"profile_type": self.profile_config.get_profile_type(), "teardown_task": True}
self.build_and_run_cmd(
Expand Down
2 changes: 1 addition & 1 deletion cosmos/operators/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ def run_command(
if self.install_deps:
self._install_dependencies(tmp_dir_path, flags, env)

if run_as_async:
if run_as_async and not enable_setup_async_task:
self._mock_dbt_adapter(async_context)

full_cmd = cmd + flags
Expand Down
4 changes: 2 additions & 2 deletions cosmos/operators/virtualenv.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def run_command(
with TemporaryDirectory(prefix="cosmos-venv") as tempdir:
self.virtualenv_dir = Path(tempdir)
self._py_bin = self._prepare_virtualenv()
return super().run_command(cmd, env, context)
return super().run_command(cmd, env, context, run_as_async=run_as_async, async_context=async_context)

try:
self.log.info(f"Checking if the virtualenv lock {str(self._lock_file)} exists")
Expand All @@ -117,7 +117,7 @@ def run_command(
self.log.info("Acquiring the virtualenv lock")
self._acquire_venv_lock()
self._py_bin = self._prepare_virtualenv()
return super().run_command(cmd, env, context)
return super().run_command(cmd, env, context, run_as_async=run_as_async, async_context=async_context)
finally:
self.log.info("Releasing virtualenv lock")
self._release_venv_lock()
Expand Down
3 changes: 2 additions & 1 deletion dev/dags/simple_dag_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
profile_config=profile_config,
execution_config=ExecutionConfig(
execution_mode=ExecutionMode.AIRFLOW_ASYNC,
async_py_requirements=["dbt-bigquery"],
),
render_config=RenderConfig(
select=["path:models"],
Expand All @@ -37,6 +38,6 @@
catchup=False,
dag_id="simple_dag_async",
tags=["simple"],
operator_args={"location": "northamerica-northeast1", "install_deps": True},
operator_args={"location": "US", "install_deps": True},
)
# [END airflow_async_execution_mode_example]
18 changes: 17 additions & 1 deletion tests/airflow/test_graph.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
from datetime import datetime
from pathlib import Path
from unittest.mock import patch
from unittest.mock import Mock, patch

import pytest
from airflow import __version__ as airflow_version
Expand All @@ -11,6 +11,7 @@
from packaging import version

from cosmos.airflow.graph import (
_add_teardown_task,
_snake_case_to_camelcase,
build_airflow_graph,
calculate_detached_node_name,
Expand All @@ -30,6 +31,7 @@
)
from cosmos.converter import airflow_kwargs
from cosmos.dbt.graph import DbtNode
from cosmos.exceptions import CosmosValueError
from cosmos.profiles import PostgresUserPasswordProfileMapping

SAMPLE_PROJ_PATH = Path("/home/user/path/dbt-proj/")
Expand Down Expand Up @@ -982,3 +984,17 @@ def test_custom_meta():
assert task.queue == "custom_queue"
else:
assert task.queue == "default"


def test_add_teardown_task_raises_error_without_async_py_requirements():
"""Test that an error is raised if async_py_requirements is not provided."""
task_args = {}

sample_dag = DAG(dag_id="test_dag")
sample_tasks_map = {
"task_1": Mock(downstream_list=[]),
"task_2": Mock(downstream_list=[]),
}

with pytest.raises(CosmosValueError, match="ExecutionConfig.AIRFLOW_ASYNC needs async_py_requirements to be set"):
_add_teardown_task(sample_dag, ExecutionMode.AIRFLOW_ASYNC, task_args, sample_tasks_map, None, None)
104 changes: 102 additions & 2 deletions tests/operators/_asynchronous/test_base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from __future__ import annotations

from unittest.mock import MagicMock, Mock, patch
from unittest.mock import MagicMock, Mock, mock_open, patch

import pytest

from cosmos.config import ProfileConfig
from cosmos.operators._asynchronous import TeardownAsyncOperator
from cosmos.hooks.subprocess import FullOutputSubprocessResult
from cosmos.operators._asynchronous import SetupAsyncOperator, TeardownAsyncOperator
from cosmos.operators._asynchronous.base import DbtRunAirflowAsyncFactoryOperator, _create_async_operator_class
from cosmos.operators._asynchronous.bigquery import DbtRunAirflowAsyncBigqueryOperator
from cosmos.operators._asynchronous.databricks import DbtRunAirflowAsyncDatabricksOperator
Expand Down Expand Up @@ -79,3 +80,102 @@ def test_teardown_execute(mock_build_and_run_cmd):
)
operator.execute({})
mock_build_and_run_cmd.assert_called_once()


@pytest.fixture
def mock_operator_params():
return {
"task_id": "test_task",
"project_dir": "/tmp",
"profile_config": MagicMock(get_profile_type=MagicMock(return_value="bigquery")),
}


@pytest.fixture
def mock_load_method():
"""Mock load_method_from_module to return a fake function."""
mock_function = MagicMock()
mock_function.__name__ = "_mock_bigquery_adapter"
mock_function.__module__ = "cosmos.operators._asynchronous.bigquery"
with patch("cosmos._utils.importer.load_method_from_module", return_value=mock_function):
yield mock_function


@pytest.fixture
def mock_file_operations():
"""Mock file reading/writing operations."""
with patch("builtins.open", mock_open(read_data="#!/usr/bin/env python\n")) as mock_file:
yield mock_file


@pytest.fixture
def mock_super_run_subprocess():
with patch(
"cosmos.operators.virtualenv.DbtRunVirtualenvOperator.run_subprocess",
return_value=FullOutputSubprocessResult(0, "", ""),
) as mock_run:
yield mock_run


def test_setup_run_subprocess(mock_operator_params, mock_load_method, mock_file_operations, mock_super_run_subprocess):
op = SetupAsyncOperator(**mock_operator_params)
op._py_bin = "/fake/venv/bin/python"
command = ["dbt", "run"]
env = {}
cwd = "/tmp"

op.run_subprocess(command, env, cwd)

mock_file_operations.assert_called_with("/fake/venv/bin/dbt", "w")
mock_super_run_subprocess.assert_called_once_with(command, env, cwd)


def test_teardown_run_subprocess(
mock_operator_params, mock_load_method, mock_file_operations, mock_super_run_subprocess
):
op = TeardownAsyncOperator(**mock_operator_params)
op._py_bin = "/fake/venv/bin/python"

command = ["dbt", "clean"]
env = {}
cwd = "/tmp"

op.run_subprocess(command, env, cwd)

mock_file_operations.assert_called_with("/fake/venv/bin/dbt", "w")
mock_super_run_subprocess.assert_called_once_with(command, env, cwd)


def test_setup_execute(mock_operator_params):
op = SetupAsyncOperator(**mock_operator_params)

with patch.object(op, "build_and_run_cmd") as mock_build_and_run:
op.execute(context={})

mock_build_and_run.assert_called_once_with(
context={}, cmd_flags=op.dbt_cmd_flags, run_as_async=True, async_context={"profile_type": "bigquery"}
)


def test_setup_run_subprocess_py_bin_unset(
mock_operator_params, mock_load_method, mock_file_operations, mock_super_run_subprocess
):
op = SetupAsyncOperator(**mock_operator_params)
command = ["dbt", "run"]
env = {}
cwd = "/tmp"

with pytest.raises(AttributeError, match="_py_bin attribute not set for VirtualEnv operator"):
op.run_subprocess(command, env, cwd)


def test_teardown_run_subprocess_py_bin_unset(
mock_operator_params, mock_load_method, mock_file_operations, mock_super_run_subprocess
):
op = TeardownAsyncOperator(**mock_operator_params)
command = ["dbt", "run"]
env = {}
cwd = "/tmp"

with pytest.raises(AttributeError, match="_py_bin attribute not set for VirtualEnv operator"):
op.run_subprocess(command, env, cwd)
Loading

0 comments on commit 6ddc3c2

Please sign in to comment.