Skip to content

Commit

Permalink
Add structure to support multiple db for async operator execution (#1483
Browse files Browse the repository at this point in the history
)

Pluggable Async Operator Interface 
------------------------------------

This PR enhances the initial async operator support in Cosmos,
as introduced in [PR
#1230](#1230). The
changes decouple the DbtRunAirflowAsyncOperator
from BigQueryInsertJobOperator, making it more flexible and allowing
support for async operators with other data sources in the future

Introducing the DbtRunAirflowAsyncFactoryOperator class,
which dynamically selects the parent class containing 
the async operator implementation based on dbt profile.

I’ve added a template for implementing the Databricks async operator.

~ATM moved async operator-related code at path
`cosmos/operators/_async/`, but open for suggestion~
After discussing this with the team, I have moved async operator-related
code at path `cosmos/operators/_asynchronous/`

## Design principle 

Introduced the `DbtRunAirflowAsyncFactoryOperator` class that uses a
Factory Method design pattern combined with dynamic inheritance at
runtime.
- **Factory Method:** The `create_async_operator()` method generates a
specific async operator class based on the profile_config provided. This
allows the operator to adapt to different types of async operator at
runtime.
- **Dynamic Inheritance:** The class dynamically changes its base class
(__bases__) to use the async operator class created in the factory
method. This ensures the correct async class is used during execution.
- **Execution:** The execute() method calls the `super().execute()` to
trigger the execution logic, but it dynamically uses the appropriate
operator class for async behavior.

### Class hierarchy
```
                                    BigQueryInsertJobOperator
                                               |
DbtRunLocalOperator  DbtRunAirflowAsyncBigqueryOperator DbtRunAirflowAsyncDatabricksOperator (inject these parent class at runtime)
              \              /
             DbtRunFactoryAirflowAsyncOperator
                         |
            DbtRunAirflowAsyncOperator
``` 

## How to add a new async db operator
- Implement the operator at the path `cosmos/operators/_asynchronous/`
- The operator module should be in the format of:
`cosmos.operators._asynchronous.{profile_type}.{dbt_class}{_snake_case_to_camelcase(execution_mode)}{profile_type.capitalize()}Operator`
- For details example, I have added a dummy implementation for
`Databricks`

## Example DAG
```python
import os
from datetime import datetime
from pathlib import Path

from cosmos import DbtDag, ExecutionConfig, ExecutionMode, ProfileConfig, ProjectConfig, RenderConfig
from cosmos.profiles import GoogleCloudServiceAccountDictProfileMapping

DEFAULT_DBT_ROOT_PATH = Path(__file__).parent / "dbt"
DBT_ROOT_PATH = Path(os.getenv("DBT_ROOT_PATH", DEFAULT_DBT_ROOT_PATH))

profile_config = ProfileConfig(
    profile_name="default",
    target_name="dev",
    profile_mapping=GoogleCloudServiceAccountDictProfileMapping(
        conn_id="gcp_gs_conn", profile_args={"dataset": "release_17", "project": "astronomer-dag-authoring"}
    ),
)


# [START airflow_async_execution_mode_example]
simple_dag_async = DbtDag(
    # dbt/cosmos-specific parameters
    project_config=ProjectConfig(
        DBT_ROOT_PATH / "original_jaffle_shop",
    ),
    profile_config=profile_config,
    execution_config=ExecutionConfig(
        execution_mode=ExecutionMode.AIRFLOW_ASYNC,
    ),
    render_config=RenderConfig(
        select=["path:models"],
        # test_behavior=TestBehavior.NONE
    ),
    # normal dag parameters
    schedule_interval=None,
    start_date=datetime(2023, 1, 1),
    catchup=False,
    dag_id="simple_dag_async",
    tags=["simple"],
    operator_args={"full_refresh": True, "location": "northamerica-northeast1"},
)
# [END airflow_async_execution_mode_example]
``` 

<img width="1687" alt="Screenshot 2025-01-23 at 3 30 48 PM"
src="https://github.com/user-attachments/assets/bd7a13cc-7a52-4d55-947e-23a618b47e68"
/>

**Graph View**
<img width="1688" alt="Screenshot 2025-01-27 at 3 41 14 PM"
src="https://github.com/user-attachments/assets/d74a1851-39d9-4575-b9cb-11c286094643"
/>


closes: #1238
closes: #1239
  • Loading branch information
pankajastro authored Jan 27, 2025
1 parent 2b99777 commit bdc8746
Show file tree
Hide file tree
Showing 12 changed files with 379 additions and 121 deletions.
Empty file.
65 changes: 65 additions & 0 deletions cosmos/operators/_asynchronous/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import importlib
import logging
from abc import ABCMeta
from typing import Any, Sequence

from airflow.utils.context import Context

from cosmos.airflow.graph import _snake_case_to_camelcase
from cosmos.config import ProfileConfig
from cosmos.constants import ExecutionMode
from cosmos.operators.local import DbtRunLocalOperator

log = logging.getLogger(__name__)


def _create_async_operator_class(profile_type: str, dbt_class: str) -> Any:
"""
Dynamically constructs and returns an asynchronous operator class for the given profile type and dbt class name.
The function constructs a class path string for an asynchronous operator, based on the provided `profile_type` and
`dbt_class`. It attempts to import the corresponding class dynamically and return it. If the class cannot be found,
it falls back to returning the `DbtRunLocalOperator` class.
:param profile_type: The dbt profile type
:param dbt_class: The dbt class name. Example DbtRun, DbtTest.
"""
execution_mode = ExecutionMode.AIRFLOW_ASYNC.value
class_path = f"cosmos.operators._asynchronous.{profile_type}.{dbt_class}{_snake_case_to_camelcase(execution_mode)}{profile_type.capitalize()}Operator"
try:
module_path, class_name = class_path.rsplit(".", 1)
module = importlib.import_module(module_path)
operator_class = getattr(module, class_name)
return operator_class
except (ModuleNotFoundError, AttributeError):
log.info("Error in loading class: %s. falling back to DbtRunLocalOperator", class_path)
return DbtRunLocalOperator


class DbtRunAirflowAsyncFactoryOperator(DbtRunLocalOperator, metaclass=ABCMeta): # type: ignore[misc]

template_fields: Sequence[str] = DbtRunLocalOperator.template_fields + ("project_dir",) # type: ignore[operator]

def __init__(self, project_dir: str, profile_config: ProfileConfig, **kwargs: Any):
self.project_dir = project_dir
self.profile_config = profile_config

async_operator_class = self.create_async_operator()

# Dynamically modify the base classes.
# This is necessary because the async operator class is only known at runtime.
# When using composition instead of inheritance to initialize the async class and run its execute method,
# Airflow throws a `DuplicateTaskIdFound` error.
DbtRunAirflowAsyncFactoryOperator.__bases__ = (async_operator_class,)
super().__init__(project_dir=project_dir, profile_config=profile_config, **kwargs)

def create_async_operator(self) -> Any:

profile_type = self.profile_config.get_profile_type()

async_class_operator = _create_async_operator_class(profile_type, "DbtRun")

return async_class_operator

def execute(self, context: Context) -> None:
super().execute(context)
108 changes: 108 additions & 0 deletions cosmos/operators/_asynchronous/bigquery.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING, Any, Sequence

from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook
from airflow.providers.google.cloud.operators.bigquery import BigQueryInsertJobOperator
from airflow.utils.context import Context

from cosmos import settings
from cosmos.config import ProfileConfig
from cosmos.exceptions import CosmosValueError
from cosmos.settings import remote_target_path, remote_target_path_conn_id


class DbtRunAirflowAsyncBigqueryOperator(BigQueryInsertJobOperator): # type: ignore[misc]

template_fields: Sequence[str] = (
"full_refresh",
"gcp_project",
"dataset",
"location",
)

def __init__(
self,
project_dir: str,
profile_config: ProfileConfig,
extra_context: dict[str, Any] | None = None,
**kwargs: Any,
):
self.project_dir = project_dir
self.profile_config = profile_config
self.gcp_conn_id = self.profile_config.profile_mapping.conn_id # type: ignore
profile = self.profile_config.profile_mapping.profile # type: ignore
self.gcp_project = profile["project"]
self.dataset = profile["dataset"]
self.extra_context = extra_context or {}
self.full_refresh = None
if "full_refresh" in kwargs:
self.full_refresh = kwargs.pop("full_refresh")
self.configuration: dict[str, Any] = {}
super().__init__(
gcp_conn_id=self.gcp_conn_id,
configuration=self.configuration,
deferrable=True,
**kwargs,
)

def get_remote_sql(self) -> str:
if not settings.AIRFLOW_IO_AVAILABLE:
raise CosmosValueError(f"Cosmos async support is only available starting in Airflow 2.8 or later.")
from airflow.io.path import ObjectStoragePath

file_path = self.extra_context["dbt_node_config"]["file_path"] # type: ignore
dbt_dag_task_group_identifier = self.extra_context["dbt_dag_task_group_identifier"]

remote_target_path_str = str(remote_target_path).rstrip("/")

if TYPE_CHECKING: # pragma: no cover
assert self.project_dir is not None

project_dir_parent = str(Path(self.project_dir).parent)
relative_file_path = str(file_path).replace(project_dir_parent, "").lstrip("/")
remote_model_path = f"{remote_target_path_str}/{dbt_dag_task_group_identifier}/compiled/{relative_file_path}"

object_storage_path = ObjectStoragePath(remote_model_path, conn_id=remote_target_path_conn_id)
with object_storage_path.open() as fp: # type: ignore
return fp.read() # type: ignore

def drop_table_sql(self) -> None:
model_name = self.extra_context["dbt_node_config"]["resource_name"] # type: ignore
sql = f"DROP TABLE IF EXISTS {self.gcp_project}.{self.dataset}.{model_name};"

hook = BigQueryHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)
self.configuration = {
"query": {
"query": sql,
"useLegacySql": False,
}
}
hook.insert_job(configuration=self.configuration, location=self.location, project_id=self.gcp_project)

