diff --git a/src/noteburst/exceptions.py b/src/noteburst/exceptions.py index 10a788d..429c3c1 100644 --- a/src/noteburst/exceptions.py +++ b/src/noteburst/exceptions.py @@ -11,6 +11,7 @@ __all__ = [ "TaskError", "NbexecTaskError", + "NbexecTaskTimeoutError", "NoteburstClientRequestError", "NoteburstError", ] @@ -38,6 +39,14 @@ class NbexecTaskError(TaskError): task_name = "nbexec" +class NbexecTaskTimeoutError(NbexecTaskError): + """Error raised when a notebook execution task times out.""" + + @classmethod + def from_exception(cls, exc: Exception) -> Self: + return cls(f"{cls.task_name} timeout error\n\n{exc!s}") + + class NoteburstClientRequestError(ClientRequestError): """Error related to the API client.""" diff --git a/src/noteburst/handlers/v1/handlers.py b/src/noteburst/handlers/v1/handlers.py index 8760e89..4907b55 100644 --- a/src/noteburst/handlers/v1/handlers.py +++ b/src/noteburst/handlers/v1/handlers.py @@ -68,6 +68,7 @@ async def post_nbexec( ipynb=request_data.get_ipynb_as_str(), kernel_name=request_data.kernel_name, enable_retry=request_data.enable_retry, + timeout=request_data.timeout, ) logger.info("Finished enqueing an nbexec task", job_id=job_metadata.id) response_data = await NotebookResponse.from_job_metadata( diff --git a/src/noteburst/handlers/v1/models.py b/src/noteburst/handlers/v1/models.py index fd3b211..ee77667 100644 --- a/src/noteburst/handlers/v1/models.py +++ b/src/noteburst/handlers/v1/models.py @@ -3,13 +3,14 @@ from __future__ import annotations import json -from datetime import datetime +from datetime import datetime, timedelta from typing import Annotated, Any from arq.jobs import JobStatus from fastapi import Request from pydantic import AnyHttpUrl, BaseModel, Field from safir.arq import JobMetadata, JobResult +from safir.pydantic import HumanTimedelta from noteburst.jupyterclient.jupyterlab import ( NotebookExecutionErrorModel, @@ -172,6 +173,17 @@ class PostNotebookRequest(BaseModel): kernel_name: Annotated[str, kernel_name_field] + timeout: HumanTimedelta = Field( + default_factory=lambda: timedelta(seconds=300), + title="Timeout for notebook execution.", + description=( + "The timeout is is a human-readable duration string. For " + "example, '5m' is 5 minutes, '1h' is 1 hour, '1d' is 1 day." + "If the notebook execution does not complete within this time," + "the job is marked as failed." + ), + ) + enable_retry: Annotated[ bool, Field( diff --git a/src/noteburst/worker/functions/nbexec.py b/src/noteburst/worker/functions/nbexec.py index c008edc..74230bc 100644 --- a/src/noteburst/worker/functions/nbexec.py +++ b/src/noteburst/worker/functions/nbexec.py @@ -4,14 +4,16 @@ from __future__ import annotations +import asyncio import json import sys +from datetime import timedelta from typing import Any, cast from arq import Retry from safir.slack.blockkit import SlackCodeBlock, SlackTextField -from noteburst.exceptions import NbexecTaskError +from noteburst.exceptions import NbexecTaskError, NbexecTaskTimeoutError from noteburst.jupyterclient.jupyterlab import JupyterClient, JupyterError @@ -21,6 +23,7 @@ async def nbexec( ipynb: str, kernel_name: str = "LSST", enable_retry: bool = True, + timeout: timedelta | None = None, # noqa: ASYNC109 ) -> str: """Execute a notebook, as an asynchronous arq worker task. @@ -54,10 +57,15 @@ async def nbexec( parsed_notebook = json.loads(ipynb) logger.debug("Got ipynb", ipynb=parsed_notebook) try: - execution_result = await jupyter_client.execute_notebook( - parsed_notebook, kernel_name=kernel_name + execution_result = await asyncio.wait_for( + jupyter_client.execute_notebook( + parsed_notebook, kernel_name=kernel_name + ), + timeout=timeout.total_seconds() if timeout else None, ) logger.info("nbexec finished", error=execution_result.error) + except TimeoutError as e: + raise NbexecTaskTimeoutError.from_exception(e) from e except JupyterError as e: logger.exception("nbexec error", jupyter_status=e.status) if "slack" in ctx and "slack_message_factory" in ctx: