From 24108f00f2173034238265629698c92fc54c6706 Mon Sep 17 00:00:00 2001 From: Pankaj Koti Date: Wed, 5 Feb 2025 12:15:29 +0530 Subject: [PATCH] Create and run accurate SQL statements when using `ExecutionMode.AIRFLOW_ASYNC` (#1474) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Overview This PR introduces a reliable way to extract SQL statements run by `dbt-core` so Airflow asynchronous operators can use them. It fixes the experimental BQ implementation of `ExecutionMode.AIRFLOW_ASYNC` introduced in Cosmos 1.7 (#1230). Previously, in #1230, we attempted to understand the implementation of how `dbt-core` runs `--full-refresh` for BQ, and we hard-coded the SQL header in Cosmos as an experimental feature. Since then, we realised that this approach was prone to errors (e.g. #1260) and that it is unrealistic for Cosmos to try to recreate the logic of how `dbt-core` and its adaptors generate all the SQL statements for different operations, data warehouses, and types of materialisation. With this PR, we use `dbt-core` to create the complete SQL statements without `dbt-core` running those transformations. This enables better compatibility with various `dbt-core` features while ensuring correctness in running models. The drawback of the current approach is that it relies on monkey patching, a technique used to dynamically update the behaviour of a piece of code at run-time. Cosmos is monkey patching `dbt-core` adaptors methods at the moment that they would generally execute SQL statements - Cosmos modifies this behaviour so that the SQL statements are writen to disk without performing any operations to the actual data warehouse. The main drawback of this strategy is in case dbt changes its interface. For this reason, we logged the follow-up ticket https://github.com/astronomer/astronomer-cosmos/issues/1489 to make sure we test the latest version of dbt and its adapters and confirm the monkey patching works as expected regardless of the version being used. That said, since the method being monkey patched is part of the `dbt-core` interface with its adaptors, we believe the risks of breaking changes will be low. The other challenge with the current approach is that every Cosmos task relies on the following: 1. `dbt-core` being installed alongside the Airflow installation 2. the execution of a significant part of the `dbtRunner` logic We have logged a follow-up ticket to evaluate the possibility of overcoming these challenges: #1477 ## Key Changes 1. Mocked BigQuery Adapter Execution: - Introduced `_mock_bigquery_adapter()` to override `BigQueryConnectionManager.execute`, ensuring SQL is only written to the `target` directory and skipping execution in the warehouse. - The generated SQL is then submitted using Airflow’s BigQueryInsertJobOperator in deferrable mode. 4. Refactoring `AbstractDbtBaseOperator`: - Previously, `AbstractDbtBaseOperator` inherited `BaseOperator`, causing conflicts when used with `BigQueryInsertJobOperator` with our`EXECUTIONMODE.AIRFLOW_ASYNC` classes and the interface built in #1483 - Refactored to `AbstractDbtBase` (no longer inheriting `BaseOperator`), requiring explicit `BaseOperator` initialization in all derived operators. - Updated the below existing operators to consider this refactoring needing derived classes to initialise `BaseOperator`: - `DbtAzureContainerInstanceBaseOperator` - `DbtDockerBaseOperator` - `DbtGcpCloudRunJobBaseOperator` - `DbtKubernetesBaseOperator` 5. Changes to dbt Compilation Workflow - Removed `_add_dbt_compile_task`, which previously pre-generated SQL and uploaded it to remote storage and subsequent task downloaded this compiled SQL for their execution. - Instead, `dbt run` is now directly invoked in each task using the mocked adapter to generate the full SQL. - A future [issue](https://github.com/astronomer/astronomer-cosmos/issues/1477) will assess whether we should reintroduce a compile task using the mocked adapter for SQL generation and upload, reducing redundant dbt calls in each task. ## Issue updates The PR fixes the following issues: 1. closes: #1260 - Previously, we only supported --full-refresh dbt run with static SQL headers (e.g., CREATE/DROP TABLE). - Now, we support dynamic SQL headers based on materializations, including CREATE OR REPLACE TABLE, CREATE OR REPLACE VIEW, etc. 2. closes: #1271 - dbt macros are evaluated at runtime during dbt run invocation using mocked adapter, and this PR lays the groundwork for supporting them in async execution mode. 3. closes: #1265 - Now, large datasets can avoid full drops and recreations, enabling incremental model updates. 6. closes: #1261 - Previously, only tables (--full-refresh) were supported; this PR implements logic for handling different materializations that dbt supports like table, view, incremental, ephemeral, and materialized views. 7. closes: #1266 - Instead of relying on dbt compile (which only outputs SELECT statements), we now let dbt generate complete SQL queries, including SQL headers/DDL statements for the queries corresponding to the resource nodes and state of tables/views in the backend warehouse 8. closes: #1264 - We support emitting datasets for `EXECUTIONMODE.AIRFLOW_ASYNC` too with this PR ## Example DAG showing `EXECUTIONMODE.AIRFLOW_ASYNC` deferring tasks and the dynamic query submitted in the logs Screenshot 2025-02-04 at 1 02 42 PM ## Next Steps & Considerations: - It's acknowledged that using mock patching may have downsides, however, this currently seems the best approach to achieve our goals. It's understood and accepted the risks associated with this method. To mitigate them, we are expanding our test coverage to include all currently supported dbt adapter versions in our test matrix in #1489. This will ensure compatibility across different dbt versions and helps us catch potential issues early. - Further validation of different dbt macros and materializations with `ExecutionMode.AIRFLOW_ASYNC` by seeking feedback from users by testing alpha https://github.com/astronomer/astronomer-cosmos/releases/tag/astronomer-cosmos-v1.9.0a5 created with changes from this PR. - https://github.com/astronomer/astronomer-cosmos/issues/1477, Compare the efficiency of generating SQL dynamically vs. pre-compiling and uploading SQL via a separate task. - Add compatibility across all major cloud datawarehouse backends (dbt adapters). --------- Co-authored-by: Tatiana Al-Chueyr Co-authored-by: Pankaj Singh <98807258+pankajastro@users.noreply.github.com> --- CHANGELOG.rst | 6 +- cosmos/__init__.py | 2 +- cosmos/airflow/graph.py | 28 -- cosmos/constants.py | 1 + cosmos/dbt_adapters/__init__.py | 18 ++ cosmos/dbt_adapters/bigquery.py | 33 +++ cosmos/operators/_asynchronous/base.py | 31 ++- cosmos/operators/_asynchronous/bigquery.py | 100 +++---- cosmos/operators/_asynchronous/databricks.py | 1 + cosmos/operators/airflow_async.py | 55 ++-- cosmos/operators/azure_container_instance.py | 52 ++-- cosmos/operators/base.py | 20 +- cosmos/operators/docker.py | 33 ++- cosmos/operators/gcp_cloud_run_job.py | 38 ++- cosmos/operators/kubernetes.py | 28 +- cosmos/operators/local.py | 253 +++++++++++++----- cosmos/operators/virtualenv.py | 18 +- dev/dags/simple_dag_async.py | 2 +- tests/airflow/test_graph.py | 41 +-- tests/dbt_adapters/test_bigquery.py | 45 ++++ tests/dbt_adapters/test_init.py | 15 ++ tests/operators/_asynchronous/test_base.py | 70 +++-- .../operators/_asynchronous/test_bigquery.py | 127 ++++----- tests/operators/test_aws_eks.py | 1 - tests/operators/test_base.py | 44 ++- tests/operators/test_kubernetes.py | 45 ++-- tests/operators/test_local.py | 150 ++++++++++- 27 files changed, 825 insertions(+), 432 deletions(-) create mode 100644 cosmos/dbt_adapters/__init__.py create mode 100644 cosmos/dbt_adapters/bigquery.py create mode 100644 tests/dbt_adapters/test_bigquery.py create mode 100644 tests/dbt_adapters/test_init.py diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 4c9d6c809..8044eade4 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,7 +1,7 @@ Changelog ========= -1.9.0a4 (2025-01-29) +1.9.0a5 (2025-02-03) -------------------- Breaking changes @@ -18,6 +18,7 @@ Features * Allow users to opt-out of ``dbtRunner`` during DAG parsing with ``InvocationMode.SUBPROCESS`` by @tatiana in #1495. Check out the `documentation `_. * Add structure to support multiple db for async operator execution by @pankajastro in #1483 * Support overriding the ``profile_config`` per dbt node or folder using config by @tatiana in #1492. More information `here `_. +* Create and run accurate SQL statements when using ``ExecutionMode.AIRFLOW_ASYNC`` by @pankajkoti, @tatiana and @pankajastro in #1474 Bug Fixes @@ -27,9 +28,12 @@ Enhancement * Fix OpenLineage deprecation warning by @CorsettiS in #1449 * Move ``DbtRunner`` related functions into ``dbt/runner.py`` module by @tatiana in #1480 +* Add ``on_warning_callback`` to ``DbtSourceKubernetesOperator`` and refactor previous operators by @LuigiCerone in #1501 + Others +* Ignore dbt package tests when running Cosmos tests by @tatiana in #1502 * GitHub Actions Dependabot: #1487 * Pre-commit updates: #1473, #1493 diff --git a/cosmos/__init__.py b/cosmos/__init__.py index e245fb7e6..7374e9db6 100644 --- a/cosmos/__init__.py +++ b/cosmos/__init__.py @@ -6,7 +6,7 @@ Contains dags, task groups, and operators. """ -__version__ = "1.9.0a4" +__version__ = "1.9.0a5" from cosmos.airflow.dag import DbtDag diff --git a/cosmos/airflow/graph.py b/cosmos/airflow/graph.py index ef742f8eb..2c65361c5 100644 --- a/cosmos/airflow/graph.py +++ b/cosmos/airflow/graph.py @@ -11,7 +11,6 @@ from cosmos.config import RenderConfig from cosmos.constants import ( - DBT_COMPILE_TASK_ID, DEFAULT_DBT_RESOURCES, SUPPORTED_BUILD_RESOURCES, TESTABLE_DBT_RESOURCES, @@ -392,32 +391,6 @@ def generate_task_or_group( return task_or_group -def _add_dbt_compile_task( - nodes: dict[str, DbtNode], - dag: DAG, - execution_mode: ExecutionMode, - task_args: dict[str, Any], - tasks_map: dict[str, Any], - task_group: TaskGroup | None, -) -> None: - if execution_mode != ExecutionMode.AIRFLOW_ASYNC: - return - - compile_task_metadata = TaskMetadata( - id=DBT_COMPILE_TASK_ID, - operator_class="cosmos.operators.airflow_async.DbtCompileAirflowAsyncOperator", - arguments=task_args, - extra_context={"dbt_dag_task_group_identifier": _get_dbt_dag_task_group_identifier(dag, task_group)}, - ) - compile_airflow_task = create_airflow_task(compile_task_metadata, dag, task_group=task_group) - - for task_id, task in tasks_map.items(): - if not task.upstream_list: - compile_airflow_task >> task - - tasks_map[DBT_COMPILE_TASK_ID] = compile_airflow_task - - def _get_dbt_dag_task_group_identifier(dag: DAG, task_group: TaskGroup | None) -> str: dag_id = dag.dag_id task_group_id = task_group.group_id if task_group else None @@ -588,7 +561,6 @@ def build_airflow_graph( tasks_map[node_id] = test_task create_airflow_task_dependencies(nodes, tasks_map) - _add_dbt_compile_task(nodes, dag, execution_mode, task_args, tasks_map, task_group) return tasks_map diff --git a/cosmos/constants.py b/cosmos/constants.py index 0513d50d2..a68f5a836 100644 --- a/cosmos/constants.py +++ b/cosmos/constants.py @@ -6,6 +6,7 @@ import aenum from packaging.version import Version +BIGQUERY_PROFILE_TYPE = "bigquery" DBT_PROFILE_PATH = Path(os.path.expanduser("~")).joinpath(".dbt/profiles.yml") DEFAULT_DBT_PROFILE_NAME = "cosmos_profile" DEFAULT_DBT_TARGET_NAME = "cosmos_target" diff --git a/cosmos/dbt_adapters/__init__.py b/cosmos/dbt_adapters/__init__.py new file mode 100644 index 000000000..9c4f4dec0 --- /dev/null +++ b/cosmos/dbt_adapters/__init__.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +from typing import Any + +from cosmos.constants import BIGQUERY_PROFILE_TYPE +from cosmos.dbt_adapters.bigquery import _associate_bigquery_async_op_args, _mock_bigquery_adapter + +PROFILE_TYPE_MOCK_ADAPTER_CALLABLE_MAP = { + BIGQUERY_PROFILE_TYPE: _mock_bigquery_adapter, +} + +PROFILE_TYPE_ASSOCIATE_ARGS_CALLABLE_MAP = { + BIGQUERY_PROFILE_TYPE: _associate_bigquery_async_op_args, +} + + +def associate_async_operator_args(async_operator_obj: Any, profile_type: str, **kwargs: Any) -> Any: + return PROFILE_TYPE_ASSOCIATE_ARGS_CALLABLE_MAP[profile_type](async_operator_obj, **kwargs) diff --git a/cosmos/dbt_adapters/bigquery.py b/cosmos/dbt_adapters/bigquery.py new file mode 100644 index 000000000..e7876e06b --- /dev/null +++ b/cosmos/dbt_adapters/bigquery.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +from typing import Any + +from cosmos.exceptions import CosmosValueError + + +def _mock_bigquery_adapter() -> None: + from typing import Optional, Tuple + + import agate + from dbt.adapters.bigquery.connections import BigQueryAdapterResponse, BigQueryConnectionManager + from dbt_common.clients.agate_helper import empty_table + + def execute( # type: ignore[no-untyped-def] + self, sql, auto_begin=False, fetch=None, limit: Optional[int] = None + ) -> Tuple[BigQueryAdapterResponse, agate.Table]: + return BigQueryAdapterResponse("mock_bigquery_adapter_response"), empty_table() + + BigQueryConnectionManager.execute = execute + + +def _associate_bigquery_async_op_args(async_op_obj: Any, **kwargs: Any) -> Any: + sql = kwargs.get("sql") + if not sql: + raise CosmosValueError("Keyword argument 'sql' is required for BigQuery Async operator") + async_op_obj.configuration = { + "query": { + "query": sql, + "useLegacySql": False, + } + } + return async_op_obj diff --git a/cosmos/operators/_asynchronous/base.py b/cosmos/operators/_asynchronous/base.py index e957c9cac..f8d41b88c 100644 --- a/cosmos/operators/_asynchronous/base.py +++ b/cosmos/operators/_asynchronous/base.py @@ -1,9 +1,8 @@ +from __future__ import annotations + import importlib import logging -from abc import ABCMeta -from typing import Any, Sequence - -from airflow.utils.context import Context +from typing import Any from cosmos.airflow.graph import _snake_case_to_camelcase from cosmos.config import ProfileConfig @@ -36,11 +35,16 @@ def _create_async_operator_class(profile_type: str, dbt_class: str) -> Any: return DbtRunLocalOperator -class DbtRunAirflowAsyncFactoryOperator(DbtRunLocalOperator, metaclass=ABCMeta): # type: ignore[misc] +class DbtRunAirflowAsyncFactoryOperator(DbtRunLocalOperator): # type: ignore[misc] - template_fields: Sequence[str] = DbtRunLocalOperator.template_fields + ("project_dir",) # type: ignore[operator] - - def __init__(self, project_dir: str, profile_config: ProfileConfig, **kwargs: Any): + def __init__( + self, + project_dir: str, + profile_config: ProfileConfig, + extra_context: dict[str, object] | None = None, + dbt_kwargs: dict[str, object] | None = None, + **kwargs: Any, + ) -> None: self.project_dir = project_dir self.profile_config = profile_config @@ -51,7 +55,13 @@ def __init__(self, project_dir: str, profile_config: ProfileConfig, **kwargs: An # When using composition instead of inheritance to initialize the async class and run its execute method, # Airflow throws a `DuplicateTaskIdFound` error. DbtRunAirflowAsyncFactoryOperator.__bases__ = (async_operator_class,) - super().__init__(project_dir=project_dir, profile_config=profile_config, **kwargs) + super().__init__( + project_dir=project_dir, + profile_config=profile_config, + extra_context=extra_context, + dbt_kwargs=dbt_kwargs, + **kwargs, + ) def create_async_operator(self) -> Any: @@ -60,6 +70,3 @@ def create_async_operator(self) -> Any: async_class_operator = _create_async_operator_class(profile_type, "DbtRun") return async_class_operator - - def execute(self, context: Context) -> None: - super().execute(context) diff --git a/cosmos/operators/_asynchronous/bigquery.py b/cosmos/operators/_asynchronous/bigquery.py index decbf8d77..1c5dc01a8 100644 --- a/cosmos/operators/_asynchronous/bigquery.py +++ b/cosmos/operators/_asynchronous/bigquery.py @@ -1,22 +1,23 @@ from __future__ import annotations -from pathlib import Path -from typing import TYPE_CHECKING, Any, Sequence +from typing import Any, Sequence -from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook +import airflow from airflow.providers.google.cloud.operators.bigquery import BigQueryInsertJobOperator from airflow.utils.context import Context +from packaging.version import Version from cosmos import settings from cosmos.config import ProfileConfig -from cosmos.exceptions import CosmosValueError -from cosmos.settings import remote_target_path, remote_target_path_conn_id +from cosmos.dataset import get_dataset_alias_name +from cosmos.operators.local import AbstractDbtLocalBase +AIRFLOW_VERSION = Version(airflow.__version__) -class DbtRunAirflowAsyncBigqueryOperator(BigQueryInsertJobOperator): # type: ignore[misc] + +class DbtRunAirflowAsyncBigqueryOperator(BigQueryInsertJobOperator, AbstractDbtLocalBase): # type: ignore[misc] template_fields: Sequence[str] = ( - "full_refresh", "gcp_project", "dataset", "location", @@ -27,6 +28,7 @@ def __init__( project_dir: str, profile_config: ProfileConfig, extra_context: dict[str, Any] | None = None, + dbt_kwargs: dict[str, Any] | None = None, **kwargs: Any, ): self.project_dir = project_dir @@ -36,73 +38,35 @@ def __init__( self.gcp_project = profile["project"] self.dataset = profile["dataset"] self.extra_context = extra_context or {} - self.full_refresh = None - if "full_refresh" in kwargs: - self.full_refresh = kwargs.pop("full_refresh") self.configuration: dict[str, Any] = {} + self.dbt_kwargs = dbt_kwargs or {} + task_id = self.dbt_kwargs.pop("task_id") + AbstractDbtLocalBase.__init__( + self, task_id=task_id, project_dir=project_dir, profile_config=profile_config, **self.dbt_kwargs + ) + if kwargs.get("emit_datasets", True) and settings.enable_dataset_alias and AIRFLOW_VERSION >= Version("2.10"): + from airflow.datasets import DatasetAlias + + # ignoring the type because older versions of Airflow raise the follow error in mypy + # error: Incompatible types in assignment (expression has type "list[DatasetAlias]", target has type "str") + dag_id = kwargs.get("dag") + task_group_id = kwargs.get("task_group") + kwargs["outlets"] = [ + DatasetAlias(name=get_dataset_alias_name(dag_id, task_group_id, self.task_id)) + ] # type: ignore super().__init__( gcp_conn_id=self.gcp_conn_id, configuration=self.configuration, deferrable=True, **kwargs, ) + self.async_context = extra_context or {} + self.async_context["profile_type"] = self.profile_config.get_profile_type() + self.async_context["async_operator"] = BigQueryInsertJobOperator - def get_remote_sql(self) -> str: - if not settings.AIRFLOW_IO_AVAILABLE: - raise CosmosValueError(f"Cosmos async support is only available starting in Airflow 2.8 or later.") - from airflow.io.path import ObjectStoragePath - - file_path = self.extra_context["dbt_node_config"]["file_path"] # type: ignore - dbt_dag_task_group_identifier = self.extra_context["dbt_dag_task_group_identifier"] - - remote_target_path_str = str(remote_target_path).rstrip("/") - - if TYPE_CHECKING: # pragma: no cover - assert self.project_dir is not None - - project_dir_parent = str(Path(self.project_dir).parent) - relative_file_path = str(file_path).replace(project_dir_parent, "").lstrip("/") - remote_model_path = f"{remote_target_path_str}/{dbt_dag_task_group_identifier}/compiled/{relative_file_path}" - - object_storage_path = ObjectStoragePath(remote_model_path, conn_id=remote_target_path_conn_id) - with object_storage_path.open() as fp: # type: ignore - return fp.read() # type: ignore - - def drop_table_sql(self) -> None: - model_name = self.extra_context["dbt_node_config"]["resource_name"] # type: ignore - sql = f"DROP TABLE IF EXISTS {self.gcp_project}.{self.dataset}.{model_name};" - - hook = BigQueryHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) - self.configuration = { - "query": { - "query": sql, - "useLegacySql": False, - } - } - hook.insert_job(configuration=self.configuration, location=self.location, project_id=self.gcp_project) - - def execute(self, context: Context) -> Any | None: + @property + def base_cmd(self) -> list[str]: + return ["run"] - if not self.full_refresh: - raise CosmosValueError("The async execution only supported for full_refresh") - else: - # It may be surprising to some, but the dbt-core --full-refresh argument fully drops the table before populating it - # https://github.com/dbt-labs/dbt-core/blob/5e9f1b515f37dfe6cdae1ab1aa7d190b92490e24/core/dbt/context/base.py#L662-L666 - # https://docs.getdbt.com/reference/resource-configs/full_refresh#recommendation - # We're emulating this behaviour here - # The compiled SQL has several limitations here, but these will be addressed in the PR: https://github.com/astronomer/astronomer-cosmos/pull/1474. - self.drop_table_sql() - sql = self.get_remote_sql() - model_name = self.extra_context["dbt_node_config"]["resource_name"] # type: ignore - # prefix explicit create command to create table - sql = f"CREATE TABLE {self.gcp_project}.{self.dataset}.{model_name} AS {sql}" - self.configuration = { - "query": { - "query": sql, - "useLegacySql": False, - } - } - return super().execute(context) + def execute(self, context: Context, **kwargs: Any) -> None: + self.build_and_run_cmd(context=context, run_as_async=True, async_context=self.async_context) diff --git a/cosmos/operators/_asynchronous/databricks.py b/cosmos/operators/_asynchronous/databricks.py index d49fd0be0..6e39bfd7c 100644 --- a/cosmos/operators/_asynchronous/databricks.py +++ b/cosmos/operators/_asynchronous/databricks.py @@ -1,4 +1,5 @@ # TODO: Implement it +from __future__ import annotations from typing import Any diff --git a/cosmos/operators/airflow_async.py b/cosmos/operators/airflow_async.py index de8d041c4..d6b1bda5a 100644 --- a/cosmos/operators/airflow_async.py +++ b/cosmos/operators/airflow_async.py @@ -1,10 +1,12 @@ from __future__ import annotations import inspect +from typing import Any from cosmos.config import ProfileConfig +from cosmos.constants import BIGQUERY_PROFILE_TYPE from cosmos.operators._asynchronous.base import DbtRunAirflowAsyncFactoryOperator -from cosmos.operators.base import AbstractDbtBaseOperator +from cosmos.operators.base import AbstractDbtBase from cosmos.operators.local import ( DbtBuildLocalOperator, DbtCloneLocalOperator, @@ -18,81 +20,76 @@ DbtTestLocalOperator, ) -_SUPPORTED_DATABASES = ["bigquery"] +_SUPPORTED_DATABASES = [BIGQUERY_PROFILE_TYPE] -from abc import ABCMeta -from airflow.models.baseoperator import BaseOperator - - -class DbtBaseAirflowAsyncOperator(BaseOperator, metaclass=ABCMeta): - def __init__(self, **kwargs) -> None: # type: ignore - if "location" in kwargs: - kwargs.pop("location") - super().__init__(**kwargs) - - -class DbtBuildAirflowAsyncOperator(DbtBaseAirflowAsyncOperator, DbtBuildLocalOperator): # type: ignore +class DbtBuildAirflowAsyncOperator(DbtBuildLocalOperator): pass -class DbtLSAirflowAsyncOperator(DbtBaseAirflowAsyncOperator, DbtLSLocalOperator): # type: ignore +class DbtLSAirflowAsyncOperator(DbtLSLocalOperator): pass -class DbtSeedAirflowAsyncOperator(DbtBaseAirflowAsyncOperator, DbtSeedLocalOperator): # type: ignore +class DbtSeedAirflowAsyncOperator(DbtSeedLocalOperator): pass -class DbtSnapshotAirflowAsyncOperator(DbtBaseAirflowAsyncOperator, DbtSnapshotLocalOperator): # type: ignore +class DbtSnapshotAirflowAsyncOperator(DbtSnapshotLocalOperator): pass -class DbtSourceAirflowAsyncOperator(DbtBaseAirflowAsyncOperator, DbtSourceLocalOperator): # type: ignore +class DbtSourceAirflowAsyncOperator(DbtSourceLocalOperator): pass -class DbtRunAirflowAsyncOperator(DbtRunAirflowAsyncFactoryOperator): # type: ignore +class DbtRunAirflowAsyncOperator(DbtRunAirflowAsyncFactoryOperator): - def __init__( # type: ignore + def __init__( self, project_dir: str, profile_config: ProfileConfig, extra_context: dict[str, object] | None = None, - **kwargs, + **kwargs: Any, ) -> None: # Cosmos attempts to pass many kwargs that async operator simply does not accept. # We need to pop them. clean_kwargs = {} - non_async_args = set(inspect.signature(AbstractDbtBaseOperator.__init__).parameters.keys()) + non_async_args = set(inspect.signature(AbstractDbtBase.__init__).parameters.keys()) non_async_args |= set(inspect.signature(DbtLocalBaseOperator.__init__).parameters.keys()) - non_async_args -= {"task_id"} + + dbt_kwargs = {} for arg_key, arg_value in kwargs.items(): - if arg_key not in non_async_args: + if arg_key == "task_id": + clean_kwargs[arg_key] = arg_value + dbt_kwargs[arg_key] = arg_value + elif arg_key not in non_async_args: clean_kwargs[arg_key] = arg_value + else: + dbt_kwargs[arg_key] = arg_value - # The following are the minimum required parameters to run BigQueryInsertJobOperator using the deferrable mode super().__init__( project_dir=project_dir, profile_config=profile_config, extra_context=extra_context, + dbt_kwargs=dbt_kwargs, **clean_kwargs, ) -class DbtTestAirflowAsyncOperator(DbtBaseAirflowAsyncOperator, DbtTestLocalOperator): # type: ignore +class DbtTestAirflowAsyncOperator(DbtTestLocalOperator): pass -class DbtRunOperationAirflowAsyncOperator(DbtBaseAirflowAsyncOperator, DbtRunOperationLocalOperator): # type: ignore +class DbtRunOperationAirflowAsyncOperator(DbtRunOperationLocalOperator): pass -class DbtCompileAirflowAsyncOperator(DbtBaseAirflowAsyncOperator, DbtCompileLocalOperator): # type: ignore +class DbtCompileAirflowAsyncOperator(DbtCompileLocalOperator): pass -class DbtCloneAirflowAsyncOperator(DbtBaseAirflowAsyncOperator, DbtCloneLocalOperator): +class DbtCloneAirflowAsyncOperator(DbtCloneLocalOperator): pass diff --git a/cosmos/operators/azure_container_instance.py b/cosmos/operators/azure_container_instance.py index 7f335bd99..aeeec1a23 100644 --- a/cosmos/operators/azure_container_instance.py +++ b/cosmos/operators/azure_container_instance.py @@ -1,12 +1,13 @@ from __future__ import annotations +import inspect from typing import Any, Callable, Sequence from airflow.utils.context import Context from cosmos.config import ProfileConfig from cosmos.operators.base import ( - AbstractDbtBaseOperator, + AbstractDbtBase, DbtBuildMixin, DbtCloneMixin, DbtLSMixin, @@ -28,13 +29,13 @@ ) -class DbtAzureContainerInstanceBaseOperator(AbstractDbtBaseOperator, AzureContainerInstancesOperator): # type: ignore +class DbtAzureContainerInstanceBaseOperator(AbstractDbtBase, AzureContainerInstancesOperator): # type: ignore """ Executes a dbt core cli command in an Azure Container Instance """ template_fields: Sequence[str] = tuple( - list(AbstractDbtBaseOperator.template_fields) + list(AzureContainerInstancesOperator.template_fields) + list(AbstractDbtBase.template_fields) + list(AzureContainerInstancesOperator.template_fields) ) def __init__( @@ -51,19 +52,40 @@ def __init__( **kwargs: Any, ) -> None: self.profile_config = profile_config - super().__init__( - ci_conn_id=ci_conn_id, - resource_group=resource_group, - name=name, - image=image, - region=region, - remove_on_error=remove_on_error, - fail_if_exists=fail_if_exists, - registry_conn_id=registry_conn_id, - **kwargs, + kwargs.update( + { + "ci_conn_id": ci_conn_id, + "resource_group": resource_group, + "name": name, + "image": image, + "region": region, + "remove_on_error": remove_on_error, + "fail_if_exists": fail_if_exists, + "registry_conn_id": registry_conn_id, + } ) - - def build_and_run_cmd(self, context: Context, cmd_flags: list[str] | None = None) -> None: + super().__init__(**kwargs) + # In PR #1474, we refactored cosmos.operators.base.AbstractDbtBase to remove its inheritance from BaseOperator + # and eliminated the super().__init__() call. This change was made to resolve conflicts in parent class + # initializations while adding support for ExecutionMode.AIRFLOW_ASYNC. Operators under this mode inherit + # Airflow provider operators that enable deferrable SQL query execution. Since super().__init__() was removed + # from AbstractDbtBase and different parent classes require distinct initialization arguments, we explicitly + # initialize them (including the BaseOperator) here by segregating the required arguments for each parent class. + base_operator_args = set(inspect.signature(AzureContainerInstancesOperator.__init__).parameters.keys()) + base_kwargs = {} + for arg_key, arg_value in kwargs.items(): + if arg_key in base_operator_args: + base_kwargs[arg_key] = arg_value + base_kwargs["task_id"] = kwargs["task_id"] + AzureContainerInstancesOperator.__init__(self, **base_kwargs) + + def build_and_run_cmd( + self, + context: Context, + cmd_flags: list[str] | None = None, + run_as_async: bool = False, + async_context: dict[str, Any] | None = None, + ) -> Any: self.build_command(context, cmd_flags) self.log.info(f"Running command: {self.command}") result = AzureContainerInstancesOperator.execute(self, context) diff --git a/cosmos/operators/base.py b/cosmos/operators/base.py index 52fb98bac..18019ab92 100644 --- a/cosmos/operators/base.py +++ b/cosmos/operators/base.py @@ -1,20 +1,21 @@ from __future__ import annotations +import logging import os from abc import ABCMeta, abstractmethod from pathlib import Path from typing import Any, Sequence, Tuple import yaml -from airflow.models.baseoperator import BaseOperator from airflow.utils.context import Context, context_merge from airflow.utils.operator_helpers import context_to_airflow_vars from airflow.utils.strings import to_boolean from cosmos.dbt.executable import get_system_dbt +from cosmos.log import get_logger -class AbstractDbtBaseOperator(BaseOperator, metaclass=ABCMeta): +class AbstractDbtBase(metaclass=ABCMeta): """ Executes a dbt core cli command. @@ -140,7 +141,6 @@ def __init__( self.cache_dir = cache_dir self.extra_context = extra_context or {} kwargs.pop("full_refresh", None) # usage of this param should be implemented in child classes - super().__init__(**kwargs) def get_env(self, context: Context) -> dict[str, str | bytes | os.PathLike[Any]]: """ @@ -191,6 +191,10 @@ def get_env(self, context: Context) -> dict[str, str | bytes | os.PathLike[Any]] return filtered_env + @property + def log(self) -> logging.Logger: + return get_logger(__name__) + def add_global_flags(self) -> list[str]: flags = [] for global_flag in self.global_flags: @@ -258,10 +262,16 @@ def build_cmd( return dbt_cmd, env @abstractmethod - def build_and_run_cmd(self, context: Context, cmd_flags: list[str]) -> Any: + def build_and_run_cmd( + self, + context: Context, + cmd_flags: list[str], + run_as_async: bool = False, + async_context: dict[str, Any] | None = None, + ) -> Any: """Override this method for the operator to execute the dbt command""" - def execute(self, context: Context) -> Any | None: # type: ignore + def execute(self, context: Context, **kwargs) -> Any | None: # type: ignore if self.extra_context: context_merge(context, self.extra_context) diff --git a/cosmos/operators/docker.py b/cosmos/operators/docker.py index 8dc614cfc..879a8164c 100644 --- a/cosmos/operators/docker.py +++ b/cosmos/operators/docker.py @@ -1,5 +1,6 @@ from __future__ import annotations +import inspect from typing import Any, Callable, Sequence from airflow.utils.context import Context @@ -7,7 +8,7 @@ from cosmos.config import ProfileConfig from cosmos.exceptions import CosmosValueError from cosmos.operators.base import ( - AbstractDbtBaseOperator, + AbstractDbtBase, DbtBuildMixin, DbtCloneMixin, DbtLSMixin, @@ -29,15 +30,13 @@ ) -class DbtDockerBaseOperator(AbstractDbtBaseOperator, DockerOperator): # type: ignore +class DbtDockerBaseOperator(AbstractDbtBase, DockerOperator): # type: ignore """ Executes a dbt core cli command in a Docker container. """ - template_fields: Sequence[str] = tuple( - list(AbstractDbtBaseOperator.template_fields) + list(DockerOperator.template_fields) - ) + template_fields: Sequence[str] = tuple(list(AbstractDbtBase.template_fields) + list(DockerOperator.template_fields)) intercept_flag = False @@ -56,8 +55,28 @@ def __init__( ) super().__init__(image=image, **kwargs) - - def build_and_run_cmd(self, context: Context, cmd_flags: list[str] | None = None) -> Any: + # In PR #1474, we refactored cosmos.operators.base.AbstractDbtBase to remove its inheritance from BaseOperator + # and eliminated the super().__init__() call. This change was made to resolve conflicts in parent class + # initializations while adding support for ExecutionMode.AIRFLOW_ASYNC. Operators under this mode inherit + # Airflow provider operators that enable deferrable SQL query execution. Since super().__init__() was removed + # from AbstractDbtBase and different parent classes require distinct initialization arguments, we explicitly + # initialize them (including the BaseOperator) here by segregating the required arguments for each parent class. + kwargs["image"] = image + base_operator_args = set(inspect.signature(DockerOperator.__init__).parameters.keys()) + base_kwargs = {} + for arg_key, arg_value in kwargs.items(): + if arg_key in base_operator_args: + base_kwargs[arg_key] = arg_value + base_kwargs["task_id"] = kwargs["task_id"] + DockerOperator.__init__(self, **base_kwargs) + + def build_and_run_cmd( + self, + context: Context, + cmd_flags: list[str] | None = None, + run_as_async: bool = False, + async_context: dict[str, Any] | None = None, + ) -> Any: self.build_command(context, cmd_flags) self.log.info(f"Running command: {self.command}") result = DockerOperator.execute(self, context) diff --git a/cosmos/operators/gcp_cloud_run_job.py b/cosmos/operators/gcp_cloud_run_job.py index ef47db2cc..e24191d6a 100644 --- a/cosmos/operators/gcp_cloud_run_job.py +++ b/cosmos/operators/gcp_cloud_run_job.py @@ -8,7 +8,7 @@ from cosmos.config import ProfileConfig from cosmos.log import get_logger from cosmos.operators.base import ( - AbstractDbtBaseOperator, + AbstractDbtBase, DbtBuildMixin, DbtCloneMixin, DbtLSMixin, @@ -41,14 +41,14 @@ ) -class DbtGcpCloudRunJobBaseOperator(AbstractDbtBaseOperator, CloudRunExecuteJobOperator): # type: ignore +class DbtGcpCloudRunJobBaseOperator(AbstractDbtBase, CloudRunExecuteJobOperator): # type: ignore """ Executes a dbt core cli command in a Cloud Run Job instance with dbt installed in it. """ template_fields: Sequence[str] = tuple( - list(AbstractDbtBaseOperator.template_fields) + list(CloudRunExecuteJobOperator.template_fields) + list(AbstractDbtBase.template_fields) + list(CloudRunExecuteJobOperator.template_fields) ) intercept_flag = False @@ -69,8 +69,36 @@ def __init__( self.command = command self.environment_variables = environment_variables or DEFAULT_ENVIRONMENT_VARIABLES super().__init__(project_id=project_id, region=region, job_name=job_name, **kwargs) - - def build_and_run_cmd(self, context: Context, cmd_flags: list[str] | None = None) -> Any: + # In PR #1474, we refactored cosmos.operators.base.AbstractDbtBase to remove its inheritance from BaseOperator + # and eliminated the super().__init__() call. This change was made to resolve conflicts in parent class + # initializations while adding support for ExecutionMode.AIRFLOW_ASYNC. Operators under this mode inherit + # Airflow provider operators that enable deferrable SQL query execution. Since super().__init__() was removed + # from AbstractDbtBase and different parent classes require distinct initialization arguments, we explicitly + # initialize them (including the BaseOperator) here by segregating the required arguments for each parent class. + kwargs.update( + { + "project_id": project_id, + "region": region, + "job_name": job_name, + "command": command, + "environment_variables": environment_variables, + } + ) + base_operator_args = set(inspect.signature(CloudRunExecuteJobOperator.__init__).parameters.keys()) + base_kwargs = {} + for arg_key, arg_value in kwargs.items(): + if arg_key in base_operator_args: + base_kwargs[arg_key] = arg_value + base_kwargs["task_id"] = kwargs["task_id"] + CloudRunExecuteJobOperator.__init__(self, **base_kwargs) + + def build_and_run_cmd( + self, + context: Context, + cmd_flags: list[str] | None = None, + run_as_async: bool = False, + async_context: dict[str, Any] | None = None, + ) -> Any: self.build_command(context, cmd_flags) self.log.info(f"Running command: {self.command}") result = CloudRunExecuteJobOperator.execute(self, context) diff --git a/cosmos/operators/kubernetes.py b/cosmos/operators/kubernetes.py index b00e6380c..8cbc20e1c 100644 --- a/cosmos/operators/kubernetes.py +++ b/cosmos/operators/kubernetes.py @@ -1,5 +1,6 @@ from __future__ import annotations +import inspect from abc import ABC from os import PathLike from typing import Any, Callable, Sequence @@ -10,7 +11,7 @@ from cosmos.config import ProfileConfig from cosmos.dbt.parser.output import extract_log_issues from cosmos.operators.base import ( - AbstractDbtBaseOperator, + AbstractDbtBase, DbtBuildMixin, DbtCloneMixin, DbtLSMixin, @@ -43,14 +44,14 @@ ) -class DbtKubernetesBaseOperator(AbstractDbtBaseOperator, KubernetesPodOperator): # type: ignore +class DbtKubernetesBaseOperator(AbstractDbtBase, KubernetesPodOperator): # type: ignore """ Executes a dbt core cli command in a Kubernetes Pod. """ template_fields: Sequence[str] = tuple( - list(AbstractDbtBaseOperator.template_fields) + list(KubernetesPodOperator.template_fields) + list(AbstractDbtBase.template_fields) + list(KubernetesPodOperator.template_fields) ) intercept_flag = False @@ -58,6 +59,19 @@ class DbtKubernetesBaseOperator(AbstractDbtBaseOperator, KubernetesPodOperator): def __init__(self, profile_config: ProfileConfig | None = None, **kwargs: Any) -> None: self.profile_config = profile_config super().__init__(**kwargs) + # In PR #1474, we refactored cosmos.operators.base.AbstractDbtBase to remove its inheritance from BaseOperator + # and eliminated the super().__init__() call. This change was made to resolve conflicts in parent class + # initializations while adding support for ExecutionMode.AIRFLOW_ASYNC. Operators under this mode inherit + # Airflow provider operators that enable deferrable SQL query execution. Since super().__init__() was removed + # from AbstractDbtBase and different parent classes require distinct initialization arguments, we explicitly + # initialize them (including the BaseOperator) here by segregating the required arguments for each parent class. + base_operator_args = set(inspect.signature(KubernetesPodOperator.__init__).parameters.keys()) + base_kwargs = {} + for arg_key, arg_value in kwargs.items(): + if arg_key in base_operator_args: + base_kwargs[arg_key] = arg_value + base_kwargs["task_id"] = kwargs["task_id"] + KubernetesPodOperator.__init__(self, **base_kwargs) def build_env_args(self, env: dict[str, str | bytes | PathLike[Any]]) -> None: env_vars_dict: dict[str, str] = dict() @@ -69,7 +83,13 @@ def build_env_args(self, env: dict[str, str | bytes | PathLike[Any]]) -> None: self.env_vars: list[Any] = convert_env_vars(env_vars_dict) - def build_and_run_cmd(self, context: Context, cmd_flags: list[str] | None = None) -> Any: + def build_and_run_cmd( + self, + context: Context, + cmd_flags: list[str] | None = None, + run_as_async: bool = False, + async_context: dict[str, Any] | None = None, + ) -> Any: self.build_kube_args(context, cmd_flags) self.log.info(f"Running command: {self.arguments}") result = KubernetesPodOperator.execute(self, context) diff --git a/cosmos/operators/local.py b/cosmos/operators/local.py index e5dbcfd31..91b3dd314 100644 --- a/cosmos/operators/local.py +++ b/cosmos/operators/local.py @@ -1,5 +1,6 @@ from __future__ import annotations +import inspect import json import os import tempfile @@ -15,6 +16,7 @@ import jinja2 from airflow import DAG from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.models import BaseOperator from airflow.models.taskinstance import TaskInstance from airflow.utils.context import Context from airflow.utils.session import NEW_SESSION, create_session, provide_session @@ -66,13 +68,14 @@ parse_number_of_warnings_subprocess, ) from cosmos.dbt.project import create_symlinks +from cosmos.dbt_adapters import PROFILE_TYPE_MOCK_ADAPTER_CALLABLE_MAP, associate_async_operator_args from cosmos.hooks.subprocess import ( FullOutputSubprocessHook, FullOutputSubprocessResult, ) from cosmos.log import get_logger from cosmos.operators.base import ( - AbstractDbtBaseOperator, + AbstractDbtBase, DbtBuildMixin, DbtCloneMixin, DbtCompileMixin, @@ -112,7 +115,7 @@ class OperatorLineage: # type: ignore job_facets: dict[str, str] = dict() -class DbtLocalBaseOperator(AbstractDbtBaseOperator): +class AbstractDbtLocalBase(AbstractDbtBase): """ Executes a dbt core cli command locally. @@ -131,7 +134,7 @@ class DbtLocalBaseOperator(AbstractDbtBaseOperator): and does not inherit the current process environment. """ - template_fields: Sequence[str] = AbstractDbtBaseOperator.template_fields + ("compiled_sql", "freshness") # type: ignore[operator] + template_fields: Sequence[str] = AbstractDbtBase.template_fields + ("compiled_sql", "freshness") # type: ignore[operator] template_fields_renderers = { "compiled_sql": "sql", "freshness": "json", @@ -162,17 +165,6 @@ def __init__( self.invocation_mode = invocation_mode self._dbt_runner: dbtRunner | None = None - if kwargs.get("emit_datasets", True) and settings.enable_dataset_alias and AIRFLOW_VERSION >= Version("2.10"): - from airflow.datasets import DatasetAlias - - # ignoring the type because older versions of Airflow raise the follow error in mypy - # error: Incompatible types in assignment (expression has type "list[DatasetAlias]", target has type "str") - dag_id = kwargs.get("dag") - task_group_id = kwargs.get("task_group") - kwargs["outlets"] = [ - DatasetAlias(name=get_dataset_alias_name(dag_id, task_group_id, task_id)) - ] # type: ignore - super().__init__(task_id=task_id, **kwargs) # For local execution mode, we're consistent with the LoadMode.DBT_LS command in forwarding the environment @@ -271,7 +263,7 @@ def store_compiled_sql(self, tmp_project_dir: str, context: Context, session: Se # delete the old records session.query(RenderedTaskInstanceFields).filter( - RenderedTaskInstanceFields.dag_id == self.dag_id, + RenderedTaskInstanceFields.dag_id == self.dag_id, # type: ignore[attr-defined] RenderedTaskInstanceFields.task_id == self.task_id, RenderedTaskInstanceFields.run_id == ti.run_id, ).delete() @@ -401,12 +393,97 @@ def _cache_package_lockfile(self, tmp_project_dir: Path) -> None: if latest_package_lockfile: _copy_cached_package_lockfile_to_project(latest_package_lockfile, tmp_project_dir) + def _read_run_sql_from_target_dir(self, tmp_project_dir: str, sql_context: dict[str, Any]) -> str: + sql_relative_path = sql_context["dbt_node_config"]["file_path"].split(str(self.project_dir))[-1].lstrip("/") + run_sql_path = Path(tmp_project_dir) / "target/run" / Path(self.project_dir).name / sql_relative_path + with run_sql_path.open("r") as sql_file: + sql_content: str = sql_file.read() + return sql_content + + def _clone_project(self, tmp_dir_path: Path) -> None: + self.log.info( + "Cloning project to writable temp directory %s from %s", + tmp_dir_path, + self.project_dir, + ) + create_symlinks(Path(self.project_dir), tmp_dir_path, self.install_deps) + + def _handle_partial_parse(self, tmp_dir_path: Path) -> None: + if self.cache_dir is None: + return + latest_partial_parse = cache._get_latest_partial_parse(Path(self.project_dir), self.cache_dir) + self.log.info("Partial parse is enabled and the latest partial parse file is %s", latest_partial_parse) + if latest_partial_parse is not None: + cache._copy_partial_parse_to_project(latest_partial_parse, tmp_dir_path) + + def _generate_dbt_flags(self, tmp_project_dir: str, profile_path: Path) -> list[str]: + return [ + "--project-dir", + str(tmp_project_dir), + "--profiles-dir", + str(profile_path.parent), + "--profile", + self.profile_config.profile_name, + "--target", + self.profile_config.target_name, + ] + + def _install_dependencies( + self, tmp_dir_path: Path, flags: list[str], env: dict[str, str | bytes | os.PathLike[Any]] + ) -> None: + self._cache_package_lockfile(tmp_dir_path) + deps_command = [self.dbt_executable_path, "deps"] + flags + self.invoke_dbt(command=deps_command, env=env, cwd=tmp_dir_path) + + @staticmethod + def _mock_dbt_adapter(async_context: dict[str, Any] | None) -> None: + if not async_context: + raise CosmosValueError("`async_context` is necessary for running the model asynchronously") + if "async_operator" not in async_context: + raise CosmosValueError("`async_operator` needs to be specified in `async_context` when running as async") + if "profile_type" not in async_context: + raise CosmosValueError("`profile_type` needs to be specified in `async_context` when running as async") + profile_type = async_context["profile_type"] + if profile_type not in PROFILE_TYPE_MOCK_ADAPTER_CALLABLE_MAP: + raise CosmosValueError(f"Mock adapter callable function not available for profile_type {profile_type}") + mock_adapter_callable = PROFILE_TYPE_MOCK_ADAPTER_CALLABLE_MAP[profile_type] + mock_adapter_callable() + + def _handle_datasets(self, context: Context) -> None: + inlets = self.get_datasets("inputs") + outlets = self.get_datasets("outputs") + self.log.info("Inlets: %s", inlets) + self.log.info("Outlets: %s", outlets) + self.register_dataset(inlets, outlets, context) + + def _update_partial_parse_cache(self, tmp_dir_path: Path) -> None: + if self.cache_dir is None: + return + partial_parse_file = get_partial_parse_path(tmp_dir_path) + if partial_parse_file.exists(): + cache._update_partial_parse_cache(partial_parse_file, self.cache_dir) + + def _handle_post_execution(self, tmp_project_dir: str, context: Context) -> None: + self.store_freshness_json(tmp_project_dir, context) + self.store_compiled_sql(tmp_project_dir, context) + self.upload_compiled_sql(tmp_project_dir, context) + if self.callback: + self.callback_args.update({"context": context}) + self.callback(tmp_project_dir, **self.callback_args) + + def _handle_async_execution(self, tmp_project_dir: str, context: Context, async_context: dict[str, Any]) -> None: + sql = self._read_run_sql_from_target_dir(tmp_project_dir, async_context) + associate_async_operator_args(self, async_context["profile_type"], sql=sql) + async_context["async_operator"].execute(self, context) + def run_command( self, cmd: list[str], env: dict[str, str | bytes | os.PathLike[Any]], context: Context, - ) -> FullOutputSubprocessResult | dbtRunnerResult: + run_as_async: bool = False, + async_context: dict[str, Any] | None = None, + ) -> FullOutputSubprocessResult | dbtRunnerResult | str: """ Copies the dbt project to a temporary directory and runs the command. """ @@ -415,50 +492,27 @@ def run_command( with tempfile.TemporaryDirectory() as tmp_project_dir: - self.log.info( - "Cloning project to writable temp directory %s from %s", - tmp_project_dir, - self.project_dir, - ) tmp_dir_path = Path(tmp_project_dir) env = {k: str(v) for k, v in env.items()} - create_symlinks(Path(self.project_dir), tmp_dir_path, self.install_deps) + self._clone_project(tmp_dir_path) - if self.partial_parse and self.cache_dir is not None: - latest_partial_parse = cache._get_latest_partial_parse(Path(self.project_dir), self.cache_dir) - self.log.info("Partial parse is enabled and the latest partial parse file is %s", latest_partial_parse) - if latest_partial_parse is not None: - cache._copy_partial_parse_to_project(latest_partial_parse, tmp_dir_path) + if self.partial_parse: + self._handle_partial_parse(tmp_dir_path) with self.profile_config.ensure_profile() as profile_values: (profile_path, env_vars) = profile_values env.update(env_vars) + self.log.debug("Using environment variables keys: %s", env.keys()) - flags = [ - "--project-dir", - str(tmp_project_dir), - "--profiles-dir", - str(profile_path.parent), - "--profile", - self.profile_config.profile_name, - "--target", - self.profile_config.target_name, - ] + flags = self._generate_dbt_flags(tmp_project_dir, profile_path) if self.install_deps: - self._cache_package_lockfile(tmp_dir_path) - deps_command = [self.dbt_executable_path, "deps"] - deps_command.extend(flags) - self.invoke_dbt( - command=deps_command, - env=env, - cwd=tmp_project_dir, - ) + self._install_dependencies(tmp_dir_path, flags, env) - full_cmd = cmd + flags - - self.log.debug("Using environment variables keys: %s", env.keys()) + if run_as_async: + self._mock_dbt_adapter(async_context) + full_cmd = cmd + flags result = self.invoke_dbt( command=full_cmd, env=env, @@ -471,25 +525,17 @@ def run_command( ].openlineage_events_completes = self.openlineage_events_completes # type: ignore if self.emit_datasets: - inlets = self.get_datasets("inputs") - outlets = self.get_datasets("outputs") - self.log.info("Inlets: %s", inlets) - self.log.info("Outlets: %s", outlets) - self.register_dataset(inlets, outlets, context) - - if self.partial_parse and self.cache_dir: - partial_parse_file = get_partial_parse_path(tmp_dir_path) - if partial_parse_file.exists(): - cache._update_partial_parse_cache(partial_parse_file, self.cache_dir) - - self.store_freshness_json(tmp_project_dir, context) - self.store_compiled_sql(tmp_project_dir, context) - self.upload_compiled_sql(tmp_project_dir, context) - if self.callback: - self.callback_args.update({"context": context}) - self.callback(tmp_project_dir, **self.callback_args) + self._handle_datasets(context) + + if self.partial_parse: + self._update_partial_parse_cache(tmp_dir_path) + + self._handle_post_execution(tmp_project_dir, context) self.handle_exception(result) + if run_as_async and async_context: + self._handle_async_execution(tmp_project_dir, context, async_context) + return result def calculate_openlineage_events_completes( @@ -576,17 +622,17 @@ def register_dataset(self, new_inlets: list[Dataset], new_outlets: list[Dataset] if AIRFLOW_VERSION < Version("2.10") or not settings.enable_dataset_alias: logger.info("Assigning inlets/outlets without DatasetAlias") with create_session() as session: - self.outlets.extend(new_outlets) - self.inlets.extend(new_inlets) - for task in self.dag.tasks: + self.outlets.extend(new_outlets) # type: ignore[attr-defined] + self.inlets.extend(new_inlets) # type: ignore[attr-defined] + for task in self.dag.tasks: # type: ignore[attr-defined] if task.task_id == self.task_id: task.outlets.extend(new_outlets) task.inlets.extend(new_inlets) - DAG.bulk_write_to_db([self.dag], session=session) + DAG.bulk_write_to_db([self.dag], session=session) # type: ignore[attr-defined] session.commit() else: logger.info("Assigning inlets/outlets with DatasetAlias") - dataset_alias_name = get_dataset_alias_name(self.dag, self.task_group, self.task_id) + dataset_alias_name = get_dataset_alias_name(self.dag, self.task_group, self.task_id) # type: ignore[attr-defined] for outlet in new_outlets: context["outlet_events"][dataset_alias_name].add(outlet) @@ -629,11 +675,17 @@ def get_openlineage_facets_on_complete(self, task_instance: TaskInstance) -> Ope ) def build_and_run_cmd( - self, context: Context, cmd_flags: list[str] | None = None + self, + context: Context, + cmd_flags: list[str] | None = None, + run_as_async: bool = False, + async_context: dict[str, Any] | None = None, ) -> FullOutputSubprocessResult | dbtRunnerResult: dbt_cmd, env = self.build_cmd(context=context, cmd_flags=cmd_flags) dbt_cmd = dbt_cmd or [] - result = self.run_command(cmd=dbt_cmd, env=env, context=context) + result = self.run_command( + cmd=dbt_cmd, env=env, context=context, run_as_async=run_as_async, async_context=async_context + ) return result def on_kill(self) -> None: @@ -644,6 +696,43 @@ def on_kill(self) -> None: self.subprocess_hook.send_sigterm() +class DbtLocalBaseOperator(AbstractDbtLocalBase, BaseOperator): + + template_fields: Sequence[str] = AbstractDbtLocalBase.template_fields # type: ignore[operator] + + def __init__(self, *args: Any, **kwargs: Any) -> None: + # In PR #1474, we refactored cosmos.operators.base.AbstractDbtBase to remove its inheritance from BaseOperator + # and eliminated the super().__init__() call. This change was made to resolve conflicts in parent class + # initializations while adding support for ExecutionMode.AIRFLOW_ASYNC. Operators under this mode inherit + # Airflow provider operators that enable deferrable SQL query execution. Since super().__init__() was removed + # from AbstractDbtBase and different parent classes require distinct initialization arguments, we explicitly + # initialize them (including the BaseOperator) here by segregating the required arguments for each parent class. + abstract_dbt_local_base_kwargs = {} + base_operator_kwargs = {} + abstract_dbt_local_base_args_keys = ( + inspect.getfullargspec(AbstractDbtBase.__init__).args + + inspect.getfullargspec(AbstractDbtLocalBase.__init__).args + ) + base_operator_args = set(inspect.signature(BaseOperator.__init__).parameters.keys()) + for arg_key, arg_value in kwargs.items(): + if arg_key in abstract_dbt_local_base_args_keys: + abstract_dbt_local_base_kwargs[arg_key] = arg_value + if arg_key in base_operator_args: + base_operator_kwargs[arg_key] = arg_value + AbstractDbtLocalBase.__init__(self, **abstract_dbt_local_base_kwargs) + if kwargs.get("emit_datasets", True) and settings.enable_dataset_alias and AIRFLOW_VERSION >= Version("2.10"): + from airflow.datasets import DatasetAlias + + # ignoring the type because older versions of Airflow raise the follow error in mypy + # error: Incompatible types in assignment (expression has type "list[DatasetAlias]", target has type "str") + dag_id = kwargs.get("dag") + task_group_id = kwargs.get("task_group") + base_operator_kwargs["outlets"] = [ + DatasetAlias(name=get_dataset_alias_name(dag_id, task_group_id, self.task_id)) + ] # type: ignore + BaseOperator.__init__(self, **base_operator_kwargs) + + class DbtBuildLocalOperator(DbtBuildMixin, DbtLocalBaseOperator): """ Executes a dbt core build command. @@ -660,6 +749,8 @@ class DbtLSLocalOperator(DbtLSMixin, DbtLocalBaseOperator): Executes a dbt core ls command. """ + template_fields: Sequence[str] = DbtLocalBaseOperator.template_fields # type: ignore[operator] + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) @@ -680,6 +771,8 @@ class DbtSnapshotLocalOperator(DbtSnapshotMixin, DbtLocalBaseOperator): Executes a dbt core snapshot command. """ + template_fields: Sequence[str] = DbtLocalBaseOperator.template_fields # type: ignore[operator] + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) @@ -689,6 +782,8 @@ class DbtSourceLocalOperator(DbtSourceMixin, DbtLocalBaseOperator): Executes a dbt source freshness command. """ + template_fields: Sequence[str] = DbtLocalBaseOperator.template_fields # type: ignore[operator] + def __init__(self, *args: Any, on_warning_callback: Callable[..., Any] | None = None, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.on_warning_callback = on_warning_callback @@ -715,7 +810,7 @@ def _handle_warnings(self, result: FullOutputSubprocessResult | dbtRunnerResult, self.on_warning_callback and self.on_warning_callback(warning_context) - def execute(self, context: Context) -> None: + def execute(self, context: Context, **kwargs: Any) -> None: result = self.build_and_run_cmd(context=context, cmd_flags=self.add_cmd_flags()) if self.on_warning_callback: self._handle_warnings(result, context) @@ -739,6 +834,8 @@ class DbtTestLocalOperator(DbtTestMixin, DbtLocalBaseOperator): and "test_results" of type `List`. Each index in "test_names" corresponds to the same index in "test_results". """ + template_fields: Sequence[str] = DbtLocalBaseOperator.template_fields # type: ignore[operator] + def __init__( self, on_warning_callback: Callable[..., Any] | None = None, @@ -774,7 +871,7 @@ def _set_test_result_parsing_methods(self) -> None: self.extract_issues = dbt_runner.extract_message_by_status self.parse_number_of_warnings = dbt_runner.parse_number_of_warnings - def execute(self, context: Context) -> None: + def execute(self, context: Context, **kwargs: Any) -> None: result = self.build_and_run_cmd(context=context, cmd_flags=self.add_cmd_flags()) self._set_test_result_parsing_methods() number_of_warnings = self.parse_number_of_warnings(result) # type: ignore @@ -803,6 +900,8 @@ class DbtDocsLocalOperator(DbtLocalBaseOperator): Use the `callback` parameter to specify a callback function to run after the command completes. """ + template_fields: Sequence[str] = DbtLocalBaseOperator.template_fields # type: ignore[operator] + ui_color = "#8194E0" required_files = ["index.html", "manifest.json", "catalog.json"] base_cmd = ["docs", "generate"] @@ -826,6 +925,8 @@ class DbtDocsCloudLocalOperator(DbtDocsLocalOperator, ABC): Abstract class for operators that upload the generated documentation to cloud storage. """ + template_fields: Sequence[str] = DbtDocsLocalOperator.template_fields # type: ignore[operator] + def __init__( self, connection_id: str, @@ -1021,6 +1122,8 @@ def __init__(self, **kwargs: str) -> None: class DbtCompileLocalOperator(DbtCompileMixin, DbtLocalBaseOperator): + template_fields: Sequence[str] = DbtLocalBaseOperator.template_fields # type: ignore[operator] + def __init__(self, *args: Any, **kwargs: Any) -> None: kwargs["should_upload_compiled_sql"] = True super().__init__(*args, **kwargs) @@ -1031,5 +1134,7 @@ class DbtCloneLocalOperator(DbtCloneMixin, DbtLocalBaseOperator): Executes a dbt core clone command. """ + template_fields: Sequence[str] = DbtLocalBaseOperator.template_fields # type: ignore[operator] + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) diff --git a/cosmos/operators/virtualenv.py b/cosmos/operators/virtualenv.py index 3bd54da99..4026d3eb4 100644 --- a/cosmos/operators/virtualenv.py +++ b/cosmos/operators/virtualenv.py @@ -5,7 +5,7 @@ import time from pathlib import Path from tempfile import TemporaryDirectory -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any, Callable, Sequence import psutil from airflow.utils.python_virtualenv import prepare_virtualenv @@ -96,6 +96,8 @@ def run_command( cmd: list[str], env: dict[str, str | bytes | os.PathLike[Any]], context: Context, + run_as_async: bool = False, + async_context: dict[str, Any] | None = None, ) -> FullOutputSubprocessResult | dbtRunnerResult: # No virtualenv_dir set, so create a temporary virtualenv if self.virtualenv_dir is None or self.is_virtualenv_dir_temporary: @@ -128,7 +130,7 @@ def clean_dir_if_temporary(self) -> None: self.log.info(f"Deleting the Python virtualenv {self.virtualenv_dir}") shutil.rmtree(str(self.virtualenv_dir), ignore_errors=True) - def execute(self, context: Context) -> None: + def execute(self, context: Context, **kwargs: Any) -> None: try: output = super().execute(context) self.log.info(output) @@ -215,6 +217,8 @@ class DbtLSVirtualenvOperator(DbtVirtualenvBaseOperator, DbtLSLocalOperator): and deleted just after. """ + template_fields: Sequence[str] = DbtVirtualenvBaseOperator.template_fields # type: ignore[operator] + def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) @@ -235,6 +239,8 @@ class DbtSnapshotVirtualenvOperator(DbtVirtualenvBaseOperator, DbtSnapshotLocalO command and deleted just after. """ + template_fields: Sequence[str] = DbtVirtualenvBaseOperator.template_fields # type: ignore[operator] + def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) @@ -245,6 +251,8 @@ class DbtSourceVirtualenvOperator(DbtVirtualenvBaseOperator, DbtSourceLocalOpera command and deleted just after. """ + template_fields: Sequence[str] = DbtVirtualenvBaseOperator.template_fields # type: ignore[operator] + def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) @@ -265,6 +273,8 @@ class DbtTestVirtualenvOperator(DbtVirtualenvBaseOperator, DbtTestLocalOperator) and deleted just after. """ + template_fields: Sequence[str] = DbtVirtualenvBaseOperator.template_fields # type: ignore[operator] + def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) @@ -285,6 +295,8 @@ class DbtDocsVirtualenvOperator(DbtVirtualenvBaseOperator, DbtDocsLocalOperator) command and deleted just after. """ + template_fields: Sequence[str] = DbtVirtualenvBaseOperator.template_fields # type: ignore[operator] + def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) @@ -294,5 +306,7 @@ class DbtCloneVirtualenvOperator(DbtVirtualenvBaseOperator, DbtCloneLocalOperato Executes a dbt core clone command. """ + template_fields: Sequence[str] = DbtVirtualenvBaseOperator.template_fields # type: ignore[operator] + def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) diff --git a/dev/dags/simple_dag_async.py b/dev/dags/simple_dag_async.py index 1b2b67651..8fb8cb844 100644 --- a/dev/dags/simple_dag_async.py +++ b/dev/dags/simple_dag_async.py @@ -37,6 +37,6 @@ catchup=False, dag_id="simple_dag_async", tags=["simple"], - operator_args={"full_refresh": True, "location": "northamerica-northeast1"}, + operator_args={"location": "northamerica-northeast1"}, ) # [END airflow_async_execution_mode_example] diff --git a/tests/airflow/test_graph.py b/tests/airflow/test_graph.py index ccbd911be..d86abab74 100644 --- a/tests/airflow/test_graph.py +++ b/tests/airflow/test_graph.py @@ -1,7 +1,7 @@ import os from datetime import datetime from pathlib import Path -from unittest.mock import MagicMock, patch +from unittest.mock import patch import pytest from airflow import __version__ as airflow_version @@ -22,7 +22,6 @@ ) from cosmos.config import ProfileConfig, RenderConfig from cosmos.constants import ( - DBT_COMPILE_TASK_ID, DbtResourceType, ExecutionMode, SourceRenderingBehavior, @@ -31,7 +30,7 @@ ) from cosmos.converter import airflow_kwargs from cosmos.dbt.graph import DbtNode -from cosmos.profiles import GoogleCloudServiceAccountFileProfileMapping, PostgresUserPasswordProfileMapping +from cosmos.profiles import PostgresUserPasswordProfileMapping SAMPLE_PROJ_PATH = Path("/home/user/path/dbt-proj/") SOURCE_RENDERING_BEHAVIOR = SourceRenderingBehavior(os.getenv("SOURCE_RENDERING_BEHAVIOR", "none")) @@ -347,42 +346,6 @@ def test_build_airflow_graph_with_override_profile_config(): assert generated_parent_profile_config.profile_mapping.profile_args["schema"] == "public" -@pytest.mark.integration -@patch("airflow.hooks.base.BaseHook.get_connection", new=MagicMock()) -def test_build_airflow_graph_with_dbt_compile_task(): - bigquery_profile_config = ProfileConfig( - profile_name="my-bigquery-db", - target_name="dev", - profile_mapping=GoogleCloudServiceAccountFileProfileMapping( - conn_id="fake_conn", profile_args={"dataset": "release_17"} - ), - ) - with DAG("test-id-dbt-compile", start_date=datetime(2022, 1, 1)) as dag: - task_args = { - "project_dir": SAMPLE_PROJ_PATH, - "conn_id": "fake_conn", - "profile_config": bigquery_profile_config, - } - render_config = RenderConfig( - select=["tag:some"], - test_behavior=TestBehavior.AFTER_ALL, - source_rendering_behavior=SOURCE_RENDERING_BEHAVIOR, - ) - build_airflow_graph( - nodes=sample_nodes, - dag=dag, - execution_mode=ExecutionMode.AIRFLOW_ASYNC, - test_indirect_selection=TestIndirectSelection.EAGER, - task_args=task_args, - dbt_project_name="astro_shop", - render_config=render_config, - ) - - task_ids = [task.task_id for task in dag.tasks] - assert DBT_COMPILE_TASK_ID in task_ids - assert DBT_COMPILE_TASK_ID in dag.tasks[0].upstream_task_ids - - def test_calculate_operator_class(): class_module_import_path = calculate_operator_class(execution_mode=ExecutionMode.KUBERNETES, dbt_class="DbtSeed") assert class_module_import_path == "cosmos.operators.kubernetes.DbtSeedKubernetesOperator" diff --git a/tests/dbt_adapters/test_bigquery.py b/tests/dbt_adapters/test_bigquery.py new file mode 100644 index 000000000..d8921d059 --- /dev/null +++ b/tests/dbt_adapters/test_bigquery.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +from unittest.mock import Mock + +import pytest + +from cosmos.dbt_adapters.bigquery import _associate_bigquery_async_op_args, _mock_bigquery_adapter +from cosmos.exceptions import CosmosValueError + + +@pytest.fixture +def async_operator_mock(): + """Fixture to create a mock async operator object.""" + return Mock() + + +@pytest.mark.integration +def test_mock_bigquery_adapter(): + """Test _mock_bigquery_adapter to verify it modifies BigQueryConnectionManager.execute.""" + from dbt.adapters.bigquery.connections import BigQueryConnectionManager + + _mock_bigquery_adapter() + + assert hasattr(BigQueryConnectionManager, "execute") + + response, table = BigQueryConnectionManager.execute(None, sql="SELECT 1") + assert response._message == "mock_bigquery_adapter_response" + assert table is not None + + +def test_associate_bigquery_async_op_args_valid(async_operator_mock): + """Test _associate_bigquery_async_op_args correctly configures the async operator.""" + sql_query = "SELECT * FROM test_table" + + result = _associate_bigquery_async_op_args(async_operator_mock, sql=sql_query) + + assert result == async_operator_mock + assert result.configuration["query"]["query"] == sql_query + assert result.configuration["query"]["useLegacySql"] is False + + +def test_associate_bigquery_async_op_args_missing_sql(async_operator_mock): + """Test _associate_bigquery_async_op_args raises CosmosValueError when 'sql' is missing.""" + with pytest.raises(CosmosValueError, match="Keyword argument 'sql' is required for BigQuery Async operator"): + _associate_bigquery_async_op_args(async_operator_mock) diff --git a/tests/dbt_adapters/test_init.py b/tests/dbt_adapters/test_init.py new file mode 100644 index 000000000..ce272e333 --- /dev/null +++ b/tests/dbt_adapters/test_init.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +from unittest.mock import Mock + +import pytest + +from cosmos.dbt_adapters import associate_async_operator_args + + +def test_associate_async_operator_args_invalid_profile(): + """Test associate_async_operator_args raises KeyError for an invalid profile type.""" + async_operator_mock = Mock() + + with pytest.raises(KeyError): + associate_async_operator_args(async_operator_mock, "invalid_profile") diff --git a/tests/operators/_asynchronous/test_base.py b/tests/operators/_asynchronous/test_base.py index c01bbd866..f3e49a621 100644 --- a/tests/operators/_asynchronous/test_base.py +++ b/tests/operators/_asynchronous/test_base.py @@ -1,12 +1,13 @@ -from unittest.mock import patch +from __future__ import annotations + +from unittest.mock import MagicMock, patch import pytest -from cosmos import ProfileConfig +from cosmos.config import ProfileConfig from cosmos.operators._asynchronous.base import DbtRunAirflowAsyncFactoryOperator, _create_async_operator_class from cosmos.operators._asynchronous.bigquery import DbtRunAirflowAsyncBigqueryOperator from cosmos.operators.local import DbtRunLocalOperator -from cosmos.profiles import get_automatic_profile_mapping @pytest.mark.parametrize( @@ -25,30 +26,45 @@ def test_create_async_operator_class_success(profile_type, dbt_class, expected_o assert operator_class == expected_operator_class -@patch("cosmos.operators._asynchronous.bigquery.DbtRunAirflowAsyncBigqueryOperator.drop_table_sql") -@patch("cosmos.operators._asynchronous.bigquery.DbtRunAirflowAsyncBigqueryOperator.get_remote_sql") -@patch("cosmos.operators._asynchronous.bigquery.BigQueryInsertJobOperator.execute") -def test_factory_async_class(mock_execute, get_remote_sql, drop_table_sql, mock_bigquery_conn): - profile_mapping = get_automatic_profile_mapping( - mock_bigquery_conn.conn_id, - profile_args={ - "dataset": "my_dataset", - }, - ) - bigquery_profile_config = ProfileConfig( - profile_name="my_profile", target_name="dev", profile_mapping=profile_mapping - ) - factory_class = DbtRunAirflowAsyncFactoryOperator( - task_id="run", - project_dir="/tmp", - profile_config=bigquery_profile_config, - full_refresh=True, - extra_context={"dbt_node_config": {"resource_name": "customer"}}, - ) +@pytest.fixture +def profile_config_mock(): + """Fixture to create a mock ProfileConfig.""" + mock_config = MagicMock(spec=ProfileConfig) + mock_config.get_profile_type.return_value = "bigquery" + return mock_config + + +def test_create_async_operator_class_valid(): + """Test _create_async_operator_class returns the correct async operator class if available.""" + with patch("cosmos.operators._asynchronous.base.importlib.import_module") as mock_import: + mock_class = MagicMock() + mock_import.return_value = MagicMock() + setattr(mock_import.return_value, "DbtRunAirflowAsyncBigqueryOperator", mock_class) + + result = _create_async_operator_class("bigquery", "DbtRun") + assert result == mock_class - async_operator = factory_class.create_async_operator() - assert async_operator == DbtRunAirflowAsyncBigqueryOperator - factory_class.execute(context={}) +def test_create_async_operator_class_fallback(): + """Test _create_async_operator_class falls back to DbtRunLocalOperator when import fails.""" + with patch("cosmos.operators._asynchronous.base.importlib.import_module", side_effect=ModuleNotFoundError): + result = _create_async_operator_class("bigquery", "DbtRun") + assert result == DbtRunLocalOperator + + +class MockAsyncOperator(DbtRunLocalOperator): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + +@patch("cosmos.operators._asynchronous.base._create_async_operator_class", return_value=MockAsyncOperator) +def test_dbt_run_airflow_async_factory_operator_init(mock_create_class, profile_config_mock): + + operator = DbtRunAirflowAsyncFactoryOperator( + task_id="test_task", + project_dir="some/path", + profile_config=profile_config_mock, + ) - mock_execute.assert_called_once_with({}) + assert operator is not None + assert isinstance(operator, MockAsyncOperator) diff --git a/tests/operators/_asynchronous/test_bigquery.py b/tests/operators/_asynchronous/test_bigquery.py index 6eb532107..34182784b 100644 --- a/tests/operators/_asynchronous/test_bigquery.py +++ b/tests/operators/_asynchronous/test_bigquery.py @@ -1,96 +1,71 @@ +from __future__ import annotations + from unittest.mock import MagicMock, patch import pytest -from airflow import __version__ as airflow_version -from packaging import version +from airflow.providers.google.cloud.operators.bigquery import BigQueryInsertJobOperator -from cosmos import ProfileConfig -from cosmos.exceptions import CosmosValueError +from cosmos.config import ProfileConfig from cosmos.operators._asynchronous.bigquery import DbtRunAirflowAsyncBigqueryOperator -from cosmos.profiles import get_automatic_profile_mapping -from cosmos.settings import AIRFLOW_IO_AVAILABLE - -def test_bigquery_without_refresh(mock_bigquery_conn): - profile_mapping = get_automatic_profile_mapping( - mock_bigquery_conn.conn_id, - profile_args={ - "dataset": "my_dataset", - }, - ) - bigquery_profile_config = ProfileConfig( - profile_name="my_profile", target_name="dev", profile_mapping=profile_mapping - ) - operator = DbtRunAirflowAsyncBigqueryOperator( - task_id="test_task", project_dir="/tmp", profile_config=bigquery_profile_config - ) - operator.extra_context = { - "dbt_node_config": {"file_path": "/some/path/to/file.sql"}, - "dbt_dag_task_group_identifier": "task_group_1", - } - with pytest.raises(CosmosValueError, match="The async execution only supported for full_refresh"): - operator.execute({}) +@pytest.fixture +def profile_config_mock(): + """Fixture to create a mock ProfileConfig.""" + mock_config = MagicMock(spec=ProfileConfig) + mock_config.get_profile_type.return_value = "bigquery" + mock_config.profile_mapping.conn_id = "google_cloud_default" + mock_config.profile_mapping.profile = {"project": "test_project", "dataset": "test_dataset"} + return mock_config -def test_get_remote_sql_airflow_io_unavailable(mock_bigquery_conn): - profile_mapping = get_automatic_profile_mapping( - mock_bigquery_conn.conn_id, - profile_args={ - "dataset": "my_dataset", - }, - ) - bigquery_profile_config = ProfileConfig( - profile_name="my_profile", target_name="dev", profile_mapping=profile_mapping - ) +def test_dbt_run_airflow_async_bigquery_operator_init(profile_config_mock): + """Test DbtRunAirflowAsyncBigqueryOperator initializes with correct attributes.""" operator = DbtRunAirflowAsyncBigqueryOperator( - task_id="test_task", project_dir="/tmp", profile_config=bigquery_profile_config + task_id="test_task", + project_dir="/path/to/project", + profile_config=profile_config_mock, + dbt_kwargs={"task_id": "test_task"}, ) - operator.extra_context = { - "dbt_node_config": {"file_path": "/some/path/to/file.sql"}, - "dbt_dag_task_group_identifier": "task_group_1", - } + assert isinstance(operator, DbtRunAirflowAsyncBigqueryOperator) + assert isinstance(operator, BigQueryInsertJobOperator) + assert operator.project_dir == "/path/to/project" + assert operator.profile_config == profile_config_mock + assert operator.gcp_conn_id == "google_cloud_default" + assert operator.gcp_project == "test_project" + assert operator.dataset == "test_dataset" - if not AIRFLOW_IO_AVAILABLE: - with pytest.raises( - CosmosValueError, match="Cosmos async support is only available starting in Airflow 2.8 or later." - ): - operator.get_remote_sql() - -@pytest.mark.skipif( - version.parse(airflow_version) < version.parse("2.8"), - reason="Airflow object storage supported 2.8 release", -) -def test_get_remote_sql_success(mock_bigquery_conn): - profile_mapping = get_automatic_profile_mapping( - mock_bigquery_conn.conn_id, - profile_args={ - "dataset": "my_dataset", - }, - ) - bigquery_profile_config = ProfileConfig( - profile_name="my_profile", target_name="dev", profile_mapping=profile_mapping - ) +def test_dbt_run_airflow_async_bigquery_operator_base_cmd(profile_config_mock): + """Test base_cmd property returns the correct dbt command.""" operator = DbtRunAirflowAsyncBigqueryOperator( - task_id="test_task", project_dir="/tmp", profile_config=bigquery_profile_config + task_id="test_task", + project_dir="/path/to/project", + profile_config=profile_config_mock, + dbt_kwargs={"task_id": "test_task"}, ) + assert operator.base_cmd == ["run"] - operator.extra_context = { - "dbt_node_config": {"file_path": "/some/path/to/file.sql"}, - "dbt_dag_task_group_identifier": "task_group_1", - } - operator.project_dir = "/tmp" - - mock_object_storage_path = MagicMock() - mock_file = MagicMock() - mock_file.read.return_value = "SELECT * FROM table" - mock_object_storage_path.open.return_value.__enter__.return_value = mock_file +@patch.object(DbtRunAirflowAsyncBigqueryOperator, "build_and_run_cmd") +def test_dbt_run_airflow_async_bigquery_operator_execute(mock_build_and_run_cmd, profile_config_mock): + """Test execute calls build_and_run_cmd with correct parameters.""" + operator = DbtRunAirflowAsyncBigqueryOperator( + task_id="test_task", + project_dir="/path/to/project", + profile_config=profile_config_mock, + dbt_kwargs={"task_id": "test_task"}, + ) - with patch("airflow.io.path.ObjectStoragePath", return_value=mock_object_storage_path): - remote_sql = operator.get_remote_sql() + mock_context = MagicMock() + operator.execute(mock_context) - assert remote_sql == "SELECT * FROM table" - mock_object_storage_path.open.assert_called_once() + mock_build_and_run_cmd.assert_called_once_with( + context=mock_context, + run_as_async=True, + async_context={ + "profile_type": "bigquery", + "async_operator": BigQueryInsertJobOperator, + }, + ) diff --git a/tests/operators/test_aws_eks.py b/tests/operators/test_aws_eks.py index bca007c4d..86f9409b2 100644 --- a/tests/operators/test_aws_eks.py +++ b/tests/operators/test_aws_eks.py @@ -38,7 +38,6 @@ def test_dbt_kubernetes_build_command(): Since we know that the KubernetesOperator is tested, we can just test that the command is built correctly and added to the "arguments" parameter. """ - result_map = { "ls": DbtLSAwsEksOperator(**base_kwargs), "run": DbtRunAwsEksOperator(**base_kwargs), diff --git a/tests/operators/test_base.py b/tests/operators/test_base.py index e97c2d396..7394a7df9 100644 --- a/tests/operators/test_base.py +++ b/tests/operators/test_base.py @@ -1,12 +1,14 @@ +import inspect import sys from datetime import datetime from unittest.mock import patch import pytest +from airflow.models import BaseOperator from airflow.utils.context import Context from cosmos.operators.base import ( - AbstractDbtBaseOperator, + AbstractDbtBase, DbtBuildMixin, DbtCompileMixin, DbtLSMixin, @@ -22,13 +24,13 @@ (sys.version_info.major, sys.version_info.minor) == (3, 12), reason="The error message for the abstract class instantiation seems to have changed between Python 3.11 and 3.12", ) -def test_dbt_base_operator_is_abstract(): +def test_dbt_base_is_abstract(): """Tests that the abstract base operator cannot be instantiated since the base_cmd is not defined.""" expected_error = ( - "Can't instantiate abstract class AbstractDbtBaseOperator with abstract methods base_cmd, build_and_run_cmd" + "Can't instantiate abstract class AbstractDbtBase with abstract methods base_cmd, build_and_run_cmd" ) with pytest.raises(TypeError, match=expected_error): - AbstractDbtBaseOperator() + AbstractDbtBase(project_dir="project_dir") @pytest.mark.skipif( @@ -38,21 +40,21 @@ def test_dbt_base_operator_is_abstract(): def test_dbt_base_operator_is_abstract_py12(): """Tests that the abstract base operator cannot be instantiated since the base_cmd is not defined.""" expected_error = ( - "Can't instantiate abstract class AbstractDbtBaseOperator without an implementation for abstract methods " + "Can't instantiate abstract class AbstractDbtBase without an implementation for abstract methods " "'base_cmd', 'build_and_run_cmd'" ) with pytest.raises(TypeError, match=expected_error): - AbstractDbtBaseOperator() + AbstractDbtBase(project_dir="project_dir") @pytest.mark.parametrize("cmd_flags", [["--some-flag"], []]) -@patch("cosmos.operators.base.AbstractDbtBaseOperator.build_and_run_cmd") +@patch("cosmos.operators.base.AbstractDbtBase.build_and_run_cmd") def test_dbt_base_operator_execute(mock_build_and_run_cmd, cmd_flags, monkeypatch): """Tests that the base operator execute method calls the build_and_run_cmd method with the expected arguments.""" - monkeypatch.setattr(AbstractDbtBaseOperator, "add_cmd_flags", lambda _: cmd_flags) - AbstractDbtBaseOperator.__abstractmethods__ = set() + monkeypatch.setattr(AbstractDbtBase, "add_cmd_flags", lambda _: cmd_flags) + AbstractDbtBase.__abstractmethods__ = set() - base_operator = AbstractDbtBaseOperator(task_id="fake_task", project_dir="fake_dir") + base_operator = AbstractDbtBase(task_id="fake_task", project_dir="fake_dir") base_operator.execute(context={}) mock_build_and_run_cmd.assert_called_once_with(context={}, cmd_flags=cmd_flags) @@ -61,7 +63,7 @@ def test_dbt_base_operator_execute(mock_build_and_run_cmd, cmd_flags, monkeypatc @patch("cosmos.operators.base.context_merge") def test_dbt_base_operator_context_merge_called(mock_context_merge): """Tests that the base operator execute method calls the context_merge method with the expected arguments.""" - base_operator = AbstractDbtBaseOperator( + base_operator = AbstractDbtBase( task_id="fake_task", project_dir="fake_dir", extra_context={"extra": "extra"}, @@ -125,7 +127,7 @@ def test_dbt_base_operator_context_merge( expected_context, ): """Tests that the base operator execute method calls and update context""" - base_operator = AbstractDbtBaseOperator( + base_operator = AbstractDbtBase( task_id="fake_task", project_dir="fake_dir", extra_context=extra_context, @@ -173,5 +175,21 @@ def test_dbt_mixin_add_cmd_flags_run_operator(args, expected_flags): def test_abstract_dbt_base_operator_append_env_is_false_by_default(): """Tests that the append_env attribute is set to False by default.""" - base_operator = AbstractDbtBaseOperator(task_id="fake_task", project_dir="fake_dir") + AbstractDbtBase.__abstractmethods__ = set() + base_operator = AbstractDbtBase(task_id="fake_task", project_dir="fake_dir") assert base_operator.append_env is False + + +def test_abstract_dbt_base_is_not_airflow_base_operator(): + AbstractDbtBase.__abstractmethods__ = set() + base_operator = AbstractDbtBase(task_id="fake_task", project_dir="fake_dir") + assert not isinstance(base_operator, BaseOperator) + + +def test_abstract_dbt_base_init_no_super(): + """Test that super().__init__ is not called in AbstractDbtBase""" + init_method = getattr(AbstractDbtBase, "__init__", None) + assert init_method is not None + + source = inspect.getsource(init_method) + assert "super().__init__" not in source diff --git a/tests/operators/test_kubernetes.py b/tests/operators/test_kubernetes.py index aee415f26..0562e28ce 100644 --- a/tests/operators/test_kubernetes.py +++ b/tests/operators/test_kubernetes.py @@ -191,20 +191,27 @@ def test_dbt_kubernetes_build_command(): not module_available, reason="Kubernetes module `airflow.providers.cncf.kubernetes.utils.pod_manager` not available" ) def test_dbt_test_kubernetes_operator_constructor(additional_kwargs, expected_results): + # TODO: Refactor this test so that the asserts test according to the input parameters. test_operator = DbtTestKubernetesOperator( on_warning_callback=(lambda **kwargs: None), **additional_kwargs, **base_kwargs ) - print(additional_kwargs, test_operator.__dict__) - assert isinstance(test_operator.on_success_callback, list) - assert isinstance(test_operator.on_failure_callback, list) - assert test_operator._handle_warnings in test_operator.on_success_callback - assert test_operator._cleanup_pod in test_operator.on_failure_callback - assert len(test_operator.on_success_callback) == expected_results[0] - assert len(test_operator.on_failure_callback) == expected_results[1] + assert isinstance(test_operator.on_success_callback, list) or test_operator.on_success_callback is None + assert isinstance(test_operator.on_failure_callback, list) or test_operator.on_failure_callback is None + + if test_operator.on_success_callback is not None: + assert test_operator._handle_warnings in test_operator.on_success_callback + assert len(test_operator.on_success_callback) == expected_results[0] + + if test_operator.on_failure_callback is not None: + assert test_operator._cleanup_pod in test_operator.on_failure_callback + assert len(test_operator.on_failure_callback) == expected_results[1] + assert test_operator.is_delete_operator_pod_original == expected_results[2] - assert test_operator.on_finish_action_original == OnFinishAction(expected_results[3]) + + expected_action = OnFinishAction(expected_results[3]) + assert test_operator.on_finish_action_original == expected_action @pytest.mark.parametrize( @@ -247,20 +254,28 @@ def test_dbt_test_kubernetes_operator_constructor(additional_kwargs, expected_re not module_available, reason="Kubernetes module `airflow.providers.cncf.kubernetes.utils.pod_manager` not available" ) def test_dbt_source_kubernetes_operator_constructor(additional_kwargs, expected_results): + # TODO: Refactor this test so that the asserts test according to the input parameters. source_operator = DbtSourceKubernetesOperator( on_warning_callback=(lambda **kwargs: None), **additional_kwargs, **base_kwargs ) print(additional_kwargs, source_operator.__dict__) - assert isinstance(source_operator.on_success_callback, list) - assert isinstance(source_operator.on_failure_callback, list) - assert source_operator._handle_warnings in source_operator.on_success_callback - assert source_operator._cleanup_pod in source_operator.on_failure_callback - assert len(source_operator.on_success_callback) == expected_results[0] - assert len(source_operator.on_failure_callback) == expected_results[1] + assert isinstance(source_operator.on_success_callback, list) or source_operator.on_success_callback is None + assert isinstance(source_operator.on_failure_callback, list) or source_operator.on_failure_callback is None + + if source_operator.on_success_callback is not None: + assert source_operator._handle_warnings in source_operator.on_success_callback + assert len(source_operator.on_success_callback) == expected_results[0] + + if source_operator.on_failure_callback is not None: + assert source_operator._cleanup_pod in source_operator.on_failure_callback + assert len(source_operator.on_failure_callback) == expected_results[1] + assert source_operator.is_delete_operator_pod_original == expected_results[2] - assert source_operator.on_finish_action_original == OnFinishAction(expected_results[3]) + + expected_action = OnFinishAction(expected_results[3]) + assert source_operator.on_finish_action_original == expected_action class FakePodManager: diff --git a/tests/operators/test_local.py b/tests/operators/test_local.py index 69164a194..34c34d895 100644 --- a/tests/operators/test_local.py +++ b/tests/operators/test_local.py @@ -27,6 +27,7 @@ from cosmos.exceptions import CosmosDbtRunError, CosmosValueError from cosmos.hooks.subprocess import FullOutputSubprocessResult from cosmos.operators.local import ( + AbstractDbtLocalBase, DbtBuildLocalOperator, DbtCloneLocalOperator, DbtCompileLocalOperator, @@ -776,47 +777,89 @@ def test_store_compiled_sql() -> None: ( DbtSeedLocalOperator, {"full_refresh": True}, - {"context": {}, "env": {}, "cmd_flags": ["seed", "--full-refresh"]}, + { + "context": {}, + "env": {}, + "cmd_flags": ["seed", "--full-refresh"], + "run_as_async": False, + "async_context": None, + }, ), ( DbtBuildLocalOperator, {"full_refresh": True}, - {"context": {}, "env": {}, "cmd_flags": ["build", "--full-refresh"]}, + { + "context": {}, + "env": {}, + "cmd_flags": ["build", "--full-refresh"], + "run_as_async": False, + "async_context": None, + }, ), ( DbtRunLocalOperator, {"full_refresh": True}, - {"context": {}, "env": {}, "cmd_flags": ["run", "--full-refresh"]}, + { + "context": {}, + "env": {}, + "cmd_flags": ["run", "--full-refresh"], + "run_as_async": False, + "async_context": None, + }, ), ( DbtCloneLocalOperator, {"full_refresh": True}, - {"context": {}, "env": {}, "cmd_flags": ["clone", "--full-refresh"]}, + { + "context": {}, + "env": {}, + "cmd_flags": ["clone", "--full-refresh"], + "run_as_async": False, + "async_context": None, + }, ), ( DbtTestLocalOperator, {}, - {"context": {}, "env": {}, "cmd_flags": ["test"]}, + {"context": {}, "env": {}, "cmd_flags": ["test"], "run_as_async": False, "async_context": None}, ), ( DbtTestLocalOperator, {"select": []}, - {"context": {}, "env": {}, "cmd_flags": ["test"]}, + {"context": {}, "env": {}, "cmd_flags": ["test"], "run_as_async": False, "async_context": None}, ), ( DbtTestLocalOperator, {"full_refresh": True, "select": ["tag:daily"], "exclude": ["tag:disabled"]}, - {"context": {}, "env": {}, "cmd_flags": ["test", "--select", "tag:daily", "--exclude", "tag:disabled"]}, + { + "context": {}, + "env": {}, + "cmd_flags": ["test", "--select", "tag:daily", "--exclude", "tag:disabled"], + "run_as_async": False, + "async_context": None, + }, ), ( DbtTestLocalOperator, {"full_refresh": True, "selector": "nightly_snowplow"}, - {"context": {}, "env": {}, "cmd_flags": ["test", "--selector", "nightly_snowplow"]}, + { + "context": {}, + "env": {}, + "cmd_flags": ["test", "--selector", "nightly_snowplow"], + "run_as_async": False, + "async_context": None, + }, ), ( DbtRunOperationLocalOperator, {"args": {"days": 7, "dry_run": True}, "macro_name": "bla"}, - {"context": {}, "env": {}, "cmd_flags": ["run-operation", "bla", "--args", "days: 7\ndry_run: true\n"]}, + { + "context": {}, + "env": {}, + "cmd_flags": ["run-operation", "bla", "--args", "days: 7\ndry_run: true\n"], + "run_as_async": False, + "async_context": None, + }, ), ], ) @@ -1317,3 +1360,92 @@ def test_upload_compiled_sql_should_upload(mock_configure_remote, mock_object_st expected_dest_path = f"mock_remote_path/test_dag/compiled/{rel_path.lstrip('/')}" mock_object_storage_path.assert_any_call(expected_dest_path, conn_id="mock_conn_id") mock_object_storage_path.return_value.copy.assert_any_call(mock_object_storage_path.return_value) + + +MOCK_ADAPTER_CALLABLE_MAP = { + "snowflake": MagicMock(), + "bigquery": MagicMock(), +} + + +@pytest.fixture +def mock_adapter_map(monkeypatch): + monkeypatch.setattr( + "cosmos.operators.local.PROFILE_TYPE_MOCK_ADAPTER_CALLABLE_MAP", + MOCK_ADAPTER_CALLABLE_MAP, + ) + + +def test_mock_dbt_adapter_valid_context(mock_adapter_map): + """ + Test that the _mock_dbt_adapter method calls the correct mock adapter function + when provided with a valid async_context. + """ + async_context = { + "async_operator": MagicMock(), + "profile_type": "bigquery", + } + AbstractDbtLocalBase.__abstractmethods__ = set() + operator = AbstractDbtLocalBase(task_id="test_task", project_dir="test_project", profile_config=MagicMock()) + operator._mock_dbt_adapter(async_context) + + MOCK_ADAPTER_CALLABLE_MAP["bigquery"].assert_called_once() + + +def test_mock_dbt_adapter_missing_async_context(): + """ + Test that the _mock_dbt_adapter method raises a CosmosValueError + when async_context is None. + """ + AbstractDbtLocalBase.__abstractmethods__ = set() + operator = AbstractDbtLocalBase(task_id="test_task", project_dir="test_project", profile_config=MagicMock()) + with pytest.raises(CosmosValueError, match="`async_context` is necessary for running the model asynchronously"): + operator._mock_dbt_adapter(None) + + +def test_mock_dbt_adapter_missing_async_operator(): + """ + Test that the _mock_dbt_adapter method raises a CosmosValueError + when async_operator is missing in async_context. + """ + async_context = { + "profile_type": "snowflake", + } + AbstractDbtLocalBase.__abstractmethods__ = set() + operator = AbstractDbtLocalBase(task_id="test_task", project_dir="test_project", profile_config=MagicMock()) + with pytest.raises( + CosmosValueError, match="`async_operator` needs to be specified in `async_context` when running as async" + ): + operator._mock_dbt_adapter(async_context) + + +def test_mock_dbt_adapter_missing_profile_type(): + """ + Test that the _mock_dbt_adapter method raises a CosmosValueError + when profile_type is missing in async_context. + """ + async_context = { + "async_operator": MagicMock(), + } + AbstractDbtLocalBase.__abstractmethods__ = set() + operator = AbstractDbtLocalBase(task_id="test_task", project_dir="test_project", profile_config=MagicMock()) + with pytest.raises(CosmosValueError, match="`profile_type` needs to be specified in `async_context`"): + operator._mock_dbt_adapter(async_context) + + +def test_mock_dbt_adapter_unsupported_profile_type(mock_adapter_map): + """ + Test that the _mock_dbt_adapter method raises a CosmosValueError + when the profile_type is not supported. + """ + async_context = { + "async_operator": MagicMock(), + "profile_type": "unsupported_profile", + } + AbstractDbtLocalBase.__abstractmethods__ = set() + operator = AbstractDbtLocalBase(task_id="test_task", project_dir="test_project", profile_config=MagicMock()) + with pytest.raises( + CosmosValueError, + match="Mock adapter callable function not available for profile_type unsupported_profile", + ): + operator._mock_dbt_adapter(async_context)