Skip to content

Commit

Permalink
CELE-119 feat: Add populate_db endpoint and connect it to ingestion s…
Browse files Browse the repository at this point in the history
…cript
  • Loading branch information
afonsobspinto committed Dec 19, 2024
1 parent 15d329b commit 5ca9fa6
Show file tree
Hide file tree
Showing 9 changed files with 193 additions and 24 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ cloud-harness/
.vscode/
node_modules
secret.json
data/
5 changes: 4 additions & 1 deletion applications/visualizer/backend/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -406,4 +406,7 @@ poetry.toml
# LSP config files
pyrightconfig.json

# End of https://www.toptal.com/developers/gitignore/api/node,python,django
# End of https://www.toptal.com/developers/gitignore/api/node,python,django


static/
26 changes: 25 additions & 1 deletion applications/visualizer/backend/api/api.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
from io import StringIO
import sys
from collections import defaultdict
from typing import Iterable, Optional

from ninja import NinjaAPI, Router, Query, Schema
from ninja.pagination import paginate, PageNumberPagination
from ninja.errors import HttpError

from django.shortcuts import aget_object_or_404
from django.db.models import Q
from django.db.models.manager import BaseManager
from django.conf import settings
from django.core.management import call_command


from .utils import get_dataset_viewer_config, to_list

Expand All @@ -16,8 +22,9 @@
Neuron as NeuronModel,
Connection as ConnectionModel,
)
from .decorators.streaming import with_stdout_streaming
from .services.connectivity import query_nematode_connections

from .authenticators.basic_auth_super_user import basic_auth_superuser

class ErrorMessage(Schema):
detail: str
Expand Down Expand Up @@ -237,6 +244,23 @@ def get_connections(
# )


## Ingestion


@api.get("/populate_db", auth=basic_auth_superuser, tags=["ingestion"])
@with_stdout_streaming
def populate_db(request):
try:
print("Starting DB population...\n")
call_command("migrate")
call_command("populatedb")
except Exception as e:
raise HttpError(500)


## Healthcheck


@api.get("/live", tags=["healthcheck"])
async def live(request):
"""Test if application is healthy"""
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from ninja.security import HttpBasicAuth
from django.contrib.auth import authenticate as django_authenticate


class BasicAuthSuperUser(HttpBasicAuth):
def authenticate(self, request, username, password):
# Authenticate user with Django's built-in authenticate function
user = django_authenticate(request, username=username, password=password)
if user and user.is_superuser: # Ensure the user is a superuser
return user
return None

basic_auth_superuser = BasicAuthSuperUser()
Empty file.
61 changes: 61 additions & 0 deletions applications/visualizer/backend/api/decorators/streaming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import asyncio
import sys
import threading
from queue import Queue
from functools import wraps
from django.http import StreamingHttpResponse

def with_stdout_streaming(func):
"""
A decorator that:
- Runs the decorated function in a separate thread,
- Captures anything it prints to stdout,
- Streams that output asynchronously line-by-line as it's produced.
"""
@wraps(func)
def wrapper(request, *args, **kwargs):
q = Queue()

def run_func():
# Redirect sys.stdout
old_stdout = sys.stdout

class QueueWriter:
def write(self, data):
if data:
q.put(data) # Push data into the thread-safe queue

def flush(self):
pass # For compatibility with print

sys.stdout = QueueWriter()

try:
func(request, *args, **kwargs)
except Exception as e:
q.put(f"Error: {e}\n")
finally:
# Signal completion
q.put(None)
sys.stdout = old_stdout

# Run the function in a background thread
t = threading.Thread(target=run_func)
t.start()

# Async generator to yield lines from the queue
async def line_generator():
while True:
line = await asyncio.to_thread(q.get) # Get item from thread-safe queue
if line is None: # End signal
break
yield line
await asyncio.sleep(0) # Yield control to event loop

# Return a streaming response that sends data asynchronously
response = StreamingHttpResponse(line_generator(), content_type="text/plain")
response['Cache-Control'] = 'no-cache'
response['X-Accel-Buffering'] = 'no' # Disable nginx buffering if using nginx
return response

return wrapper
110 changes: 88 additions & 22 deletions ingestion/ingestion/ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from google.api_core.exceptions import PreconditionFailed
from google.cloud import storage
from pydantic import ValidationError
import niquests
from tqdm import tqdm

from ingestion.cli import ask, type_directory, type_file
Expand Down Expand Up @@ -47,8 +48,12 @@
logger = logging.getLogger(__name__)


def _done_message(dataset_name: str, dry_run: bool = False) -> str:
return f"==> Done {'upload simulation for' if dry_run else 'uploading'} dataset '{dataset_name}'! ✨"
def _done_message(dataset_name: str | None, dry_run: bool) -> str:
"""Generate a completion message for the ingestion process."""
if dataset_name:
return f"==> Done {'upload simulation for' if dry_run else 'uploading'} dataset '{dataset_name}'! ✨"
else:
return "==> Ingestion completed! ✨"


def add_flags(parser: ArgumentParser):
Expand Down Expand Up @@ -108,6 +113,18 @@ def env_or(name: str, default: str) -> str:
),
)

