-
Notifications
You must be signed in to change notification settings - Fork 184
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add structure to support multiple db for async operator execution (#1483
) Pluggable Async Operator Interface ------------------------------------ This PR enhances the initial async operator support in Cosmos, as introduced in [PR #1230](#1230). The changes decouple the DbtRunAirflowAsyncOperator from BigQueryInsertJobOperator, making it more flexible and allowing support for async operators with other data sources in the future Introducing the DbtRunAirflowAsyncFactoryOperator class, which dynamically selects the parent class containing the async operator implementation based on dbt profile. I’ve added a template for implementing the Databricks async operator. ~ATM moved async operator-related code at path `cosmos/operators/_async/`, but open for suggestion~ After discussing this with the team, I have moved async operator-related code at path `cosmos/operators/_asynchronous/` ## Design principle Introduced the `DbtRunAirflowAsyncFactoryOperator` class that uses a Factory Method design pattern combined with dynamic inheritance at runtime. - **Factory Method:** The `create_async_operator()` method generates a specific async operator class based on the profile_config provided. This allows the operator to adapt to different types of async operator at runtime. - **Dynamic Inheritance:** The class dynamically changes its base class (__bases__) to use the async operator class created in the factory method. This ensures the correct async class is used during execution. - **Execution:** The execute() method calls the `super().execute()` to trigger the execution logic, but it dynamically uses the appropriate operator class for async behavior. ### Class hierarchy ``` BigQueryInsertJobOperator | DbtRunLocalOperator DbtRunAirflowAsyncBigqueryOperator DbtRunAirflowAsyncDatabricksOperator (inject these parent class at runtime) \ / DbtRunFactoryAirflowAsyncOperator | DbtRunAirflowAsyncOperator ``` ## How to add a new async db operator - Implement the operator at the path `cosmos/operators/_asynchronous/` - The operator module should be in the format of: `cosmos.operators._asynchronous.{profile_type}.{dbt_class}{_snake_case_to_camelcase(execution_mode)}{profile_type.capitalize()}Operator` - For details example, I have added a dummy implementation for `Databricks` ## Example DAG ```python import os from datetime import datetime from pathlib import Path from cosmos import DbtDag, ExecutionConfig, ExecutionMode, ProfileConfig, ProjectConfig, RenderConfig from cosmos.profiles import GoogleCloudServiceAccountDictProfileMapping DEFAULT_DBT_ROOT_PATH = Path(__file__).parent / "dbt" DBT_ROOT_PATH = Path(os.getenv("DBT_ROOT_PATH", DEFAULT_DBT_ROOT_PATH)) profile_config = ProfileConfig( profile_name="default", target_name="dev", profile_mapping=GoogleCloudServiceAccountDictProfileMapping( conn_id="gcp_gs_conn", profile_args={"dataset": "release_17", "project": "astronomer-dag-authoring"} ), ) # [START airflow_async_execution_mode_example] simple_dag_async = DbtDag( # dbt/cosmos-specific parameters project_config=ProjectConfig( DBT_ROOT_PATH / "original_jaffle_shop", ), profile_config=profile_config, execution_config=ExecutionConfig( execution_mode=ExecutionMode.AIRFLOW_ASYNC, ), render_config=RenderConfig( select=["path:models"], # test_behavior=TestBehavior.NONE ), # normal dag parameters schedule_interval=None, start_date=datetime(2023, 1, 1), catchup=False, dag_id="simple_dag_async", tags=["simple"], operator_args={"full_refresh": True, "location": "northamerica-northeast1"}, ) # [END airflow_async_execution_mode_example] ``` <img width="1687" alt="Screenshot 2025-01-23 at 3 30 48 PM" src="https://github.com/user-attachments/assets/bd7a13cc-7a52-4d55-947e-23a618b47e68" /> **Graph View** <img width="1688" alt="Screenshot 2025-01-27 at 3 41 14 PM" src="https://github.com/user-attachments/assets/d74a1851-39d9-4575-b9cb-11c286094643" /> closes: #1238 closes: #1239
- Loading branch information
1 parent
2b99777
commit bdc8746
Showing
12 changed files
with
379 additions
and
121 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import importlib | ||
import logging | ||
from abc import ABCMeta | ||
from typing import Any, Sequence | ||
|
||
from airflow.utils.context import Context | ||
|
||
from cosmos.airflow.graph import _snake_case_to_camelcase | ||
from cosmos.config import ProfileConfig | ||
from cosmos.constants import ExecutionMode | ||
from cosmos.operators.local import DbtRunLocalOperator | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
|
||
def _create_async_operator_class(profile_type: str, dbt_class: str) -> Any: | ||
""" | ||
Dynamically constructs and returns an asynchronous operator class for the given profile type and dbt class name. | ||
The function constructs a class path string for an asynchronous operator, based on the provided `profile_type` and | ||
`dbt_class`. It attempts to import the corresponding class dynamically and return it. If the class cannot be found, | ||
it falls back to returning the `DbtRunLocalOperator` class. | ||
:param profile_type: The dbt profile type | ||
:param dbt_class: The dbt class name. Example DbtRun, DbtTest. | ||
""" | ||
execution_mode = ExecutionMode.AIRFLOW_ASYNC.value | ||
class_path = f"cosmos.operators._asynchronous.{profile_type}.{dbt_class}{_snake_case_to_camelcase(execution_mode)}{profile_type.capitalize()}Operator" | ||
try: | ||
module_path, class_name = class_path.rsplit(".", 1) | ||
module = importlib.import_module(module_path) | ||
operator_class = getattr(module, class_name) | ||
return operator_class | ||
except (ModuleNotFoundError, AttributeError): | ||
log.info("Error in loading class: %s. falling back to DbtRunLocalOperator", class_path) | ||
return DbtRunLocalOperator | ||
|
||
|
||
class DbtRunAirflowAsyncFactoryOperator(DbtRunLocalOperator, metaclass=ABCMeta): # type: ignore[misc] | ||
|
||
template_fields: Sequence[str] = DbtRunLocalOperator.template_fields + ("project_dir",) # type: ignore[operator] | ||
|
||
def __init__(self, project_dir: str, profile_config: ProfileConfig, **kwargs: Any): | ||
self.project_dir = project_dir | ||
self.profile_config = profile_config | ||
|
||
async_operator_class = self.create_async_operator() | ||
|
||
# Dynamically modify the base classes. | ||
# This is necessary because the async operator class is only known at runtime. | ||
# When using composition instead of inheritance to initialize the async class and run its execute method, | ||
# Airflow throws a `DuplicateTaskIdFound` error. | ||
DbtRunAirflowAsyncFactoryOperator.__bases__ = (async_operator_class,) | ||
super().__init__(project_dir=project_dir, profile_config=profile_config, **kwargs) | ||
|
||
def create_async_operator(self) -> Any: | ||
|
||
profile_type = self.profile_config.get_profile_type() | ||
|
||
async_class_operator = _create_async_operator_class(profile_type, "DbtRun") | ||
|
||
return async_class_operator | ||
|
||
def execute(self, context: Context) -> None: | ||
super().execute(context) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
from __future__ import annotations | ||
|
||
from pathlib import Path | ||
from typing import TYPE_CHECKING, Any, Sequence | ||
|
||
from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook | ||
from airflow.providers.google.cloud.operators.bigquery import BigQueryInsertJobOperator | ||
from airflow.utils.context import Context | ||
|
||
from cosmos import settings | ||
from cosmos.config import ProfileConfig | ||
from cosmos.exceptions import CosmosValueError | ||
from cosmos.settings import remote_target_path, remote_target_path_conn_id | ||
|
||
|
||
class DbtRunAirflowAsyncBigqueryOperator(BigQueryInsertJobOperator): # type: ignore[misc] | ||
|
||
template_fields: Sequence[str] = ( | ||
"full_refresh", | ||
"gcp_project", | ||
"dataset", | ||
"location", | ||
) | ||
|
||
def __init__( | ||
self, | ||
project_dir: str, | ||
profile_config: ProfileConfig, | ||
extra_context: dict[str, Any] | None = None, | ||
**kwargs: Any, | ||
): | ||
self.project_dir = project_dir | ||
self.profile_config = profile_config | ||
self.gcp_conn_id = self.profile_config.profile_mapping.conn_id # type: ignore | ||
profile = self.profile_config.profile_mapping.profile # type: ignore | ||
self.gcp_project = profile["project"] | ||
self.dataset = profile["dataset"] | ||
self.extra_context = extra_context or {} | ||
self.full_refresh = None | ||
if "full_refresh" in kwargs: | ||
self.full_refresh = kwargs.pop("full_refresh") | ||
self.configuration: dict[str, Any] = {} | ||
super().__init__( | ||
gcp_conn_id=self.gcp_conn_id, | ||
configuration=self.configuration, | ||
deferrable=True, | ||
**kwargs, | ||
) | ||
|
||
def get_remote_sql(self) -> str: | ||
if not settings.AIRFLOW_IO_AVAILABLE: | ||
raise CosmosValueError(f"Cosmos async support is only available starting in Airflow 2.8 or later.") | ||
from airflow.io.path import ObjectStoragePath | ||
|
||
file_path = self.extra_context["dbt_node_config"]["file_path"] # type: ignore | ||
dbt_dag_task_group_identifier = self.extra_context["dbt_dag_task_group_identifier"] | ||
|
||
remote_target_path_str = str(remote_target_path).rstrip("/") | ||
|
||
if TYPE_CHECKING: # pragma: no cover | ||
assert self.project_dir is not None | ||
|
||
project_dir_parent = str(Path(self.project_dir).parent) | ||
relative_file_path = str(file_path).replace(project_dir_parent, "").lstrip("/") | ||
remote_model_path = f"{remote_target_path_str}/{dbt_dag_task_group_identifier}/compiled/{relative_file_path}" | ||
|
||
object_storage_path = ObjectStoragePath(remote_model_path, conn_id=remote_target_path_conn_id) | ||
with object_storage_path.open() as fp: # type: ignore | ||
return fp.read() # type: ignore | ||
|
||
def drop_table_sql(self) -> None: | ||
model_name = self.extra_context["dbt_node_config"]["resource_name"] # type: ignore | ||
sql = f"DROP TABLE IF EXISTS {self.gcp_project}.{self.dataset}.{model_name};" | ||
|
||
hook = BigQueryHook( | ||
gcp_conn_id=self.gcp_conn_id, | ||
impersonation_chain=self.impersonation_chain, | ||
) | ||
self.configuration = { | ||
"query": { | ||
"query": sql, | ||
"useLegacySql": False, | ||
} | ||
} | ||
hook.insert_job(configuration=self.configuration, location=self.location, project_id=self.gcp_project) | ||
|
||
def execute(self, context: Context) -> Any | None: | ||
|
||
if not self.full_refresh: | ||
raise CosmosValueError("The async execution only supported for full_refresh") | ||
else: | ||
# It may be surprising to some, but the dbt-core --full-refresh argument fully drops the table before populating it | ||
# https://github.com/dbt-labs/dbt-core/blob/5e9f1b515f37dfe6cdae1ab1aa7d190b92490e24/core/dbt/context/base.py#L662-L666 | ||
# https://docs.getdbt.com/reference/resource-configs/full_refresh#recommendation | ||
# We're emulating this behaviour here | ||
# The compiled SQL has several limitations here, but these will be addressed in the PR: https://github.com/astronomer/astronomer-cosmos/pull/1474. | ||
self.drop_table_sql() | ||
sql = self.get_remote_sql() | ||
model_name = self.extra_context["dbt_node_config"]["resource_name"] # type: ignore | ||
# prefix explicit create command to create table | ||
sql = f"CREATE TABLE {self.gcp_project}.{self.dataset}.{model_name} AS {sql}" | ||
self.configuration = { | ||
"query": { | ||
"query": sql, | ||
"useLegacySql": False, | ||
} | ||
} | ||
return super().execute(context) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
# TODO: Implement it | ||
|
||
from typing import Any | ||
|
||
from airflow.models.baseoperator import BaseOperator | ||
from airflow.utils.context import Context | ||
|
||
|
||
class DbtRunAirflowAsyncDatabricksOperator(BaseOperator): | ||
def __init__(self, *args: Any, **kwargs: Any): | ||
super().__init__(*args, **kwargs) | ||
|
||
def execute(self, context: Context) -> None: | ||
raise NotImplementedError() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import json | ||
from unittest.mock import patch | ||
|
||
import pytest | ||
from airflow.models.connection import Connection | ||
|
||
|
||
@pytest.fixture() | ||
def mock_bigquery_conn(): # type: ignore | ||
""" | ||
Mocks and returns an Airflow BigQuery connection. | ||
""" | ||
extra = { | ||
"project": "my_project", | ||
"key_path": "my_key_path.json", | ||
} | ||
conn = Connection( | ||
conn_id="my_bigquery_connection", | ||
conn_type="google_cloud_platform", | ||
extra=json.dumps(extra), | ||
) | ||
|
||
with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): | ||
yield conn |
Empty file.
Oops, something went wrong.