Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move DbtRunner related functions into dbt/runner.py module #1480

Merged
merged 11 commits into from
Jan 23, 2025
21 changes: 19 additions & 2 deletions cosmos/dbt/parser/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@
import re
from typing import TYPE_CHECKING, List, Tuple

import deprecation

if TYPE_CHECKING:
from dbt.cli.main import dbtRunnerResult

from cosmos import __version__ as cosmos_version # type: ignore[attr-defined]
from cosmos.hooks.subprocess import FullOutputSubprocessResult

DBT_NO_TESTS_MSG = "Nothing to do"
Expand Down Expand Up @@ -40,7 +43,14 @@ def parse_number_of_warnings_subprocess(result: FullOutputSubprocessResult) -> i
return num


def parse_number_of_warnings_dbt_runner(result: dbtRunnerResult) -> int:
# Python 3.13 exposes a deprecated operator, we can replace it in the future
@deprecation.deprecated(
deprecated_in="1.9",
removed_in="2.0",
current_version=cosmos_version,
details="Use the `cosmos.dbt.runner.parse_number_of_warnings` instead.",
) # type: ignore[misc]
def parse_number_of_warnings_dbt_runner(result: dbtRunnerResult) -> int: # type: ignore[misc]
"""Parses a dbt runner result and returns the number of warnings found. This only works for dbtRunnerResult
from invoking dbt build, compile, run, seed, snapshot, test, or run-operation.
"""
Expand Down Expand Up @@ -105,9 +115,16 @@ def clean_line(line: str) -> str:
return test_names, test_results


# Python 3.13 exposes a deprecated operator, we can replace it in the future
@deprecation.deprecated(
deprecated_in="1.9",
removed_in="2.0",
current_version=cosmos_version,
details="Use the `cosmos.dbt.runner.extract_message_by_status` instead.",
) # type: ignore[misc]
def extract_dbt_runner_issues(
result: dbtRunnerResult, status_levels: list[str] = ["warn"]
) -> Tuple[List[str], List[str]]:
) -> Tuple[List[str], List[str]]: # type: ignore[misc]
"""
Extracts messages from the dbt runner result and returns them as a formatted string.

Expand Down
113 changes: 113 additions & 0 deletions cosmos/dbt/runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from __future__ import annotations

import sys
from functools import lru_cache
from typing import TYPE_CHECKING

from cosmos.dbt.project import change_working_directory, environ
from cosmos.exceptions import CosmosDbtRunError
from cosmos.log import get_logger

if "pytest" in sys.modules:
# We set the cache limit to 0, so nothing gets cached by default when
# running tests
cache = lru_cache(maxsize=0)
else: # pragma: no cover
try:
# Available since Python 3.9
from functools import cache
except ImportError:
cache = lru_cache(maxsize=None)


logger = get_logger(__name__)

if TYPE_CHECKING: # pragma: no cover
from dbt.cli.main import dbtRunner, dbtRunnerResult


@cache
def is_available() -> bool:
"""
Checks if the dbt runner is available (if dbt-core is installed in the same Python virtualenv as Airflow)."
"""
try:
from dbt.cli.main import dbtRunner # noqa
except ImportError:
return False
return True


@cache
def get_runner() -> dbtRunner:
"""
Retrieves a dbtRunner instance.
"""
from dbt.cli.main import dbtRunner

return dbtRunner()


def run_command(command: list[str], env: dict[str, str], cwd: str) -> dbtRunnerResult:
"""
Invokes the dbt command programmatically.
"""
# Exclude the dbt executable path from the command. This step is necessary because we are using the same
# command that is used by `InvocationMode.SUBPROCESS`, and in that scenario the first command is necessarily the path
# to the dbt executable.
cli_args = command[1:]
with change_working_directory(cwd), environ(env):
logger.info("Trying to run dbtRunner with:\n %s\n in %s", cli_args, cwd)
runner = get_runner()
result = runner.invoke(cli_args)
return result


def extract_message_by_status(
result: dbtRunnerResult, status_levels: list[str] = ["warn"]
) -> tuple[list[str], list[str]]:
"""
Extracts messages from the dbt runner result and returns them as a formatted string.

