Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move DbtRunner related functions into dbt/runner.py module #1480

Merged
merged 11 commits into from
Jan 23, 2025
111 changes: 111 additions & 0 deletions cosmos/dbt/runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from __future__ import annotations

import sys
from functools import lru_cache
from typing import TYPE_CHECKING

from cosmos.dbt.project import change_working_directory, environ
from cosmos.exceptions import CosmosDbtRunError
from cosmos.log import get_logger

if "pytest" in sys.modules:
# We set the cache limit to 0, so nothing gets cached by default when
# running tests
cache = lru_cache(maxsize=0)
else: # pragma: no cover
try:
# Available since Python 3.9
from functools import cache
except ImportError:
cache = lru_cache(maxsize=None)


logger = get_logger(__name__)

if TYPE_CHECKING: # pragma: no cover
from dbt.cli.main import dbtRunner, dbtRunnerResult


@cache
def is_available() -> bool:
"""
Checks if the dbt runner is available (if dbt-core is installed in the same Python virtualenv as Airflow."
tatiana marked this conversation as resolved.
Show resolved Hide resolved
"""
try:
from dbt.cli.main import dbtRunner # noqa
except ImportError:
return False
return True


@cache
def get_runner() -> dbtRunner:
"""
Retrieves a dbtRunner instance.
"""
from dbt.cli.main import dbtRunner

return dbtRunner()


def run_command(command: list[str], env: dict[str, str], cwd: str) -> dbtRunnerResult:
"""
Invokes the dbt command programmatically.
"""
# Exclude the dbt executable path from the command
tatiana marked this conversation as resolved.
Show resolved Hide resolved
cli_args = command[1:]
with change_working_directory(cwd), environ(env):
logger.info("Trying to run dbtRunner with:\n %s\n in %s", cli_args, cwd)
runner = get_runner()
result = runner.invoke(cli_args)
return result


def extract_message_by_status(
result: dbtRunnerResult, status_levels: list[str] = ["warn"]
) -> tuple[list[str], list[str]]:
"""
Extracts messages from the dbt runner result and returns them as a formatted string.

This function iterates over dbtRunnerResult messages in dbt run. It extracts results that match the
status levels provided and appends them to a list of issues.

:param result: dbtRunnerResult object containing the output to be parsed.
:param status_levels: List of strings, where each string is a result status level. Default is ["warn"].
:return: two lists of strings, the first one containing the node names and the second one
containing the node result message.
"""
node_names = []
node_results = []

for node_result in result.result.results: # type: ignore
if node_result.status in status_levels:
node_names.append(str(node_result.node.name))
node_results.append(str(node_result.message))

return node_names, node_results


def parse_number_of_warnings(result: dbtRunnerResult) -> int:
"""Parses a dbt runner result and returns the number of warnings found. This only works for dbtRunnerResult
from invoking dbt build, compile, run, seed, snapshot, test, or run-operation.
"""
num = 0
for run_result in result.result.results: # type: ignore
if run_result.status == "warn":
num += 1
return num


def handle_exception_if_needed(result: dbtRunnerResult) -> None:
"""
Given a dbtRunnerResult, identify if it failed and handle the exception, if necessary.
"""
# dbtRunnerResult has an attribute `success` that is False if the command failed.
if not result.success:
if result.exception:
raise CosmosDbtRunError(f"dbt invocation did not complete with unhandled error: {result.exception}")
else:
node_names, node_results = extract_message_by_status(result, ["error", "fail", "runtime error"])
error_message = "\n".join([f"{name}: {result}" for name, result in zip(node_names, node_results)])
raise CosmosDbtRunError(f"dbt invocation completed with errors: {error_message}")
4 changes: 4 additions & 0 deletions cosmos/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,9 @@ class CosmosValueError(ValueError):
"""Raised when a Cosmos config value is invalid."""


class CosmosDbtRunError(Exception):
"""Raised when there are exceptions running DbtRunner"""


class AirflowCompatibilityError(Exception):
"""Raised when Cosmos features are limited for Airflow version being used."""
45 changes: 12 additions & 33 deletions cosmos/operators/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,18 +51,17 @@

from sqlalchemy.orm import Session

