From 5ea52171b93f345cd375869763094069a4ec1a15 Mon Sep 17 00:00:00 2001 From: Pankaj Koti Date: Wed, 29 Jan 2025 15:58:18 +0530 Subject: [PATCH] Moment of glory 2 --- cosmos/operators/_asynchronous/base.py | 8 +++---- cosmos/operators/_asynchronous/bigquery.py | 9 ++++---- cosmos/operators/airflow_async.py | 25 +++++++++++++++++----- cosmos/operators/base.py | 4 ++-- cosmos/operators/local.py | 2 +- 5 files changed, 31 insertions(+), 17 deletions(-) diff --git a/cosmos/operators/_asynchronous/base.py b/cosmos/operators/_asynchronous/base.py index b56a5d075..ee7e85443 100644 --- a/cosmos/operators/_asynchronous/base.py +++ b/cosmos/operators/_asynchronous/base.py @@ -48,10 +48,10 @@ def __init__(self, project_dir: str, profile_config: ProfileConfig, extra_contex # When using composition instead of inheritance to initialize the async class and run its execute method, # Airflow throws a `DuplicateTaskIdFound` error. DbtRunAirflowAsyncFactoryOperator.__bases__ = (async_operator_class,) - super().__init__(project_dir=project_dir, profile_config=profile_config, dbt_kwargs=dbt_kwargs, **kwargs) - self.async_context = extra_context - self.async_context["profile_type"] = "bigquery" - self.async_context["async_operator"] = async_operator_class + super().__init__(project_dir=project_dir, profile_config=profile_config, extra_context=extra_context, dbt_kwargs=dbt_kwargs, **kwargs) + # self.async_context = extra_context + # self.async_context["profile_type"] = "bigquery" + # self.async_context["async_operator"] = async_operator_class def create_async_operator(self) -> Any: diff --git a/cosmos/operators/_asynchronous/bigquery.py b/cosmos/operators/_asynchronous/bigquery.py index 29788dcdf..9e7953a35 100644 --- a/cosmos/operators/_asynchronous/bigquery.py +++ b/cosmos/operators/_asynchronous/bigquery.py @@ -37,19 +37,18 @@ def __init__( if "full_refresh" in kwargs: self.full_refresh = kwargs.pop("full_refresh") self.configuration: dict[str, Any] = {} + task_id = dbt_kwargs.pop("task_id") + AbstractDbtLocalBase.__init__( + self, task_id=task_id, project_dir=project_dir, profile_config=profile_config, **dbt_kwargs + ) super().__init__( gcp_conn_id=self.gcp_conn_id, configuration=self.configuration, deferrable=True, **kwargs, ) - task_id = dbt_kwargs.pop("task_id") # DbtRunMixin.__init__(self, **dbt_kwargs) # breakpoint() - - AbstractDbtLocalBase.__init__( - self, task_id=task_id, project_dir=project_dir, profile_config=profile_config, **dbt_kwargs - ) self.dbt_kwargs = dbt_kwargs self.async_context = extra_context self.async_context["profile_type"] = self.profile_config.get_profile_type() diff --git a/cosmos/operators/airflow_async.py b/cosmos/operators/airflow_async.py index 4a31bd070..eaf9480d1 100644 --- a/cosmos/operators/airflow_async.py +++ b/cosmos/operators/airflow_async.py @@ -4,6 +4,7 @@ from cosmos.config import ProfileConfig from cosmos.operators._asynchronous.bigquery import DbtRunAirflowAsyncBigqueryOperator +from cosmos.operators._asynchronous.base import DbtRunAirflowAsyncFactoryOperator from cosmos.operators.base import AbstractDbtBase from cosmos.operators.local import ( DbtBuildLocalOperator, @@ -40,8 +41,15 @@ class DbtLSAirflowAsyncOperator(DbtBaseAirflowAsyncOperator, DbtLSLocalOperator) pass -class DbtSeedAirflowAsyncOperator(DbtBaseAirflowAsyncOperator, DbtSeedLocalOperator): # type: ignore - pass +class DbtSeedAirflowAsyncOperator(DbtSeedLocalOperator): # type: ignore + def __init__(self, *args, **kwargs) -> None: + clean_kwargs = {} + base_operator_args = set(inspect.signature(BaseOperator.__init__).parameters.keys()) + for arg_key, arg_value in kwargs.items(): + if arg_key in base_operator_args: + clean_kwargs[arg_key] = arg_value + BaseOperator.__init__(self, **clean_kwargs) + super().__init__(*args, **kwargs) class DbtSnapshotAirflowAsyncOperator(DbtBaseAirflowAsyncOperator, DbtSnapshotLocalOperator): # type: ignore @@ -52,7 +60,7 @@ class DbtSourceAirflowAsyncOperator(DbtBaseAirflowAsyncOperator, DbtSourceLocalO pass -class DbtRunAirflowAsyncOperator(DbtRunAirflowAsyncBigqueryOperator): # type: ignore +class DbtRunAirflowAsyncOperator(DbtRunAirflowAsyncFactoryOperator): # type: ignore def __init__( # type: ignore self, @@ -89,8 +97,15 @@ def __init__( # type: ignore ) -class DbtTestAirflowAsyncOperator(DbtBaseAirflowAsyncOperator, DbtTestLocalOperator): # type: ignore - pass +class DbtTestAirflowAsyncOperator(DbtTestLocalOperator): # type: ignore + def __init__(self, *args, **kwargs) -> None: + clean_kwargs = {} + base_operator_args = set(inspect.signature(BaseOperator.__init__).parameters.keys()) + for arg_key, arg_value in kwargs.items(): + if arg_key in base_operator_args: + clean_kwargs[arg_key] = arg_value + super().__init__(*args, **kwargs) + BaseOperator.__init__(self, **clean_kwargs) class DbtRunOperationAirflowAsyncOperator(DbtBaseAirflowAsyncOperator, DbtRunOperationLocalOperator): # type: ignore diff --git a/cosmos/operators/base.py b/cosmos/operators/base.py index 1837be01c..f07f7a493 100644 --- a/cosmos/operators/base.py +++ b/cosmos/operators/base.py @@ -266,7 +266,7 @@ def build_and_run_cmd( ) -> Any: """Override this method for the operator to execute the dbt command""" - def execute(self, context: Context) -> Any | None: # type: ignore + def execute(self, context: Context, **kwargs) -> Any | None: # type: ignore if self.extra_context: context_merge(context, self.extra_context) @@ -371,7 +371,7 @@ class DbtRunMixin: def __init__(self, full_refresh: bool | str = False, **kwargs: Any) -> None: self.full_refresh = full_refresh - # super().__init__(**kwargs) + super().__init__(**kwargs) def add_cmd_flags(self) -> list[str]: flags = [] diff --git a/cosmos/operators/local.py b/cosmos/operators/local.py index 1395dc1f6..6f9b2578e 100644 --- a/cosmos/operators/local.py +++ b/cosmos/operators/local.py @@ -841,7 +841,7 @@ def _set_test_result_parsing_methods(self) -> None: self.extract_issues = dbt_runner.extract_message_by_status self.parse_number_of_warnings = dbt_runner.parse_number_of_warnings - def execute(self, context: Context) -> None: + def execute(self, context: Context, **kwargs) -> None: result = self.build_and_run_cmd(context=context, cmd_flags=self.add_cmd_flags()) self._set_test_result_parsing_methods() number_of_warnings = self.parse_number_of_warnings(result) # type: ignore