This function iterates over dbtRunnerResult messages in dbt run. It extracts results that match the
status levels provided and appends them to a list of issues.

:param result: dbtRunnerResult object containing the output to be parsed.
:param status_levels: List of strings, where each string is a result status level. Default is ["warn"].
:return: two lists of strings, the first one containing the node names and the second one
containing the node result message.
"""
node_names = []
node_results = []

for node_result in result.result.results: # type: ignore
if node_result.status in status_levels:
node_names.append(str(node_result.node.name))
node_results.append(str(node_result.message))

return node_names, node_results


def parse_number_of_warnings(result: dbtRunnerResult) -> int:
"""Parses a dbt runner result and returns the number of warnings found. This only works for dbtRunnerResult
from invoking dbt build, compile, run, seed, snapshot, test, or run-operation.
"""
num = 0
for run_result in result.result.results: # type: ignore
if run_result.status == "warn":
num += 1
return num


def handle_exception_if_needed(result: dbtRunnerResult) -> None:
"""
Given a dbtRunnerResult, identify if it failed and handle the exception, if necessary.
"""
# dbtRunnerResult has an attribute `success` that is False if the command failed.
if not result.success:
if result.exception:
raise CosmosDbtRunError(f"dbt invocation did not complete with unhandled error: {result.exception}")
else:
node_names, node_results = extract_message_by_status(result, ["error", "fail", "runtime error"])
error_message = "\n".join([f"{name}: {result}" for name, result in zip(node_names, node_results)])
raise CosmosDbtRunError(f"dbt invocation completed with errors: {error_message}")
4 changes: 4 additions & 0 deletions cosmos/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,9 @@ class CosmosValueError(ValueError):
"""Raised when a Cosmos config value is invalid."""


class CosmosDbtRunError(Exception):
"""Raised when there are exceptions running DbtRunner"""


class AirflowCompatibilityError(Exception):
"""Raised when Cosmos features are limited for Airflow version being used."""
49 changes: 14 additions & 35 deletions cosmos/operators/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from cosmos.constants import FILE_SCHEME_AIRFLOW_DEFAULT_CONN_ID_MAP, InvocationMode
from cosmos.dataset import get_dataset_alias_name
from cosmos.dbt.project import get_partial_parse_path, has_non_empty_dependencies_file
from cosmos.exceptions import AirflowCompatibilityError, CosmosValueError
from cosmos.exceptions import AirflowCompatibilityError, CosmosDbtRunError, CosmosValueError
from cosmos.settings import remote_target_path, remote_target_path_conn_id

try:
Expand All @@ -51,18 +51,17 @@

from sqlalchemy.orm import Session

import cosmos.dbt.runner as dbt_runner
from cosmos.config import ProfileConfig
from cosmos.constants import (
OPENLINEAGE_PRODUCER,
)
from cosmos.dbt.parser.output import (
extract_dbt_runner_issues,
tatiana marked this conversation as resolved.
Show resolved Hide resolved
extract_freshness_warn_msg,
extract_log_issues,
parse_number_of_warnings_dbt_runner,
tatiana marked this conversation as resolved.
Show resolved Hide resolved
parse_number_of_warnings_subprocess,
)
from cosmos.dbt.project import change_working_directory, create_symlinks, environ
from cosmos.dbt.project import create_symlinks
from cosmos.hooks.subprocess import (
FullOutputSubprocessHook,
FullOutputSubprocessResult,
Expand Down Expand Up @@ -210,14 +209,12 @@ def _discover_invocation_mode(self) -> None:
be used since it is faster than subprocess. If dbtRunner is not available, it will fall back to subprocess.
This method is called at runtime to work in the environment where the operator is running.
"""
try:
from dbt.cli.main import dbtRunner # noqa
except ImportError:
self.invocation_mode = InvocationMode.SUBPROCESS
self.log.info("Could not import dbtRunner. Falling back to subprocess for invoking dbt.")
else:
if dbt_runner.is_available():
self.invocation_mode = InvocationMode.DBT_RUNNER
self.log.info("dbtRunner is available. Using dbtRunner for invoking dbt.")
else:
self.invocation_mode = InvocationMode.SUBPROCESS
self.log.info("Could not import dbtRunner. Falling back to subprocess for invoking dbt.")