def execute(self, context: Context) -> Any | None:

if not self.full_refresh:
raise CosmosValueError("The async execution only supported for full_refresh")
else:
# It may be surprising to some, but the dbt-core --full-refresh argument fully drops the table before populating it
# https://github.com/dbt-labs/dbt-core/blob/5e9f1b515f37dfe6cdae1ab1aa7d190b92490e24/core/dbt/context/base.py#L662-L666
# https://docs.getdbt.com/reference/resource-configs/full_refresh#recommendation
# We're emulating this behaviour here
# The compiled SQL has several limitations here, but these will be addressed in the PR: https://github.com/astronomer/astronomer-cosmos/pull/1474.
self.drop_table_sql()
sql = self.get_remote_sql()
model_name = self.extra_context["dbt_node_config"]["resource_name"] # type: ignore
# prefix explicit create command to create table
sql = f"CREATE TABLE {self.gcp_project}.{self.dataset}.{model_name} AS {sql}"
self.configuration = {
"query": {
"query": sql,
"useLegacySql": False,
}
}
return super().execute(context)
14 changes: 14 additions & 0 deletions cosmos/operators/_asynchronous/databricks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# TODO: Implement it

from typing import Any

from airflow.models.baseoperator import BaseOperator
from airflow.utils.context import Context


