Skip to content

Commit

Permalink
Moment of glory 2
Browse files Browse the repository at this point in the history
  • Loading branch information
pankajkoti committed Jan 29, 2025
1 parent 31161bf commit 5ea5217
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 17 deletions.
8 changes: 4 additions & 4 deletions cosmos/operators/_asynchronous/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ def __init__(self, project_dir: str, profile_config: ProfileConfig, extra_contex
# When using composition instead of inheritance to initialize the async class and run its execute method,
# Airflow throws a `DuplicateTaskIdFound` error.
DbtRunAirflowAsyncFactoryOperator.__bases__ = (async_operator_class,)
super().__init__(project_dir=project_dir, profile_config=profile_config, dbt_kwargs=dbt_kwargs, **kwargs)
self.async_context = extra_context
self.async_context["profile_type"] = "bigquery"
self.async_context["async_operator"] = async_operator_class
super().__init__(project_dir=project_dir, profile_config=profile_config, extra_context=extra_context, dbt_kwargs=dbt_kwargs, **kwargs)
# self.async_context = extra_context
# self.async_context["profile_type"] = "bigquery"
# self.async_context["async_operator"] = async_operator_class

def create_async_operator(self) -> Any:

Expand Down
9 changes: 4 additions & 5 deletions cosmos/operators/_asynchronous/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,18 @@ def __init__(
if "full_refresh" in kwargs:
self.full_refresh = kwargs.pop("full_refresh")
self.configuration: dict[str, Any] = {}
task_id = dbt_kwargs.pop("task_id")
AbstractDbtLocalBase.__init__(
self, task_id=task_id, project_dir=project_dir, profile_config=profile_config, **dbt_kwargs
)
super().__init__(
gcp_conn_id=self.gcp_conn_id,
configuration=self.configuration,
deferrable=True,
**kwargs,
)
task_id = dbt_kwargs.pop("task_id")
# DbtRunMixin.__init__(self, **dbt_kwargs)
# breakpoint()

AbstractDbtLocalBase.__init__(
self, task_id=task_id, project_dir=project_dir, profile_config=profile_config, **dbt_kwargs
)
self.dbt_kwargs = dbt_kwargs
self.async_context = extra_context
self.async_context["profile_type"] = self.profile_config.get_profile_type()
Expand Down
25 changes: 20 additions & 5 deletions cosmos/operators/airflow_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from cosmos.config import ProfileConfig
from cosmos.operators._asynchronous.bigquery import DbtRunAirflowAsyncBigqueryOperator
from cosmos.operators._asynchronous.base import DbtRunAirflowAsyncFactoryOperator
from cosmos.operators.base import AbstractDbtBase
from cosmos.operators.local import (
DbtBuildLocalOperator,
Expand Down Expand Up @@ -40,8 +41,15 @@ class DbtLSAirflowAsyncOperator(DbtBaseAirflowAsyncOperator, DbtLSLocalOperator)
pass


class DbtSeedAirflowAsyncOperator(DbtBaseAirflowAsyncOperator, DbtSeedLocalOperator): # type: ignore
pass
class DbtSeedAirflowAsyncOperator(DbtSeedLocalOperator): # type: ignore
def __init__(self, *args, **kwargs) -> None:
clean_kwargs = {}
base_operator_args = set(inspect.signature(BaseOperator.__init__).parameters.keys())
for arg_key, arg_value in kwargs.items():
if arg_key in base_operator_args:
clean_kwargs[arg_key] = arg_value
BaseOperator.__init__(self, **clean_kwargs)
super().__init__(*args, **kwargs)


class DbtSnapshotAirflowAsyncOperator(DbtBaseAirflowAsyncOperator, DbtSnapshotLocalOperator): # type: ignore
Expand All @@ -52,7 +60,7 @@ class DbtSourceAirflowAsyncOperator(DbtBaseAirflowAsyncOperator, DbtSourceLocalO
pass


class DbtRunAirflowAsyncOperator(DbtRunAirflowAsyncBigqueryOperator): # type: ignore
class DbtRunAirflowAsyncOperator(DbtRunAirflowAsyncFactoryOperator): # type: ignore

def __init__( # type: ignore
self,
Expand Down Expand Up @@ -89,8 +97,15 @@ def __init__( # type: ignore
)


class DbtTestAirflowAsyncOperator(DbtBaseAirflowAsyncOperator, DbtTestLocalOperator): # type: ignore
pass
class DbtTestAirflowAsyncOperator(DbtTestLocalOperator): # type: ignore
def __init__(self, *args, **kwargs) -> None:
clean_kwargs = {}
base_operator_args = set(inspect.signature(BaseOperator.__init__).parameters.keys())
for arg_key, arg_value in kwargs.items():
if arg_key in base_operator_args:
clean_kwargs[arg_key] = arg_value
super().__init__(*args, **kwargs)
BaseOperator.__init__(self, **clean_kwargs)


class DbtRunOperationAirflowAsyncOperator(DbtBaseAirflowAsyncOperator, DbtRunOperationLocalOperator): # type: ignore
Expand Down
4 changes: 2 additions & 2 deletions cosmos/operators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def build_and_run_cmd(
) -> Any:
"""Override this method for the operator to execute the dbt command"""

def execute(self, context: Context) -> Any | None: # type: ignore
def execute(self, context: Context, **kwargs) -> Any | None: # type: ignore
if self.extra_context:
context_merge(context, self.extra_context)

Expand Down Expand Up @@ -371,7 +371,7 @@ class DbtRunMixin:

def __init__(self, full_refresh: bool | str = False, **kwargs: Any) -> None:
self.full_refresh = full_refresh
# super().__init__(**kwargs)
super().__init__(**kwargs)

def add_cmd_flags(self) -> list[str]:
flags = []
Expand Down
2 changes: 1 addition & 1 deletion cosmos/operators/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -841,7 +841,7 @@ def _set_test_result_parsing_methods(self) -> None:
self.extract_issues = dbt_runner.extract_message_by_status
self.parse_number_of_warnings = dbt_runner.parse_number_of_warnings

def execute(self, context: Context) -> None:
def execute(self, context: Context, **kwargs) -> None:
result = self.build_and_run_cmd(context=context, cmd_flags=self.add_cmd_flags())
self._set_test_result_parsing_methods()
number_of_warnings = self.parse_number_of_warnings(result) # type: ignore
Expand Down

0 comments on commit 5ea5217

Please sign in to comment.