def handle_exception_subprocess(self, result: FullOutputSubprocessResult) -> None:
if self.skip_exit_code is not None and result.exit_code == self.skip_exit_code:
Expand All @@ -228,13 +225,7 @@ def handle_exception_subprocess(self, result: FullOutputSubprocessResult) -> Non

def handle_exception_dbt_runner(self, result: dbtRunnerResult) -> None:
"""dbtRunnerResult has an attribute `success` that is False if the command failed."""
if not result.success:
if result.exception:
raise AirflowException(f"dbt invocation did not complete with unhandled error: {result.exception}")
else:
node_names, node_results = extract_dbt_runner_issues(result, ["error", "fail", "runtime error"])
error_message = "\n".join([f"{name}: {result}" for name, result in zip(node_names, node_results)])
raise AirflowException(f"dbt invocation completed with errors: {error_message}")
return dbt_runner.handle_exception_if_needed(result)

@provide_session
def store_compiled_sql(self, tmp_project_dir: str, context: Context, session: Session = NEW_SESSION) -> None:
Expand Down Expand Up @@ -393,24 +384,12 @@ def run_subprocess(self, command: list[str], env: dict[str, str], cwd: str) -> F

def run_dbt_runner(self, command: list[str], env: dict[str, str], cwd: str) -> dbtRunnerResult:
"""Invokes the dbt command programmatically."""
try:
from dbt.cli.main import dbtRunner
except ImportError:
raise ImportError(
if not dbt_runner.is_available():
raise CosmosDbtRunError(
"Could not import dbt core. Ensure that dbt-core >= v1.5 is installed and available in the environment where the operator is running."
)

if self._dbt_runner is None:
self._dbt_runner = dbtRunner()

# Exclude the dbt executable path from the command
cli_args = command[1:]
self.log.info("Trying to run dbtRunner with:\n %s\n in %s", cli_args, cwd)

with change_working_directory(cwd), environ(env):
result = self._dbt_runner.invoke(cli_args)

return result
return dbt_runner.run_command(command, env, cwd)

def _cache_package_lockfile(self, tmp_project_dir: Path) -> None:
project_dir = Path(self.project_dir)
Expand Down Expand Up @@ -723,7 +702,7 @@ def _handle_warnings(self, result: FullOutputSubprocessResult | dbtRunnerResult,
if self.invocation_mode == InvocationMode.SUBPROCESS:
self.extract_issues = extract_freshness_warn_msg
elif self.invocation_mode == InvocationMode.DBT_RUNNER:
self.extract_issues = extract_dbt_runner_issues
self.extract_issues = dbt_runner.extract_message_by_status

test_names, test_results = self.extract_issues(result)

Expand Down Expand Up @@ -789,8 +768,8 @@ def _set_test_result_parsing_methods(self) -> None:
self.extract_issues = lambda result: extract_log_issues(result.full_output)
self.parse_number_of_warnings = parse_number_of_warnings_subprocess
elif self.invocation_mode == InvocationMode.DBT_RUNNER:
self.extract_issues = extract_dbt_runner_issues
self.parse_number_of_warnings = parse_number_of_warnings_dbt_runner
self.extract_issues = dbt_runner.extract_message_by_status
self.parse_number_of_warnings = dbt_runner.parse_number_of_warnings

def execute(self, context: Context) -> None:
result = self.build_and_run_cmd(context=context, cmd_flags=self.add_cmd_flags())
Expand Down
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
aenum
deprecation
msgpack
apache-airflow
pydantic
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ dependencies = [
"aenum",
"attrs",
"apache-airflow>=2.4.0",
"deprecation", # Python 3.13 exposes a deprecated operator, we can remove this dependency in the future
"importlib-metadata; python_version < '3.8'",
"Jinja2>=3.0.0",
"msgpack",
Expand Down
Loading
Loading