From 6e23c8c015587b25abba4d357911db4b3168ee56 Mon Sep 17 00:00:00 2001 From: Krystle Salazar Date: Thu, 25 Apr 2024 14:11:35 -0400 Subject: [PATCH] Extend `update_license_url` tasks timeout to a day and a half --- catalog/dags/maintenance/add_license_url.py | 24 ++++++++------------- 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/catalog/dags/maintenance/add_license_url.py b/catalog/dags/maintenance/add_license_url.py index 5479214a0ab..dfb4d97f332 100644 --- a/catalog/dags/maintenance/add_license_url.py +++ b/catalog/dags/maintenance/add_license_url.py @@ -54,16 +54,14 @@ def run_sql( @task -def get_license_groups( - query: str, dag_task: AbstractOperator = None -) -> list[tuple[str, str]]: +def get_license_groups(query: str, ti=None) -> list[tuple[str, str]]: """ Get license groups of rows that don't have a `license_url` in their `meta_data` field. :return: List of (license, version) tuples. """ - license_groups = run_sql(query, dag_task=dag_task) + license_groups = run_sql(query, dag_task=ti.task) total_nulls = sum(group[2] for group in license_groups) licenses_detailed = "\n".join( @@ -84,18 +82,14 @@ def get_license_groups( return [(group[0], group[1]) for group in license_groups] -@task(max_active_tis_per_dag=1) -def update_license_url( - license_group: tuple[str, str], - batch_size: int, - dag_task: AbstractOperator = None, -) -> int: +@task(max_active_tis_per_dag=1, execution_timeout=timedelta(hours=36)) +def update_license_url(license_group: tuple[str, str], batch_size: int, ti=None) -> int: """ Add license_url to meta_data batching all records with the same license. :param license_group: tuple of license and version :param batch_size: number of records to update in one update statement - :param dag_task: automatically passed by Airflow, used to set the execution timeout. + :param ti: automatically passed by Airflow, used to set the execution timeout. """ license_, version = license_group license_info = get_license_info_from_license_pair(license_, version) @@ -135,7 +129,7 @@ def update_license_url( method="run", handler=RETURN_ROW_COUNT, autocommit=True, - dag_task=dag_task, + dag_task=ti.task, ) total_updated += updated_count logger.info(f"Updated {total_updated} rows with {license_url}.") @@ -144,18 +138,18 @@ def update_license_url( @task(trigger_rule=TriggerRule.ALL_DONE) -def report_completion(updated, query: str, dag_task: AbstractOperator = None): +def report_completion(updated, query: str, ti=None): """ Check for null in `meta_data` and send a message to Slack with the statistics of the DAG run. :param updated: total number of records updated :param query: SQL query to get the count of records left with `license_url` as NULL - :param dag_task: automatically passed by Airflow, used to set the execution timeout. + :param ti: automatically passed by Airflow, used to set the execution timeout. """ total_updated = sum(updated) if updated else 0 - license_groups = run_sql(query, dag_task=dag_task) + license_groups = run_sql(query, dag_task=ti.task) total_nulls = sum(group[2] for group in license_groups) licenses_detailed = "\n".join( f"{group[0]} \t{group[1]} \t{group[2]}" for group in license_groups