parser.add_argument(
"--populate-db",
action="store_true",
help="Trigger DB population via the API endpoint",
)

parser.add_argument(
"--populate-db-url",
default="https://celegans.dev.metacell.us/api/populate_db",
help="The API URL to trigger DB population",
)


def add_add_dataset_flags(parser: ArgumentParser):
parser.add_argument(
Expand Down Expand Up @@ -460,6 +477,49 @@ def upload_em_tiles(
pbar.close()


def trigger_populate_db(args):
try:
api_url = args.populate_db_url

# Load service account credentials from gcp_credentials
with open(args.gcp_credentials, "r") as f:
gcp_creds = json.load(f)

client_id = gcp_creds.get("client_id")
private_key_id = gcp_creds.get("private_key_id")

if not client_id or not private_key_id:
print(
"Error: Could not extract client_id or private_key_id from gcp_credentials",
file=sys.stderr,
)
return

# Make a GET request to the streaming endpoint with basic auth
r = niquests.get(f"{api_url}", auth=(client_id, private_key_id), stream=True, timeout=None)

if r.status_code == 200:
for line in r.iter_lines():
# filter out keep-alive new lines
if line:
decoded_line = line.decode("utf-8")
print(decoded_line)
else:
print(
f"Error triggering DB population: {r.status_code} {r.text}",
file=sys.stderr,
)

except FileNotFoundError as e:
print(f"Error: Credentials file not found. {e}", file=sys.stderr)
except json.JSONDecodeError as e:
print(f"Error: Invalid JSON in the credentials file. {e}", file=sys.stderr)
except niquests.RequestException as e:
print(f"Error: Failed to make a request to the server. {e}", file=sys.stderr)
except Exception as e:
print(f"Unexpected error: {e}", file=sys.stderr)


def ingest_cmd(args: Namespace):
"""Runs the ingestion command."""

Expand All @@ -471,7 +531,7 @@ def ingest_cmd(args: Namespace):
bucket = storage_client.get_bucket(args.gcp_bucket)
rs = RemoteStorage(bucket, dry_run=args.dry_run)

dataset_id = args.id
dataset_id = getattr(args, "id", None)
overwrite = args.overwrite

if args.prune:
Expand All @@ -485,29 +545,35 @@ def ingest_cmd(args: Namespace):
elif dry_run:
logger.info(f"skipped prunning files from the bucket")

if args.data:
validate_and_upload_data(dataset_id, args.data, rs, overwrite=overwrite)
elif dry_run:
logger.warning(f"skipping neurons data validation and upload")
if dataset_id:
if args.data:
validate_and_upload_data(dataset_id, args.data, rs, overwrite=overwrite)
elif dry_run:
logger.warning(f"skipping neurons data validation and upload")

if args.segmentation:
upload_segmentations(dataset_id, args.segmentation, rs, overwrite=overwrite)
elif dry_run:
logger.warning("skipping segmentation upload: flag not set")
if args.segmentation:
upload_segmentations(dataset_id, args.segmentation, rs, overwrite=overwrite)
elif dry_run:
logger.warning("skipping segmentation upload: flag not set")

if args.synapses:
upload_synapses(dataset_id, args.synapses, rs, overwrite=overwrite)
elif dry_run:
logger.warning("skipping synapses upload: flag not set")
if args.synapses:
upload_synapses(dataset_id, args.synapses, rs, overwrite=overwrite)
elif dry_run:
logger.warning("skipping synapses upload: flag not set")

if paths := getattr(args, "3d"):
upload_3d(dataset_id, paths, rs, overwrite=overwrite)
elif dry_run:
logger.warning("skipping 3D files upload: flag not set")
if paths := getattr(args, "3d"):
upload_3d(dataset_id, paths, rs, overwrite=overwrite)
elif dry_run:
logger.warning("skipping 3D files upload: flag not set")

if args.em:
upload_em_tiles(dataset_id, args.em, rs, overwrite=overwrite)
elif dry_run:
logger.warning("skipping EM tiles upload: flag not set")

if args.em:
upload_em_tiles(dataset_id, args.em, rs, overwrite=overwrite)
if args.populate_db:
trigger_populate_db(args)
elif dry_run:
logger.warning("skipping EM tiles upload: flag not set")
logger.warning("skipping populate DB: flag not set")

print(_done_message(dataset_id, dry_run))
1 change: 1 addition & 0 deletions ingestion/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ dependencies = [
"google-cloud-storage==2.18.2",
"tqdm==4.66.5",
"pillow==10.4.0",
"niquests==3.7.2",

# extraction dependencies
"diplib==3.5.1",
Expand Down

0 comments on commit 5ca9fa6

Please sign in to comment.