Skip to content

Commit

Permalink
Add batch_size param
Browse files Browse the repository at this point in the history
  • Loading branch information
krysal committed Apr 16, 2024
1 parent d5f0f2d commit ac1c6db
Showing 1 changed file with 14 additions and 2 deletions.
16 changes: 14 additions & 2 deletions catalog/dags/maintenance/add_license_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from airflow.decorators import dag, task
from airflow.models.abstractoperator import AbstractOperator
from airflow.models.param import Param
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from psycopg2._json import Json

Expand Down Expand Up @@ -91,13 +92,15 @@ def get_license_groups(
@task
def update_license_url(
license_group: tuple[str, str],
batch_size: int = 10_000,
batch_size: int,
dag_task: AbstractOperator = None,
postgres_conn_id: str = POSTGRES_CONN_ID,
) -> int:
"""
Add license_url to meta_data batching all records with the same license.
:param license_group: tuple of license and version
:param batch_size: number of records to update in one update statement
:param dag_task: automatically passed by Airflow, used to set the execution timeout
:param postgres_conn_id: Postgres connection id
"""
Expand Down Expand Up @@ -211,10 +214,19 @@ def save_to_s3(
"execution_timeout": timedelta(hours=5),
},
render_template_as_native_obj=True,
params={
"batch_size": Param(
default=10_000,
type="integer",
description="The number of records to update per batch.",
),
},
)
def add_license_url():
license_groups = get_license_groups()
updated = update_license_url.expand(license_group=license_groups)
updated = update_license_url.expand(
license_group=license_groups, batch_size="{{ params.batch_size }}"
)
# save_to_s3(invalid_items)
final_report(updated)

Expand Down

0 comments on commit ac1c6db

Please sign in to comment.