class DbtRunAirflowAsyncDatabricksOperator(BaseOperator):
def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)

def execute(self, context: Context) -> None:
raise NotImplementedError()
115 changes: 9 additions & 106 deletions cosmos/operators/airflow_async.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,9 @@
from __future__ import annotations

import inspect
from pathlib import Path
from typing import TYPE_CHECKING, Any, Sequence

from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook
from airflow.providers.google.cloud.operators.bigquery import BigQueryInsertJobOperator
from airflow.utils.context import Context

from cosmos import settings
from cosmos.config import ProfileConfig
from cosmos.exceptions import CosmosValueError
from cosmos.operators._asynchronous.base import DbtRunAirflowAsyncFactoryOperator
from cosmos.operators.base import AbstractDbtBaseOperator
from cosmos.operators.local import (
DbtBuildLocalOperator,
Expand All @@ -24,7 +17,6 @@
DbtSourceLocalOperator,
DbtTestLocalOperator,
)
from cosmos.settings import remote_target_path, remote_target_path_conn_id

_SUPPORTED_DATABASES = ["bigquery"]

Expand All @@ -35,8 +27,8 @@

class DbtBaseAirflowAsyncOperator(BaseOperator, metaclass=ABCMeta):
def __init__(self, **kwargs) -> None: # type: ignore
self.location = kwargs.pop("location")
self.configuration = kwargs.pop("configuration", {})
if "location" in kwargs:
kwargs.pop("location")
super().__init__(**kwargs)


