Skip to content

Commit

Permalink
Use dbtRunner in the DAG Processor when using LoadMode.DBT_LS if …
Browse files Browse the repository at this point in the history
…`dbt-core` is available (astronomer#1484)

This PR significantly improves Cosmos resource utilisation in the
scheduler DAG Processor and in the worker nodes when dynamically
converting dbt workflows into Airflow DAGs using the `LoadMode.DBT_LS`.

It introduces support to use `dbtRunner` during DAG parsing if
`dbt-core` and its adaptors are in the same Python virtualenv as
Airflow.

This change is particularly relevant given the way Airflow (2.x) parses
DAGs not only as part of the scheduler loop but also whenever a task
executes:

<img width="845" alt="Screenshot 2025-01-23 at 13 51 40"
src="https://github.com/user-attachments/assets/90398307-a26c-4cbd-ae44-a6a5c1e0e98e"
/>

> Diagram extracted from the talk @pankajkoti and I gave at Airflow
Summit 2024
>
https://airflowsummit.org/sessions/2024/overcoming-performance-hurdles-in-integrating-dbt-with-airflow/)

When using `LoadMode.DBT_LS`, Cosmos runs `dbt ls` whenever Airflow
parses the DAG (in case of a cache miss).

Suppose there is a Cosmos `DbtDag` with 200 concurrent tasks. If the dbt
project changes, when Airflow and Cosmos attempt to parse the `DbtDag`,
they will invalidate the dbt ls cache. If 200 Cosmos tasks execute
concurrently when there is a cache miss, they will all run the same `dbt
ls` command. Until Cosmos 1.8, Cosmos would always create a subprocess
for each command. If 200 tasks execute in a worker node, this would
represent 400 processes attempting to run concurrently, leading to a
vast resource CPU - and potentially memory - spike and Out of Memory
(OOM) errors. While this change does not avoid the 200 tasks attempting
to run `dbt ls` concurrently, it avoids each of them creating an
additional subprocess - optimising the resource utilisation.

This change is heavily influenced by changes (astronomer#850) previously made by
@jbandoro, who added support for Cosmos to use `dbtRunner` to execute
dbt commands in the Airflow worker nodes when using
`ExecutionMode.LOCAL` instead of Python's subprocess.

Closes: astronomer#865
  • Loading branch information
tatiana authored Jan 28, 2025
1 parent 4eefe29 commit e11e5ae
Show file tree
Hide file tree
Showing 3 changed files with 178 additions and 38 deletions.
85 changes: 68 additions & 17 deletions cosmos/dbt/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from airflow.models import Variable

import cosmos.dbt.runner as dbt_runner
from cosmos import cache, settings
from cosmos.cache import (
_configure_remote_cache_dir,
Expand Down Expand Up @@ -158,11 +159,8 @@ def is_freshness_effective(freshness: Optional[dict[str, Any]]) -> bool:
return False


def run_command(command: list[str], tmp_dir: Path, env_vars: dict[str, str]) -> str:
def run_command_with_subprocess(command: list[str], tmp_dir: Path, env_vars: dict[str, str]) -> str:
"""Run a command in a subprocess, returning the stdout."""
command = [str(arg) if arg is not None else "<None>" for arg in command]
logger.info("Running command: `%s`", " ".join(command))
logger.debug("Environment variable keys: %s", env_vars.keys())
process = Popen(
command,
stdout=PIPE,
Expand All @@ -186,6 +184,66 @@ def run_command(command: list[str], tmp_dir: Path, env_vars: dict[str, str]) ->
return stdout


def run_command_with_dbt_runner(command: list[str], tmp_dir: Path | None, env_vars: dict[str, str]) -> str:
"""Run a command with dbtRunner, returning the stdout."""
response = dbt_runner.run_command(command=command, env=env_vars, cwd=str(tmp_dir))

stderr = ""
stdout = ""
result_list = (
[json.dumps(item.to_dict()) if hasattr(item, "to_dict") else item for item in response.result]
if response.result
else []
)
if response.result:
stdout = "\n".join(result_list)

if not response.success:
if response.exception:
stderr = str(response.exception)
if 'Run "dbt deps" to install package dependencies' in stderr and command[1] == "ls":
raise CosmosLoadDbtException(
"Unable to run dbt ls command due to missing dbt_packages. Set RenderConfig.dbt_deps=True."
)
elif response.result:
node_names, node_results = dbt_runner.extract_message_by_status(
response, ["error", "fail", "runtime error"]
)
stderr = "\n".join([f"{name}: {result}" for name, result in zip(node_names, node_results)])

if stderr:
details = f"stderr: {stderr}\nstdout: {stdout}"
raise CosmosLoadDbtException(f"Unable to run {command} due to the error:\n{details}")

return stdout


def run_command(command: list[str], tmp_dir: Path, env_vars: dict[str, str], log_dir: Path | None = None) -> str:
"""Run a command either with dbtRunner or Python subprocess, returning the stdout."""

runner = "dbt Runner" if dbt_runner.is_available() else "Python subprocess"
command = [str(arg) if arg is not None else "<None>" for arg in command]
logger.info("Running command with %s: `%s`", runner, " ".join(command))
logger.debug("Environment variable keys: %s", env_vars.keys())

if dbt_runner.is_available():
stdout = run_command_with_dbt_runner(command, tmp_dir, env_vars)
else:
stdout = run_command_with_subprocess(command, tmp_dir, env_vars)

logger.debug("dbt ls output: %s", stdout)

if log_dir is not None:
log_filepath = log_dir / DBT_LOG_FILENAME
logger.debug("dbt logs available in: %s", log_filepath)
if log_filepath.exists():
with open(log_filepath) as logfile:
for line in logfile:
logger.debug(line.strip())

return stdout


def parse_dbt_ls_output(project_path: Path | None, ls_stdout: str) -> dict[str, DbtNode]:
"""Parses the output of `dbt ls` into a dictionary of `DbtNode` instances."""
nodes = {}
Expand Down Expand Up @@ -262,6 +320,7 @@ def __init__(
self.dbt_ls_cache_key = ""
self.dbt_vars = dbt_vars or {}
self.operator_args = operator_args or {}
self.log_dir: Path | None = None

@cached_property
def env_vars(self) -> dict[str, str]:
Expand Down Expand Up @@ -471,15 +530,7 @@ def run_dbt_ls(
ls_command.extend(self.local_flags)
ls_command.extend(ls_args)

stdout = run_command(ls_command, tmp_dir, env_vars)

logger.debug("dbt ls output: %s", stdout)
log_filepath = self.log_dir / DBT_LOG_FILENAME
logger.debug("dbt logs available in: %s", log_filepath)
if log_filepath.exists():
with open(log_filepath) as logfile:
for line in logfile:
logger.debug(line.strip())
stdout = run_command(ls_command, tmp_dir, env_vars, self.log_dir)

if self.should_use_dbt_ls_cache():
self.save_dbt_ls_cache(stdout)
Expand Down Expand Up @@ -540,8 +591,7 @@ def run_dbt_deps(self, dbt_cmd: str, dbt_project_path: Path, env: dict[str, str]
deps_command = [dbt_cmd, "deps"]
deps_command.extend(self.local_flags)
self._add_vars_arg(deps_command)
stdout = run_command(deps_command, dbt_project_path, env)
logger.debug("dbt deps output: %s", stdout)
run_command(deps_command, dbt_project_path, env, self.log_dir)

def load_via_dbt_ls_without_cache(self) -> None:
"""
Expand Down Expand Up @@ -597,11 +647,12 @@ def load_via_dbt_ls_without_cache(self) -> None:
self.profile_config.target_name,
]

self.log_dir = Path(env.get(DBT_LOG_PATH_ENVVAR) or tmpdir_path / DBT_LOG_DIR_NAME)
self.target_dir = Path(env.get(DBT_TARGET_PATH_ENVVAR) or tmpdir_path / DBT_TARGET_DIR_NAME)
env[DBT_LOG_PATH_ENVVAR] = str(self.log_dir)
env[DBT_TARGET_PATH_ENVVAR] = str(self.target_dir)

self.log_dir = Path(env.get(DBT_LOG_PATH_ENVVAR) or tmpdir_path / DBT_LOG_DIR_NAME)
env[DBT_LOG_PATH_ENVVAR] = str(self.log_dir)

if self.render_config.dbt_deps and has_non_empty_dependencies_file(self.project_path):
if is_cache_package_lockfile_enabled(project_path):
latest_package_lockfile = _get_latest_cached_package_lockfile(project_path)
Expand Down
120 changes: 104 additions & 16 deletions tests/dbt/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,13 @@

from cosmos import settings
from cosmos.config import CosmosConfigException, ExecutionConfig, ProfileConfig, ProjectConfig, RenderConfig
from cosmos.constants import DBT_TARGET_DIR_NAME, DbtResourceType, ExecutionMode, SourceRenderingBehavior
from cosmos.constants import (
DBT_LOG_FILENAME,
DBT_TARGET_DIR_NAME,
DbtResourceType,
ExecutionMode,
SourceRenderingBehavior,
)
from cosmos.dbt.graph import (
CosmosLoadDbtException,
DbtGraph,
Expand Down Expand Up @@ -439,15 +445,37 @@ def test_load(


@pytest.mark.integration
@pytest.mark.parametrize("enable_cache_profile", [True, False])
@pytest.mark.parametrize(
"runner,enable_cache_profile",
[
("subprocess", True),
("subprocess", False),
("dbt_runner", True),
("dbt_runner", False),
],
)
@patch("cosmos.config.is_profile_cache_enabled")
@patch("cosmos.dbt.graph.Popen")
@patch("cosmos.dbt.graph.dbt_runner.run_command")
def test_load_via_dbt_ls_does_not_create_target_logs_in_original_folder(
mock_popen, is_profile_cache_enabled, enable_cache_profile, tmp_dbt_project_dir, postgres_profile_config
mock_dbt_runner,
mock_popen,
is_profile_cache_enabled,
runner,
enable_cache_profile,
tmp_dbt_project_dir,
postgres_profile_config,
):
import_patch = None
if runner == "subprocess":
original_sys_modules = sys.modules
import_patch = patch.dict(sys.modules, {"dbt.cli.main": None})
import_patch.start()
mock_popen().communicate.return_value = ("", "")
mock_popen().returncode = 0

is_profile_cache_enabled.return_value = enable_cache_profile
mock_popen().communicate.return_value = ("", "")
mock_popen().returncode = 0

assert not (tmp_dbt_project_dir / "target").exists()
assert not (tmp_dbt_project_dir / "logs").exists()

Expand All @@ -467,7 +495,12 @@ def test_load_via_dbt_ls_does_not_create_target_logs_in_original_folder(
assert not (tmp_dbt_project_dir / "target").exists()
assert not (tmp_dbt_project_dir / "logs").exists()

used_cwd = Path(mock_popen.call_args[0][0][5])
if import_patch is not None:
used_cwd = Path(mock_popen.call_args[0][0][5])
import_patch.stop()
sys.modules = original_sys_modules
else:
used_cwd = Path(mock_dbt_runner.call_args[1]["cwd"])
assert used_cwd != project_config.dbt_project_path
assert not used_cwd.exists()

Expand Down Expand Up @@ -637,7 +670,14 @@ def test_load_via_dbt_ls_with_sources(load_method):


@pytest.mark.integration
def test_load_via_dbt_ls_without_dbt_deps(postgres_profile_config):
@pytest.mark.parametrize("runner", ("subprocess", "dbt_runner"))
def test_load_via_dbt_ls_without_dbt_deps(runner, postgres_profile_config):
some_patch = None
if runner == "subprocess":
original_sys_modules = sys.modules
some_patch = patch.dict(sys.modules, {"dbt.cli.main": None})
some_patch.start()

project_config = ProjectConfig(dbt_project_path=DBT_PROJECTS_ROOT_DIR / DBT_PROJECT_NAME)
render_config = RenderConfig(
dbt_project_path=DBT_PROJECTS_ROOT_DIR / DBT_PROJECT_NAME,
Expand All @@ -658,6 +698,10 @@ def test_load_via_dbt_ls_without_dbt_deps(postgres_profile_config):
expected = "Unable to run dbt ls command due to missing dbt_packages. Set RenderConfig.dbt_deps=True."
assert err_info.value.args[0] == expected

if some_patch is not None:
sys.modules = original_sys_modules
some_patch.stop()


@pytest.mark.integration
def test_load_via_dbt_ls_without_dbt_deps_and_preinstalled_dbt_packages(
Expand Down Expand Up @@ -703,15 +747,18 @@ def test_load_via_dbt_ls_without_dbt_deps_and_preinstalled_dbt_packages(


@pytest.mark.integration
@pytest.mark.parametrize("enable_cache_profile", [True, False])
@pytest.mark.parametrize("enable_cache_profile", (True, False))
@patch("cosmos.config.is_profile_cache_enabled")
def test_load_via_dbt_ls_caching_partial_parsing(
is_profile_cache_enabled, enable_cache_profile, tmp_dbt_project_dir, postgres_profile_config, caplog, tmp_path
@patch.dict(sys.modules, {"dbt.cli.main": None})
def test_load_via_dbt_ls_caching_partial_parsing_subprocess(
is_profile_cache_enabled,
enable_cache_profile,
tmp_dbt_project_dir,
postgres_profile_config,
caplog,
tmp_path,
):
"""
When using RenderConfig.enable_mock_profile=False and defining DbtGraph.cache_dir,
Cosmos should leverage dbt partial parsing.
"""

caplog.set_level(logging.DEBUG)

is_profile_cache_enabled.return_value = enable_cache_profile
Expand Down Expand Up @@ -822,6 +869,7 @@ def test_load_via_dbt_ls_with_zero_returncode_and_non_empty_stderr(

@pytest.mark.integration
@patch("cosmos.dbt.graph.Popen")
@patch.dict(sys.modules, {"dbt.cli.main": None})
def test_load_via_dbt_ls_with_non_zero_returncode(mock_popen, postgres_profile_config):
mock_popen().communicate.return_value = ("", "Some stderr message")
mock_popen().returncode = 1
Expand All @@ -845,6 +893,7 @@ def test_load_via_dbt_ls_with_non_zero_returncode(mock_popen, postgres_profile_c

@pytest.mark.integration
@patch("cosmos.dbt.graph.Popen.communicate", return_value=("Some Runtime Error", ""))
@patch.dict(sys.modules, {"dbt.cli.main": None})
def test_load_via_dbt_ls_with_runtime_error_in_stdout(mock_popen_communicate, postgres_profile_config):
# It may seem strange, but at least until dbt 1.6.0, there are circumstances when it outputs errors to stdout
project_config = ProjectConfig(dbt_project_path=DBT_PROJECTS_ROOT_DIR / DBT_PROJECT_NAME)
Expand Down Expand Up @@ -1104,6 +1153,7 @@ def test_load_via_dbt_ls_file():
],
)
@patch("cosmos.dbt.graph.Popen")
@patch.dict(sys.modules, {"dbt.cli.main": None})
def test_run_command(mock_popen, stdout, returncode):
fake_command = ["fake", "command"]
fake_dir = Path("fake_dir")
Expand All @@ -1121,7 +1171,40 @@ def test_run_command(mock_popen, stdout, returncode):
assert return_value == stdout


@pytest.mark.integration
@patch.dict(sys.modules, {"dbt.cli.main": None})
def test_run_command_success_with_log(tmp_dbt_project_dir):
project_dir = tmp_dbt_project_dir / DBT_PROJECT_NAME
(project_dir / DBT_LOG_FILENAME).touch()
response = run_command(command=["dbt", "deps"], env_vars=os.environ, tmp_dir=project_dir, log_dir=project_dir)
assert "Installing dbt-labs/dbt_utils" in response


@pytest.mark.integration
def test_run_command_with_dbt_runner_exception(tmp_dbt_project_dir):
with pytest.raises(CosmosLoadDbtException) as err_info:
run_command(command=["dbt", "ls"], env_vars=os.environ, tmp_dir=tmp_dbt_project_dir / DBT_PROJECT_NAME)
err_msg = "Unable to run dbt ls command due to missing dbt_packages"
assert err_msg in str(err_info.value)


@pytest.mark.integration
def test_run_command_with_dbt_runner_error(tmp_dbt_project_dir):
project_dir = tmp_dbt_project_dir / DBT_PROJECT_NAME
file_to_be_deleted = project_dir / "packages.yml"
file_to_be_deleted.unlink()

file_to_be_changed = project_dir / "models/staging/stg_orders.sql"
with open(str(file_to_be_changed), "w") as fp:
fp.writelines("select 1 as id")

with pytest.raises(CosmosLoadDbtException) as err_info:
run_command(command=["dbt", "run"], env_vars=os.environ, tmp_dir=project_dir)
assert "Unable to run ['dbt', 'run']" in str(err_info.value)


@patch("cosmos.dbt.graph.Popen")
@patch.dict(sys.modules, {"dbt.cli.main": None})
def test_run_command_none_argument(mock_popen, caplog):
fake_command = ["invalid-cmd", None]
fake_dir = Path("fake_dir")
Expand Down Expand Up @@ -1232,6 +1315,7 @@ def test_parse_dbt_ls_output_with_json_without_tags_or_config():
@patch("cosmos.dbt.graph.Popen")
@patch("cosmos.dbt.graph.DbtGraph.update_node_dependency")
@patch("cosmos.config.RenderConfig.validate_dbt_command")
@patch.dict(sys.modules, {"dbt.cli.main": None})
def test_load_via_dbt_ls_project_config_env_vars(
mock_validate, mock_update_nodes, mock_popen, mock_enable_cache, tmp_dbt_project_dir
):
Expand Down Expand Up @@ -1267,6 +1351,7 @@ def test_load_via_dbt_ls_project_config_env_vars(
@patch("cosmos.dbt.graph.Popen")
@patch("cosmos.dbt.graph.DbtGraph.update_node_dependency")
@patch("cosmos.config.RenderConfig.validate_dbt_command")
@patch.dict(sys.modules, {"dbt.cli.main": None})
def test_profile_created_correctly_with_profile_mapping(
mock_validate,
mock_update_nodes,
Expand Down Expand Up @@ -1300,6 +1385,7 @@ def test_profile_created_correctly_with_profile_mapping(
@patch("cosmos.dbt.graph.Popen")
@patch("cosmos.dbt.graph.DbtGraph.update_node_dependency")
@patch("cosmos.config.RenderConfig.validate_dbt_command")
@patch.dict(sys.modules, {"dbt.cli.main": None})
def test_load_via_dbt_ls_project_config_dbt_vars(
mock_validate, mock_update_nodes, mock_popen, mock_use_case, tmp_dbt_project_dir
):
Expand Down Expand Up @@ -1334,6 +1420,7 @@ def test_load_via_dbt_ls_project_config_dbt_vars(
@patch("cosmos.dbt.graph.Popen")
@patch("cosmos.dbt.graph.DbtGraph.update_node_dependency")
@patch("cosmos.config.RenderConfig.validate_dbt_command")
@patch.dict(sys.modules, {"dbt.cli.main": None})
def test_load_via_dbt_ls_render_config_selector_arg_is_used(
mock_validate, mock_update_nodes, mock_popen, mock_enable_cache, tmp_dbt_project_dir
):
Expand Down Expand Up @@ -1370,6 +1457,7 @@ def test_load_via_dbt_ls_render_config_selector_arg_is_used(
@patch("cosmos.dbt.graph.Popen")
@patch("cosmos.dbt.graph.DbtGraph.update_node_dependency")
@patch("cosmos.config.RenderConfig.validate_dbt_command")
@patch.dict(sys.modules, {"dbt.cli.main": None})
def test_load_via_dbt_ls_render_config_no_partial_parse(
mock_validate, mock_update_nodes, mock_popen, mock_enable_cache, tmp_dbt_project_dir
):
Expand Down Expand Up @@ -1568,7 +1656,7 @@ def test_run_dbt_deps(run_command_mock):
graph = DbtGraph(project=project_config)
graph.local_flags = []
graph.run_dbt_deps("dbt", "/some/path", {})
run_command_mock.assert_called_with(["dbt", "deps", "--vars", '{"var-key": "var-value"}'], "/some/path", {})
run_command_mock.assert_called_with(["dbt", "deps", "--vars", '{"var-key": "var-value"}'], "/some/path", {}, None)


@pytest.fixture()
Expand Down Expand Up @@ -1609,7 +1697,7 @@ def test_save_dbt_ls_cache(mock_variable_set, mock_datetime, tmp_dbt_project_dir
hash_dir, hash_args = version.split(",")
assert hash_args == "d41d8cd98f00b204e9800998ecf8427e"
if sys.platform == "darwin":
assert hash_dir == "fa5edac64de49909d4b8cbc4dc8abd4f"
assert hash_dir == "af89237a0cdef7edce53fe4d4160fa79"
else:
assert hash_dir == "9c9f712b6f6f1ace880dfc7f5f4ff051"

Expand Down
Loading

0 comments on commit e11e5ae

Please sign in to comment.