diff --git a/catalog/dags/maintenance/add_license_url.py b/catalog/dags/maintenance/add_license_url.py index 3eb36de3535..5479214a0ab 100644 --- a/catalog/dags/maintenance/add_license_url.py +++ b/catalog/dags/maintenance/add_license_url.py @@ -1,239 +1,244 @@ """ # Add license URL -Add `license_url` to all rows that have `NULL` in their `meta_data` fields. -This PR sets the meta_data value to "{license_url: https://... }", where the +Add `license_url` to rows without one in their `meta_data` fields. +This PR merges the `meta_data` value with "{license_url: https://... }", where the url is constructed from the `license` and `license_version` columns. -This is a maintenance DAG that should be run once. If all the null values in -the `meta_data` column are updated, the DAG will only run the first and the -last step, logging the statistics. +This is a maintenance DAG that should be run once. """ -import csv import logging -from collections import defaultdict from datetime import timedelta -from tempfile import NamedTemporaryFile from textwrap import dedent -from airflow.models import DAG +from airflow.decorators import dag, task +from airflow.exceptions import AirflowSkipException from airflow.models.abstractoperator import AbstractOperator -from airflow.operators.python import PythonOperator -from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.models.param import Param +from airflow.utils.state import State from airflow.utils.trigger_rule import TriggerRule from psycopg2._json import Json -from common.constants import DAG_DEFAULT_ARGS, POSTGRES_CONN_ID, XCOM_PULL_TEMPLATE +from common import slack +from common.constants import DAG_DEFAULT_ARGS, POSTGRES_CONN_ID from common.licenses import get_license_info_from_license_pair -from common.slack import send_message from common.sql import RETURN_ROW_COUNT, PostgresHook -from providers.provider_dag_factory import AWS_CONN_ID, OPENVERSE_BUCKET DAG_ID = "add_license_url" -UPDATE_LICENSE_URL = "update_license_url" -FINAL_REPORT = "final_report" - -ALERT_EMAIL_ADDRESSES = "" logger = logging.getLogger(__name__) -base_url = "https://creativecommons.org/" - -def get_null_counts( - postgres_conn_id: str, - task: AbstractOperator, -) -> int: +def run_sql( + sql: str, + log_sql: bool = True, + method: str = "get_records", + handler: callable = None, + autocommit: bool = False, + postgres_conn_id: str = POSTGRES_CONN_ID, + dag_task: AbstractOperator = None, +): postgres = PostgresHook( postgres_conn_id=postgres_conn_id, - default_statement_timeout=PostgresHook.get_execution_timeout(task), + default_statement_timeout=PostgresHook.get_execution_timeout(dag_task), + log_sql=log_sql, ) - null_meta_data_count = postgres.get_first( - dedent("SELECT COUNT(*) from image WHERE meta_data IS NULL;") - )[0] - return null_meta_data_count + if method == "get_records": + return postgres.get_records(sql) + elif method == "get_first": + return postgres.get_first(sql) + else: + return postgres.run(sql, autocommit=autocommit, handler=handler) -def update_license_url( - postgres_conn_id: str, s3_bucket, aws_conn_id, task: AbstractOperator -) -> dict[str, int]: - """Add license_url to meta_data batching all records with the same license. - :param aws_conn_id: AWS connection id - :param s3_bucket: the bucket to upload the invalid items TSV to - :param task: automatically passed by Airflow, used to set the execution timeout - :param postgres_conn_id: Postgres connection id +@task +def get_license_groups( + query: str, dag_task: AbstractOperator = None +) -> list[tuple[str, str]]: """ + Get license groups of rows that don't have a `license_url` in their + `meta_data` field. - logger.info("Getting image records with NULL in meta_data.") - postgres = PostgresHook( - postgres_conn_id=postgres_conn_id, - default_statement_timeout=PostgresHook.get_execution_timeout(task), + :return: List of (license, version) tuples. + """ + license_groups = run_sql(query, dag_task=dag_task) + + total_nulls = sum(group[2] for group in license_groups) + licenses_detailed = "\n".join( + f"{group[0]} \t{group[1]} \t{group[2]}" for group in license_groups ) - select_query = dedent( - """ - SELECT identifier, license, license_version - FROM image WHERE meta_data IS NULL;""" + message = f""" +Starting `{DAG_ID}` DAG. Found {len(license_groups)} license groups with {total_nulls} +records without `license_url` in `meta_data` left.\nCount per license-version: +{licenses_detailed} + """ + slack.send_message( + message, + username="Airflow DAG Data Normalization - license_url", + dag_id=DAG_ID, ) - records_with_null_in_metadata = postgres.get_records(select_query) - logger.info(f"{len(records_with_null_in_metadata)} records found.") - # Dictionary with license pair as key and list of identifiers as value - records_to_update = defaultdict(list) + return [(group[0], group[1]) for group in license_groups] - for result in records_with_null_in_metadata: - identifier, license_, version = result - # Some CC0 and PDM licenses are stored as uppercase in the database - license_ = license_.lower() - records_to_update[(license_, version)].append(identifier) - total_updated = 0 - updated_by_license = {} - - invalid_items = [] - - for (license_, version), identifiers in records_to_update.items(): - *_, license_url = get_license_info_from_license_pair(license_, version) - if license_url is None: - logger.info(f"No license pair ({license_}, {version}) in the license map.") - for identifier in identifiers: - invalid_items.append( - { - "license": license_, - "license_version": version, - "identifier": identifier, - } - ) - continue - logger.info( - f"{len(identifiers):4} items will be updated " - f"with {license_url} and {license_}." - ) - license_url_dict = {"license_url": license_url} - update_query = dedent( - f""" - UPDATE image - SET meta_data = {Json(license_url_dict)}, license='{license_}' - WHERE identifier IN ({','.join([f"'{r}'" for r in identifiers])}); - """ +@task(max_active_tis_per_dag=1) +def update_license_url( + license_group: tuple[str, str], + batch_size: int, + dag_task: AbstractOperator = None, +) -> int: + """ + Add license_url to meta_data batching all records with the same license. + + :param license_group: tuple of license and version + :param batch_size: number of records to update in one update statement + :param dag_task: automatically passed by Airflow, used to set the execution timeout. + """ + license_, version = license_group + license_info = get_license_info_from_license_pair(license_, version) + if license_info is None: + raise AirflowSkipException( + f"No license pair ({license_}, {version}) in the license map." ) - updated_count: int = postgres.run( - update_query, autocommit=True, handler=RETURN_ROW_COUNT + *_, license_url = license_info + + logging.info( + f"Will add `license_url` in `meta_data` for records with license " + f"{license_} {version} to {license_url}." + ) + license_url_dict = {"license_url": license_url} + + # Merge existing metadata with the new license_url + update_query = dedent( + f""" + UPDATE image + SET meta_data = ({Json(license_url_dict)}::jsonb || meta_data), updated_on = now() + WHERE identifier IN ( + SELECT identifier + FROM image + WHERE license = '{license_}' AND license_version = '{version}' + AND meta_data->>'license_url' IS NULL + LIMIT {batch_size} + FOR UPDATE SKIP LOCKED + ); + """ + ) + total_updated = 0 + updated_count = 1 + while updated_count: + updated_count = run_sql( + update_query, + log_sql=total_updated == 0, + method="run", + handler=RETURN_ROW_COUNT, + autocommit=True, + dag_task=dag_task, ) - logger.info(f"{updated_count} records updated with {license_url}.") - if updated_count: - updated_by_license[license_url] = updated_count total_updated += updated_count - logger.info(f"Updated {total_updated} rows") - # Save the invalid_items to S3 as a TSV - save_to_s3(aws_conn_id, invalid_items, s3_bucket) - return updated_by_license + logger.info(f"Updated {total_updated} rows with {license_url}.") + return total_updated -def save_to_s3(aws_conn_id, invalid_items, s3_bucket): - """ - Save the records with invalid license pairs to S3. - :param aws_conn_id: AWS connection id - :param invalid_items: The list of dictionaries with the invalid items - :param s3_bucket: S3 bucket + +@task(trigger_rule=TriggerRule.ALL_DONE) +def report_completion(updated, query: str, dag_task: AbstractOperator = None): """ - if not invalid_items: - return + Check for null in `meta_data` and send a message to Slack with the statistics + of the DAG run. - s3_key = "invalid_items.tsv" + :param updated: total number of records updated + :param query: SQL query to get the count of records left with `license_url` as NULL + :param dag_task: automatically passed by Airflow, used to set the execution timeout. + """ + total_updated = sum(updated) if updated else 0 - with NamedTemporaryFile(mode="w+", encoding="utf-8") as f: - tsv_writer = csv.DictWriter( - f, delimiter="\t", fieldnames=["license", "license_version", "identifier"] - ) - tsv_writer.writeheader() - for item in invalid_items: - tsv_writer.writerow(item) - f.seek(0) - logger.info(f"Uploading the invalid items to {s3_bucket}:{s3_key}") - with open(f.name) as fp: - logger.info(fp.read()) - s3 = S3Hook(aws_conn_id=aws_conn_id) - s3.load_file(f.name, s3_key, bucket_name=s3_bucket, replace=True) - - -def final_report( - postgres_conn_id: str, - updated_by_license: dict[str, int] | None, - task: AbstractOperator = None, -): - """Check for null in `meta_data` and send a message to Slack - with the statistics of the DAG run. + license_groups = run_sql(query, dag_task=dag_task) + total_nulls = sum(group[2] for group in license_groups) + licenses_detailed = "\n".join( + f"{group[0]} \t{group[1]} \t{group[2]}" for group in license_groups + ) - :param postgres_conn_id: Postgres connection id. - :param updated_by_license: stringified JSON with the number of records updated - for each license_url. If `update_license_url` was skipped, this will be "None". - :param task: automatically passed by Airflow, used to set the execution timeout. + message = f""" + `{DAG_ID}` DAG run completed. Updated {total_updated} record(s) with `license_url` in the + `meta_data` field. Found {len(license_groups)} license groups with {total_nulls} record(s) left pending. """ - null_meta_data_count = get_null_counts(postgres_conn_id, task) + if total_nulls != 0: + message += f"\nCount per license-version:\n{licenses_detailed}" - if not updated_by_license: - updated_message = "No records were updated." - else: - formatted_item_count = "".join( - [ - f"{license_url}: {count} rows\n" - for license_url, count in updated_by_license.items() - ] - ) - updated_message = f"Update statistics:\n{formatted_item_count}" - message = f""" -`add_license_url` DAG run completed. -{updated_message} -Now, there are {null_meta_data_count} records with NULL meta_data left. -""" - send_message( + slack.send_message( message, username="Airflow DAG Data Normalization - license_url", dag_id=DAG_ID, ) - logger.info(message) + +@task(trigger_rule=TriggerRule.ALL_DONE) +def report_failed_license_pairs(dag_run=None): + """ + Send a message to Slack with the license-version pairs that could not be found + in the license map. + """ + skipped_tasks = [ + dag_task + for dag_task in dag_run.get_task_instances(state=State.SKIPPED) + if "update_license_url" in dag_task.task_id + ] + + if not skipped_tasks: + raise AirflowSkipException + + message = ( + f""" + One or more license pairs could not be found in the license map while running + the `{DAG_ID}` DAG. See the logs for more details: + """ + ) + "\n".join( + f" - <{dag_task.log_url}|{dag_task.task_id}>" for dag_task in skipped_tasks[:5] + ) + + slack.send_alert( + message, + username="Airflow DAG Data Normalization - license_url", + dag_id=DAG_ID, + ) -dag = DAG( +@dag( dag_id=DAG_ID, + schedule=None, + catchup=False, + tags=["data_normalization"], + doc_md=__doc__, default_args={ **DAG_DEFAULT_ARGS, "retries": 0, "execution_timeout": timedelta(hours=5), }, - schedule=None, - catchup=False, - doc_md=__doc__, - tags=["data_normalization"], render_template_as_native_obj=True, + params={ + "batch_size": Param( + default=10_000, + type="integer", + description="The number of records to update per batch.", + ), + }, ) - -with dag: - update_license_url = PythonOperator( - task_id=UPDATE_LICENSE_URL, - python_callable=update_license_url, - op_kwargs={ - "postgres_conn_id": POSTGRES_CONN_ID, - "s3_bucket": OPENVERSE_BUCKET, - "aws_conn_id": AWS_CONN_ID, - }, - ) - final_report = PythonOperator( - task_id=FINAL_REPORT, - python_callable=final_report, - trigger_rule=TriggerRule.ALL_DONE, - op_kwargs={ - "postgres_conn_id": POSTGRES_CONN_ID, - "updated_by_license": XCOM_PULL_TEMPLATE.format( - update_license_url.task_id, "return_value" - ), - }, +def add_license_url(): + query = dedent(""" + SELECT license, license_version, count(identifier) + FROM image WHERE meta_data->>'license_url' IS NULL + GROUP BY license, license_version + """) + + license_groups = get_license_groups(query) + updated = update_license_url.partial(batch_size="{{ params.batch_size }}").expand( + license_group=license_groups ) + report_completion(updated, query) + updated >> report_failed_license_pairs() + - # update_license_url only updates the images if there are records - # with NULL meta_data that have valid license pairs. - update_license_url >> final_report +add_license_url() diff --git a/documentation/catalog/reference/DAGs.md b/documentation/catalog/reference/DAGs.md index 3f272d477bf..f91f763b1b6 100644 --- a/documentation/catalog/reference/DAGs.md +++ b/documentation/catalog/reference/DAGs.md @@ -184,13 +184,11 @@ The following is documentation associated with each DAG (where available): #### Add license URL -Add `license_url` to all rows that have `NULL` in their `meta_data` fields. This -PR sets the meta_data value to "{license_url: https://... }", where the url is -constructed from the `license` and `license_version` columns. +Add `license_url` to rows without one in their `meta_data` fields. This PR +merges the `meta_data` value with "{license_url: https://... }", where the url +is constructed from the `license` and `license_version` columns. -This is a maintenance DAG that should be run once. If all the null values in the -`meta_data` column are updated, the DAG will only run the first and the last -step, logging the statistics. +This is a maintenance DAG that should be run once. ----