import cosmos.dbt.runner as dbt_runner
from cosmos.config import ProfileConfig
from cosmos.constants import (
OPENLINEAGE_PRODUCER,
)
from cosmos.dbt.parser.output import (
extract_dbt_runner_issues,
tatiana marked this conversation as resolved.
Show resolved Hide resolved
extract_freshness_warn_msg,
extract_log_issues,
parse_number_of_warnings_dbt_runner,
tatiana marked this conversation as resolved.
Show resolved Hide resolved
parse_number_of_warnings_subprocess,
)
from cosmos.dbt.project import change_working_directory, create_symlinks, environ
from cosmos.dbt.project import create_symlinks
from cosmos.hooks.subprocess import (
FullOutputSubprocessHook,
FullOutputSubprocessResult,
Expand Down Expand Up @@ -210,14 +209,12 @@ def _discover_invocation_mode(self) -> None:
be used since it is faster than subprocess. If dbtRunner is not available, it will fall back to subprocess.
This method is called at runtime to work in the environment where the operator is running.
"""
try:
from dbt.cli.main import dbtRunner # noqa
except ImportError:
self.invocation_mode = InvocationMode.SUBPROCESS
self.log.info("Could not import dbtRunner. Falling back to subprocess for invoking dbt.")
else:
if dbt_runner.is_available():
self.invocation_mode = InvocationMode.DBT_RUNNER
self.log.info("dbtRunner is available. Using dbtRunner for invoking dbt.")
else:
self.invocation_mode = InvocationMode.SUBPROCESS
self.log.info("Could not import dbtRunner. Falling back to subprocess for invoking dbt.")

def handle_exception_subprocess(self, result: FullOutputSubprocessResult) -> None:
if self.skip_exit_code is not None and result.exit_code == self.skip_exit_code:
Expand All @@ -228,13 +225,7 @@ def handle_exception_subprocess(self, result: FullOutputSubprocessResult) -> Non

def handle_exception_dbt_runner(self, result: dbtRunnerResult) -> None:
"""dbtRunnerResult has an attribute `success` that is False if the command failed."""
if not result.success:
if result.exception:
raise AirflowException(f"dbt invocation did not complete with unhandled error: {result.exception}")
else:
node_names, node_results = extract_dbt_runner_issues(result, ["error", "fail", "runtime error"])
error_message = "\n".join([f"{name}: {result}" for name, result in zip(node_names, node_results)])
raise AirflowException(f"dbt invocation completed with errors: {error_message}")
return dbt_runner.handle_exception_if_needed(result)

@provide_session
def store_compiled_sql(self, tmp_project_dir: str, context: Context, session: Session = NEW_SESSION) -> None:
Expand Down Expand Up @@ -393,24 +384,12 @@ def run_subprocess(self, command: list[str], env: dict[str, str], cwd: str) -> F

def run_dbt_runner(self, command: list[str], env: dict[str, str], cwd: str) -> dbtRunnerResult:
"""Invokes the dbt command programmatically."""
try:
from dbt.cli.main import dbtRunner
except ImportError:
if not dbt_runner.is_available():
raise ImportError(
tatiana marked this conversation as resolved.
Show resolved Hide resolved
"Could not import dbt core. Ensure that dbt-core >= v1.5 is installed and available in the environment where the operator is running."
)

if self._dbt_runner is None:
self._dbt_runner = dbtRunner()

# Exclude the dbt executable path from the command
cli_args = command[1:]
self.log.info("Trying to run dbtRunner with:\n %s\n in %s", cli_args, cwd)

with change_working_directory(cwd), environ(env):
result = self._dbt_runner.invoke(cli_args)

return result
return dbt_runner.run_command(command, env, cwd)

def _cache_package_lockfile(self, tmp_project_dir: Path) -> None:
project_dir = Path(self.project_dir)
Expand Down Expand Up @@ -723,7 +702,7 @@ def _handle_warnings(self, result: FullOutputSubprocessResult | dbtRunnerResult,
if self.invocation_mode == InvocationMode.SUBPROCESS:
self.extract_issues = extract_freshness_warn_msg
elif self.invocation_mode == InvocationMode.DBT_RUNNER:
self.extract_issues = extract_dbt_runner_issues
self.extract_issues = dbt_runner.extract_message_by_status

test_names, test_results = self.extract_issues(result)

Expand Down Expand Up @@ -789,8 +768,8 @@ def _set_test_result_parsing_methods(self) -> None:
self.extract_issues = lambda result: extract_log_issues(result.full_output)
self.parse_number_of_warnings = parse_number_of_warnings_subprocess
elif self.invocation_mode == InvocationMode.DBT_RUNNER:
self.extract_issues = extract_dbt_runner_issues
self.parse_number_of_warnings = parse_number_of_warnings_dbt_runner
self.extract_issues = dbt_runner.extract_message_by_status
self.parse_number_of_warnings = dbt_runner.parse_number_of_warnings

def execute(self, context: Context) -> None:
result = self.build_and_run_cmd(context=context, cmd_flags=self.add_cmd_flags())
Expand Down
109 changes: 109 additions & 0 deletions tests/dbt/test_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import os
import shutil
import sys
import tempfile
from pathlib import Path
from unittest.mock import patch

import pytest

import cosmos.dbt.runner as dbt_runner
from cosmos.exceptions import CosmosDbtRunError

DBT_PROJECT_PATH = Path(__file__).parent.parent.parent / "dev/dags/dbt/jaffle_shop"


@pytest.fixture
def valid_dbt_project_dir():
"""
Creates a plain dbt project structure, which does not contain logs or target folders.
"""
tmp_dir = Path(tempfile.mkdtemp())
source_proj_dir = DBT_PROJECT_PATH
target_proj_dir = tmp_dir / "jaffle_shop"
shutil.copytree(source_proj_dir, target_proj_dir)
shutil.rmtree(target_proj_dir / "logs", ignore_errors=True)
shutil.rmtree(target_proj_dir / "target", ignore_errors=True)
yield target_proj_dir

shutil.rmtree(tmp_dir, ignore_errors=True) # delete directory


@pytest.fixture
def invalid_dbt_project_dir(valid_dbt_project_dir):
"""
Create an invalid dbt project dir, that will raise exceptions if attempted to be run.
"""
file_to_be_deleted = valid_dbt_project_dir / "packages.yml"
file_to_be_deleted.unlink()

file_to_be_changed = valid_dbt_project_dir / "models/staging/stg_orders.sql"
open(str(file_to_be_changed), "w").close()

return valid_dbt_project_dir


@patch.dict(sys.modules, {"dbt.cli.main": None})
def test_is_available_is_false():
assert not dbt_runner.is_available()


@pytest.mark.integration
def test_is_available_is_true():
assert dbt_runner.is_available()


@pytest.mark.integration
def test_get_runner():
from dbt.cli.main import dbtRunner

runner = dbt_runner.get_runner()
assert isinstance(runner, dbtRunner)


@pytest.mark.integration
def test_run_command(valid_dbt_project_dir):
from dbt.cli.main import dbtRunnerResult

response = dbt_runner.run_command(command=["dbt", "deps"], env=os.environ, cwd=valid_dbt_project_dir)
assert isinstance(response, dbtRunnerResult)
assert response.success
assert response.exception is None
assert response.result is None

assert dbt_runner.handle_exception_if_needed(response) is None


@pytest.mark.integration
def test_handle_exception_if_needed_after_exception(valid_dbt_project_dir):
# THe following command will fail because we didn't run `dbt deps` in advance
response = dbt_runner.run_command(command=["dbt", "ls"], env=os.environ, cwd=valid_dbt_project_dir)
assert not response.success
assert response.exception

with pytest.raises(CosmosDbtRunError) as exc_info:
dbt_runner.handle_exception_if_needed(response)

err_msg = str(exc_info.value)
expected1 = "dbt invocation did not complete with unhandled error: Compilation Error"
expected2 = "dbt found 1 package(s) specified in packages.yml, but only 0 package(s) installed"
assert expected1 in err_msg
assert expected2 in err_msg


@pytest.mark.integration
def test_handle_exception_if_needed_after_error(invalid_dbt_project_dir):
# THe following command will fail because we didn't run `dbt deps` in advance
response = dbt_runner.run_command(command=["dbt", "run"], env=os.environ, cwd=invalid_dbt_project_dir)
assert not response.success
assert response.exception is None
assert response.result

with pytest.raises(CosmosDbtRunError) as exc_info:
dbt_runner.handle_exception_if_needed(response)

err_msg = str(exc_info.value)
expected1 = "dbt invocation completed with errors:"
expected2 = "stg_payments: Database Error in model stg_payments"
assert expected1 in err_msg
assert expected2 in err_msg
Loading
Loading