Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/cele 119 #79

Merged
merged 8 commits into from
Jan 15, 2025
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ cloud-harness/
.vscode/
node_modules
secret.json
data/
5 changes: 4 additions & 1 deletion applications/visualizer/backend/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -406,4 +406,7 @@ poetry.toml
# LSP config files
pyrightconfig.json

# End of https://www.toptal.com/developers/gitignore/api/node,python,django
# End of https://www.toptal.com/developers/gitignore/api/node,python,django


static/
25 changes: 25 additions & 0 deletions applications/visualizer/backend/api/api.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
from io import StringIO
import sys
from collections import defaultdict
from typing import Iterable, Optional

from ninja import NinjaAPI, Router, Query, Schema
from ninja.pagination import paginate, PageNumberPagination
from ninja.errors import HttpError

from django.shortcuts import aget_object_or_404
from django.db.models import Q
from django.db.models.manager import BaseManager
from django.conf import settings
from django.core.management import call_command


from .utils import get_dataset_viewer_config, to_list

Expand All @@ -16,7 +22,9 @@
Neuron as NeuronModel,
Connection as ConnectionModel,
)
from .decorators.streaming import with_stdout_streaming
from .services.connectivity import query_nematode_connections
from .authenticators.basic_auth_super_user import basic_auth_superuser


class ErrorMessage(Schema):
Expand Down Expand Up @@ -237,6 +245,23 @@ def get_connections(
# )


## Ingestion


@api.get("/populate_db", auth=basic_auth_superuser, tags=["ingestion"])
@with_stdout_streaming
def populate_db(request):
afonsobspinto marked this conversation as resolved.
Show resolved Hide resolved
try:
print("Starting DB population...\n")
call_command("migrate")
call_command("populatedb")
except Exception as e:
raise HttpError(500)


## Healthcheck


@api.get("/live", tags=["healthcheck"])
async def live(request):
"""Test if application is healthy"""
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from ninja.security import HttpBasicAuth
from django.contrib.auth import authenticate as django_authenticate


class BasicAuthSuperUser(HttpBasicAuth):
def authenticate(self, request, username, password):
# Authenticate user with Django's built-in authenticate function
user = django_authenticate(request, username=username, password=password)
if user and user.is_superuser: # Ensure the user is a superuser
return user
return None


basic_auth_superuser = BasicAuthSuperUser()
Empty file.
59 changes: 59 additions & 0 deletions applications/visualizer/backend/api/decorators/streaming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import asyncio
import sys
import threading
from queue import Queue
from functools import wraps
from django.http import StreamingHttpResponse


def with_stdout_streaming(func):
"""
A decorator that:
- Runs the decorated function in a separate thread,
- Captures anything it prints to stdout,
- Streams that output asynchronously line-by-line as it's produced.
"""

@wraps(func)
def wrapper(request, *args, **kwargs):
q = Queue()

def run_func():
# Redirect sys.stdout
old_stdout = sys.stdout

class QueueWriter:
def write(self, data):
if data:
q.put(data)

def flush(self):
pass # For compatibility with print

sys.stdout = QueueWriter()

try:
func(request, *args, **kwargs)
except Exception as e:
q.put(f"Error: {e}\n")
finally:
# Signal completion
q.put(None)
sys.stdout = old_stdout

# Run the function in a background thread
t = threading.Thread(target=run_func)
t.start()

# Async generator to yield lines from the queue
async def line_generator():
while True:
line = await asyncio.to_thread(q.get)
if line is None: # End signal
break
yield line

# Return a streaming response that sends data asynchronously
return StreamingHttpResponse(line_generator(), content_type="text/plain")

return wrapper
110 changes: 87 additions & 23 deletions ingestion/ingestion/ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import json
import logging
import os
import subprocess
import sys
import tempfile
from argparse import ArgumentParser, Namespace
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timedelta, timezone
Expand Down Expand Up @@ -47,8 +47,12 @@
logger = logging.getLogger(__name__)


def _done_message(dataset_name: str, dry_run: bool = False) -> str:
return f"==> Done {'upload simulation for' if dry_run else 'uploading'} dataset '{dataset_name}'! ✨"
def _done_message(dataset_name: str | None, dry_run: bool = False) -> str:
"""Generate a completion message for the ingestion process."""
if dataset_name:
return f"==> Done {'upload simulation for' if dry_run else 'uploading'} dataset '{dataset_name}'! ✨"
else:
return "==> Ingestion completed! ✨"


def add_flags(parser: ArgumentParser):
Expand Down Expand Up @@ -108,6 +112,18 @@ def env_or(name: str, default: str) -> str:
),
)

