Skip to content

Commit

Permalink
Address @tatiana's review feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
pankajkoti committed Jan 23, 2025
1 parent e521051 commit 8ef4ac5
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 29 deletions.
31 changes: 29 additions & 2 deletions cosmos/mocked_dbt_adapters.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from __future__ import annotations

from typing import Any

from cosmos.constants import BIGQUERY_PROFILE_TYPE
from cosmos.exceptions import CosmosValueError


def mock_bigquery_adapter() -> None:
def _mock_bigquery_adapter() -> None:
from typing import Optional, Tuple

import agate
Expand All @@ -17,5 +22,27 @@ def execute( # type: ignore[no-untyped-def]


PROFILE_TYPE_MOCK_ADAPTER_CALLABLE_MAP = {
BIGQUERY_PROFILE_TYPE: mock_bigquery_adapter,
BIGQUERY_PROFILE_TYPE: _mock_bigquery_adapter,
}


def _associate_bigquery_async_op_args(async_op_obj: Any, **kwargs: Any) -> Any:
sql = kwargs.get("sql")
if not sql:
raise CosmosValueError("Keyword argument 'sql' is required for BigQuery Async operator")
async_op_obj.configuration = {
"query": {
"query": sql,
"useLegacySql": False,
}
}
return async_op_obj


PROFILE_TYPE_ASSOCIATE_ARGS_CALLABLE_MAP = {
BIGQUERY_PROFILE_TYPE: _associate_bigquery_async_op_args,
}


def _associate_async_operator_args(async_operator_obj: Any, profile_type: str, **kwargs: Any) -> Any:
return PROFILE_TYPE_ASSOCIATE_ARGS_CALLABLE_MAP[profile_type](async_operator_obj, **kwargs)
20 changes: 5 additions & 15 deletions cosmos/operators/airflow_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,7 @@ class DbtSourceAirflowAsyncOperator(DbtBaseAirflowAsyncOperator, DbtSourceLocalO

class DbtRunAirflowAsyncOperator(BigQueryInsertJobOperator, DbtRunLocalOperator): # type: ignore

template_fields: Sequence[str] = (
"full_refresh",
"project_dir",
"location",
)
template_fields: Sequence[str] = DbtRunLocalOperator.template_fields + ("full_refresh", "project_dir", "location") # type: ignore[operator]

def __init__( # type: ignore
self,
Expand Down Expand Up @@ -98,18 +94,12 @@ def __init__( # type: ignore
deferrable=True,
**kwargs,
)
self.extra_context = extra_context or {}
self.extra_context["profile_type"] = self.profile_type
self.async_context = extra_context or {}
self.async_context["profile_type"] = self.profile_type
self.async_context["async_operator"] = BigQueryInsertJobOperator

def execute(self, context: Context) -> Any | None:
sql = self.build_and_run_cmd(context, return_sql=True, sql_context=self.extra_context)
self.configuration = {
"query": {
"query": sql,
"useLegacySql": False,
}
}
return super().execute(context)
return self.build_and_run_cmd(context, run_as_async=True, async_context=self.async_context)


class DbtTestAirflowAsyncOperator(DbtBaseAirflowAsyncOperator, DbtTestLocalOperator): # type: ignore
Expand Down
8 changes: 7 additions & 1 deletion cosmos/operators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,13 @@ def build_cmd(
return dbt_cmd, env

@abstractmethod
def build_and_run_cmd(self, context: Context, cmd_flags: list[str]) -> Any:
def build_and_run_cmd(
self,
context: Context,
cmd_flags: list[str],
run_as_async: bool = False,
async_context: dict[str, Any] | None = None,
) -> Any:
"""Override this method for the operator to execute the dbt command"""

def execute(self, context: Context) -> Any | None: # type: ignore
Expand Down
27 changes: 16 additions & 11 deletions cosmos/operators/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
FullOutputSubprocessResult,
)
from cosmos.log import get_logger
from cosmos.mocked_dbt_adapters import PROFILE_TYPE_MOCK_ADAPTER_CALLABLE_MAP
from cosmos.mocked_dbt_adapters import PROFILE_TYPE_ASSOCIATE_ARGS_CALLABLE_MAP, PROFILE_TYPE_MOCK_ADAPTER_CALLABLE_MAP
from cosmos.operators.base import (
AbstractDbtBaseOperator,
DbtBuildMixin,
Expand Down Expand Up @@ -432,8 +432,8 @@ def run_command(
cmd: list[str],
env: dict[str, str | bytes | os.PathLike[Any]],
context: Context,
return_sql: bool = False,
sql_context: dict[str, Any] | None = None,
run_as_async: bool = False,
async_context: dict[str, Any] | None = None,
) -> FullOutputSubprocessResult | dbtRunnerResult | str:
"""
Copies the dbt project to a temporary directory and runs the command.
Expand Down Expand Up @@ -486,8 +486,10 @@ def run_command(
full_cmd = cmd + flags

self.log.debug("Using environment variables keys: %s", env.keys())
if return_sql and sql_context:
profile_type = sql_context["profile_type"]
if run_as_async:
if not async_context:
raise CosmosValueError("async_context is necessary for running the model asynchronously.")
profile_type = async_context["profile_type"]
mock_adapter_callable = PROFILE_TYPE_MOCK_ADAPTER_CALLABLE_MAP.get(profile_type)
if not mock_adapter_callable:
raise CosmosValueError(
Expand Down Expand Up @@ -526,9 +528,10 @@ def run_command(
self.callback(tmp_project_dir, **self.callback_args)
self.handle_exception(result)

if return_sql and sql_context:
sql_content = self._read_run_sql_from_target_dir(tmp_project_dir, sql_context)
return sql_content
if run_as_async and async_context:
sql = self._read_run_sql_from_target_dir(tmp_project_dir, async_context)
PROFILE_TYPE_ASSOCIATE_ARGS_CALLABLE_MAP[profile_type](self, sql=sql)
async_context["async_operator"].execute(self, context)

return result

Expand Down Expand Up @@ -672,12 +675,14 @@ def build_and_run_cmd(
self,
context: Context,
cmd_flags: list[str] | None = None,
return_sql: bool = False,
sql_context: dict[str, Any] | None = None,
run_as_async: bool = False,
async_context: dict[str, Any] | None = None,
) -> FullOutputSubprocessResult | dbtRunnerResult:
dbt_cmd, env = self.build_cmd(context=context, cmd_flags=cmd_flags)
dbt_cmd = dbt_cmd or []
result = self.run_command(cmd=dbt_cmd, env=env, context=context, return_sql=return_sql, sql_context=sql_context)
result = self.run_command(
cmd=dbt_cmd, env=env, context=context, run_as_async=run_as_async, async_context=async_context
)
return result

def on_kill(self) -> None:
Expand Down

0 comments on commit 8ef4ac5

Please sign in to comment.