Skip to content

Commit

Permalink
add embeddings_for_photos endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
pleary committed Nov 15, 2024
1 parent a08c8f9 commit e2cb70b
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 1 deletion.
67 changes: 66 additions & 1 deletion lib/inat_inferrer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import time
import magic
import tensorflow as tf
import pandas as pd
import h3
Expand All @@ -8,6 +7,14 @@
import os
import tifffile
import numpy as np
import urllib
import hashlib
import magic
import aiohttp
import aiofiles
import aiofiles.os
import asyncio

from PIL import Image
from lib.tf_gp_elev_model import TFGeoPriorModelElev
from lib.vision_inferrer import VisionInferrer
Expand Down Expand Up @@ -610,6 +617,64 @@ def limit_leaf_scores_that_include_humans(self, leaf_scores):
# otherwise return no results
return leaf_scores.head(0)

async def embeddings_for_photos(self, photos):
response = {}
async with aiohttp.ClientSession() as session:
queue = asyncio.Queue()
workers = [asyncio.create_task(self.embeddings_worker_task(queue, response, session))
for _ in range(5)]
for photo in photos:
queue.put_nowait(photo)
await queue.join()
for worker in workers:
worker.cancel()
return response

async def embeddings_worker_task(self, queue, response, session):
while not queue.empty():
photo = await queue.get()
try:
embedding = await self.embedding_for_photo(photo["url"], session)
response[photo["id"]] = embedding
finally:
queue.task_done()

async def embedding_for_photo(self, url, session):
if url is None:
return

try:
cache_path = await self.download_photo_async(url, session)
if cache_path is None:
return
image = InatInferrer.prepare_image_for_inference(cache_path)
except urllib.error.HTTPError:
return
return self.vision_inferrer.signature_for_image(image).tolist()

async def download_photo_async(self, url, session):
checksum = hashlib.md5(url.encode()).hexdigest()
cache_path = os.path.join(self.upload_folder, "download-" + checksum) + ".jpg"
if await aiofiles.os.path.exists(cache_path):
return cache_path
try:
async with session.get(url, timeout=10) as resp:
if resp.status == 200:
f = await aiofiles.open(cache_path, mode="wb")
await f.write(await resp.read())
await f.close()
except asyncio.TimeoutError as e:
print("`download_photo_async` timed out")
print(e)
if not os.path.exists(cache_path):
return
mime_type = magic.from_file(cache_path, mime=True)
if mime_type != "image/jpeg":
im = Image.open(cache_path)
rgb_im = im.convert("RGB")
rgb_im.save(cache_path)
return cache_path

@staticmethod
def prepare_image_for_inference(file_path):
image = Image.open(file_path)
Expand Down
8 changes: 8 additions & 0 deletions lib/inat_vision_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def __init__(self, config):
self.h3_04_bounds_route, methods=["GET"])
self.app.add_url_rule("/geo_scores_for_taxa", "geo_scores_for_taxa",
self.geo_scores_for_taxa_route, methods=["POST"])
self.app.add_url_rule("/embeddings_for_photos", "embeddings_for_photos",
self.embeddings_for_photos_route, methods=["POST"])
self.app.add_url_rule("/build_info", "build_info", self.build_info_route, methods=["GET"])

def setup_inferrer(self, config):
Expand Down Expand Up @@ -96,6 +98,12 @@ def geo_scores_for_taxa_route(self):
for obs in request.json["observations"]
}

async def embeddings_for_photos_route(self):
start_time = time.time()
response = await self.inferrer.embeddings_for_photos(request.json["photos"])
print("embeddings_for_photos_route Time: %0.2fms" % ((time.time() - start_time) * 1000.))
return response

def index_route(self):
form = ImageForm()
if "observation_id" in request.args:
Expand Down
8 changes: 8 additions & 0 deletions lib/vision_inferrer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,16 @@ def prepare_tf_model(self):
assert device.device_type != "GPU"

self.vision_model = tf.keras.models.load_model(self.model_path, compile=False)
self.signature_model = tf.keras.Model(
inputs=self.vision_model.inputs,
outputs=self.vision_model.get_layer("global_average_pooling2d_5").output
)
self.signature_model.compile()

# given an image object (usually coming from prepare_image_for_inference),
# calculate vision results for the image
def process_image(self, image):
return self.vision_model(tf.convert_to_tensor(image), training=False)[0]

def signature_for_image(self, image):
return self.signature_model(tf.convert_to_tensor(image), training=False)[0].numpy()

0 comments on commit e2cb70b

Please sign in to comment.