parser.add_argument(
"--populate-db",
action="store_true",
help="Trigger DB population via the API endpoint",
)

parser.add_argument(
"--populate-db-url",
default="https://celegans.dev.metacell.us/api/populate_db",
help="The API URL to trigger DB population",
)


def add_add_dataset_flags(parser: ArgumentParser):
parser.add_argument(
Expand Down Expand Up @@ -460,6 +476,48 @@ def upload_em_tiles(
pbar.close()


def trigger_populate_db(args):
try:
api_url = args.populate_db_url

# Load service account credentials from gcp_credentials
with open(args.gcp_credentials, "r") as f:
gcp_creds = json.load(f)

client_id = gcp_creds.get("client_id")
private_key_id = gcp_creds.get("private_key_id")

if not client_id or not private_key_id:
print(
"Error: Could not extract client_id or private_key_id from gcp_credentials",
file=sys.stderr,
)
return

# Add `stdbuf` to ensure curl is unbuffered
command = [
"stdbuf",
"-oL", # Force line buffering
"curl",
"-N", # Disable buffering in curl
"-u",
f"{client_id}:{private_key_id}",
f"{api_url}",
]

with subprocess.Popen(
command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
) as proc:
try:
for line in proc.stdout:
print(line, end="", flush=True) # Real-time output
except KeyboardInterrupt:
proc.terminate()
print("\nStreaming interrupted by user.", file=sys.stderr)
except Exception as e:
print(f"An error occurred: {e}", file=sys.stderr)


def ingest_cmd(args: Namespace):
"""Runs the ingestion command."""

Expand All @@ -471,7 +529,7 @@ def ingest_cmd(args: Namespace):
bucket = storage_client.get_bucket(args.gcp_bucket)
rs = RemoteStorage(bucket, dry_run=args.dry_run)

dataset_id = args.id
dataset_id = getattr(args, "id", None)
overwrite = args.overwrite

if args.prune:
Expand All @@ -485,29 +543,35 @@ def ingest_cmd(args: Namespace):
elif dry_run:
logger.info(f"skipped prunning files from the bucket")

if args.data:
validate_and_upload_data(dataset_id, args.data, rs, overwrite=overwrite)
elif dry_run:
logger.warning(f"skipping neurons data validation and upload")
if dataset_id:
if args.data:
validate_and_upload_data(dataset_id, args.data, rs, overwrite=overwrite)
elif dry_run:
logger.warning(f"skipping neurons data validation and upload")

if args.segmentation:
upload_segmentations(dataset_id, args.segmentation, rs, overwrite=overwrite)
elif dry_run:
logger.warning("skipping segmentation upload: flag not set")
if args.segmentation:
upload_segmentations(dataset_id, args.segmentation, rs, overwrite=overwrite)
elif dry_run:
logger.warning("skipping segmentation upload: flag not set")

if args.synapses:
upload_synapses(dataset_id, args.synapses, rs, overwrite=overwrite)
elif dry_run:
logger.warning("skipping synapses upload: flag not set")
if args.synapses:
upload_synapses(dataset_id, args.synapses, rs, overwrite=overwrite)
elif dry_run:
logger.warning("skipping synapses upload: flag not set")

if paths := getattr(args, "3d"):
upload_3d(dataset_id, paths, rs, overwrite=overwrite)
elif dry_run:
logger.warning("skipping 3D files upload: flag not set")
if paths := getattr(args, "3d"):
upload_3d(dataset_id, paths, rs, overwrite=overwrite)
elif dry_run:
logger.warning("skipping 3D files upload: flag not set")

if args.em:
upload_em_tiles(dataset_id, args.em, rs, overwrite=overwrite)
elif dry_run:
logger.warning("skipping EM tiles upload: flag not set")

if args.em:
upload_em_tiles(dataset_id, args.em, rs, overwrite=overwrite)
if args.populate_db:
trigger_populate_db(args)
elif dry_run:
logger.warning("skipping EM tiles upload: flag not set")
logger.warning("skipping populate DB: flag not set")

print(_done_message(dataset_id, dry_run))
Loading