Expand All @@ -60,47 +52,17 @@ class DbtSourceAirflowAsyncOperator(DbtBaseAirflowAsyncOperator, DbtSourceLocalO
pass


class DbtRunAirflowAsyncOperator(BigQueryInsertJobOperator): # type: ignore

template_fields: Sequence[str] = (
"full_refresh",
"project_dir",
"gcp_project",
"dataset",
"location",
)
class DbtRunAirflowAsyncOperator(DbtRunAirflowAsyncFactoryOperator): # type: ignore

def __init__( # type: ignore
self,
project_dir: str,
profile_config: ProfileConfig,
location: str, # This is a mandatory parameter when using BigQueryInsertJobOperator with deferrable=True
full_refresh: bool = False,
extra_context: dict[str, object] | None = None,
configuration: dict[str, object] | None = None,
**kwargs,
) -> None:
# dbt task param
self.project_dir = project_dir
self.extra_context = extra_context or {}
self.full_refresh = full_refresh
self.profile_config = profile_config
if not self.profile_config or not self.profile_config.profile_mapping:
raise CosmosValueError(f"Cosmos async support is only available when using ProfileMapping")

self.profile_type: str = profile_config.get_profile_type() # type: ignore
if self.profile_type not in _SUPPORTED_DATABASES:
raise CosmosValueError(f"Async run are only supported: {_SUPPORTED_DATABASES}")

# airflow task param
self.location = location
self.configuration = configuration or {}
self.gcp_conn_id = self.profile_config.profile_mapping.conn_id # type: ignore
profile = self.profile_config.profile_mapping.profile
self.gcp_project = profile["project"]
self.dataset = profile["dataset"]

# Cosmos attempts to pass many kwargs that BigQueryInsertJobOperator simply does not accept.

# Cosmos attempts to pass many kwargs that async operator simply does not accept.
# We need to pop them.
clean_kwargs = {}
non_async_args = set(inspect.signature(AbstractDbtBaseOperator.__init__).parameters.keys())
Expand All @@ -113,71 +75,12 @@ def __init__( # type: ignore

# The following are the minimum required parameters to run BigQueryInsertJobOperator using the deferrable mode
super().__init__(
gcp_conn_id=self.gcp_conn_id,
configuration=self.configuration,
location=self.location,
deferrable=True,
project_dir=project_dir,
profile_config=profile_config,
extra_context=extra_context,
**clean_kwargs,
)

def get_remote_sql(self) -> str:
if not settings.AIRFLOW_IO_AVAILABLE:
raise CosmosValueError(f"Cosmos async support is only available starting in Airflow 2.8 or later.")
from airflow.io.path import ObjectStoragePath

file_path = self.extra_context["dbt_node_config"]["file_path"] # type: ignore
dbt_dag_task_group_identifier = self.extra_context["dbt_dag_task_group_identifier"]

remote_target_path_str = str(remote_target_path).rstrip("/")

if TYPE_CHECKING:
assert self.project_dir is not None

project_dir_parent = str(Path(self.project_dir).parent)
relative_file_path = str(file_path).replace(project_dir_parent, "").lstrip("/")
remote_model_path = f"{remote_target_path_str}/{dbt_dag_task_group_identifier}/compiled/{relative_file_path}"

object_storage_path = ObjectStoragePath(remote_model_path, conn_id=remote_target_path_conn_id)
with object_storage_path.open() as fp: # type: ignore
return fp.read() # type: ignore

def drop_table_sql(self) -> None:
model_name = self.extra_context["dbt_node_config"]["resource_name"] # type: ignore
sql = f"DROP TABLE IF EXISTS {self.gcp_project}.{self.dataset}.{model_name};"

hook = BigQueryHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)
self.configuration = {
"query": {
"query": sql,
"useLegacySql": False,
}
}
hook.insert_job(configuration=self.configuration, location=self.location, project_id=self.gcp_project)

def execute(self, context: Context) -> Any | None:
if not self.full_refresh:
raise CosmosValueError("The async execution only supported for full_refresh")
else:
# It may be surprising to some, but the dbt-core --full-refresh argument fully drops the table before populating it
# https://github.com/dbt-labs/dbt-core/blob/5e9f1b515f37dfe6cdae1ab1aa7d190b92490e24/core/dbt/context/base.py#L662-L666
# https://docs.getdbt.com/reference/resource-configs/full_refresh#recommendation
# We're emulating this behaviour here
self.drop_table_sql()
sql = self.get_remote_sql()
model_name = self.extra_context["dbt_node_config"]["resource_name"] # type: ignore
# prefix explicit create command to create table
sql = f"CREATE TABLE {self.gcp_project}.{self.dataset}.{model_name} AS {sql}"
self.configuration = {
"query": {
"query": sql,
"useLegacySql": False,
}
}
return super().execute(context)


class DbtTestAirflowAsyncOperator(DbtBaseAirflowAsyncOperator, DbtTestLocalOperator): # type: ignore
pass
Expand Down
1 change: 0 additions & 1 deletion tests/airflow/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,6 @@ def test_build_airflow_graph_with_dbt_compile_task():
"project_dir": SAMPLE_PROJ_PATH,
"conn_id": "fake_conn",
"profile_config": bigquery_profile_config,
"location": "",
}
render_config = RenderConfig(
select=["tag:some"],
Expand Down
24 changes: 24 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import json
from unittest.mock import patch

import pytest
from airflow.models.connection import Connection


@pytest.fixture()
def mock_bigquery_conn(): # type: ignore
"""
Mocks and returns an Airflow BigQuery connection.
"""
extra = {
"project": "my_project",
"key_path": "my_key_path.json",
}
conn = Connection(
conn_id="my_bigquery_connection",
conn_type="google_cloud_platform",
extra=json.dumps(extra),
)

with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
yield conn
Empty file.
Loading

0 comments on commit bdc8746

Please sign in to comment.