Skip to content

Commit

Permalink
[hma] Patch similarity score (#1641)
Browse files Browse the repository at this point in the history
  • Loading branch information
juanmrad authored Oct 3, 2024
1 parent 505b431 commit 806f30e
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
from werkzeug.exceptions import HTTPException

from OpenMediaMatch.blueprints.hashing import hash_media
from OpenMediaMatch.blueprints.matching import lookup_signal
from OpenMediaMatch.blueprints.matching import (
lookup_signal,
lookup_signal_with_distance,
)
from OpenMediaMatch.utils.flask_utils import api_error_handler

from OpenMediaMatch.utils import dev_utils
Expand Down Expand Up @@ -50,6 +53,11 @@ def query_media():
return signal_type_to_signal_map
abort(500, "Something went wrong while hashing the provided media.")

include_distance = bool(request.args.get("include_distance", False)) == True
lookup_signal_func = (
lookup_signal_with_distance if include_distance else lookup_signal
)

# Check if signal_type is an option in the map of hashes
signal_type_name = request.args.get("signal_type")
if signal_type_name is not None:
Expand All @@ -59,14 +67,14 @@ def query_media():
f"Requested signal type '{signal_type_name}' is not supported for the provided "
"media.",
)
return lookup_signal(
return lookup_signal_func(
signal_type_to_signal_map[signal_type_name], signal_type_name
)
return {
"matches": list(
itertools.chain(
*map(
lambda x: lookup_signal(x[1], x[0])["matches"],
lambda x: lookup_signal_func(x[1], x[0])["matches"],
signal_type_to_signal_map.items(),
),
)
Expand Down
53 changes: 47 additions & 6 deletions hasher-matcher-actioner/src/OpenMediaMatch/blueprints/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from dataclasses import dataclass
import datetime
import random
import sys
import typing as t
import time

Expand All @@ -15,18 +16,31 @@
from werkzeug.exceptions import HTTPException

from threatexchange.signal_type.signal_base import SignalType
from threatexchange.signal_type.index import SignalTypeIndex
from threatexchange.signal_type.index import (
IndexMatchUntyped,
SignalSimilarityInfo,
SignalTypeIndex,
)

from OpenMediaMatch.background_tasks.development import get_apscheduler
from OpenMediaMatch.storage import interface
from OpenMediaMatch.blueprints import hashing
from OpenMediaMatch.utils.flask_utils import require_request_param, api_error_handler
from OpenMediaMatch.utils.flask_utils import (
api_error_handler,
require_request_param,
str_to_bool,
)
from OpenMediaMatch.persistence import get_storage

bp = Blueprint("matching", __name__)
bp.register_error_handler(HTTPException, api_error_handler)


class MatchWithDistance(t.TypedDict):
content_id: int
distance: str


@dataclass
class _SignalIndexInMemoryCache:
signal_type: t.Type[SignalType]
Expand Down Expand Up @@ -94,15 +108,23 @@ def raw_lookup():
* Signal type (hash type)
* Signal value (the hash)
* Optional list of banks to restrict search to
* Optional include_distance (bool) wether or not to return distance values on match
Output:
* List of matching content items
* List of matching with content_id and, if included, distance values
"""
signal = require_request_param("signal")
signal_type_name = require_request_param("signal_type")
return lookup_signal(signal, signal_type_name)
include_distance = str_to_bool(request.args.get("include_distance", "false"))
lookup_signal_func = (
lookup_signal_with_distance if include_distance else lookup_signal
)

return {"matches": lookup_signal_func(signal, signal_type_name)}

def lookup_signal(signal: str, signal_type_name: str) -> dict[str, list[int]]:

def query_index(
signal: str, signal_type_name: str
) -> t.Sequence[IndexMatchUntyped[SignalSimilarityInfo, int]]:
storage = get_storage()
signal_type = _validate_and_transform_signal_type(signal_type_name, storage)

Expand All @@ -118,7 +140,25 @@ def lookup_signal(signal: str, signal_type_name: str) -> dict[str, list[int]]:
current_app.logger.debug("[lookup_signal] querying index")
results = index.query(signal)
current_app.logger.debug("[lookup_signal] query complete")
return {"matches": [m.metadata for m in results]}
return results


def lookup_signal(signal: str, signal_type_name: str) -> list[int]:
results = query_index(signal, signal_type_name)
return [m.metadata for m in results]


def lookup_signal_with_distance(
signal: str, signal_type_name: str
) -> list[MatchWithDistance]:
results = query_index(signal, signal_type_name)
return [
{
"content_id": m.metadata,
"distance": m.similarity_info.pretty_str(),
}
for m in results
]


def _validate_and_transform_signal_type(
Expand Down Expand Up @@ -300,6 +340,7 @@ def index_cache_is_stale() -> bool:

def _get_index(signal_type: t.Type[SignalType]) -> SignalTypeIndex[int] | None:
entry = _get_index_cache().get(signal_type.get_name())

if entry is None:
current_app.logger.debug("[lookup_signal] no cache, loading index")
return get_storage().get_signal_type_index(signal_type)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from threatexchange.signal_type.pdq.signal import PdqSignal


def test_raw_hash_add_to_match(app: Flask, client: FlaskClient):
def test_raw_hash_add_to_match_no_distance(app: Flask, client: FlaskClient):
bank_name = "TEST_BANK"
create_bank(client, bank_name)

Expand All @@ -40,3 +40,34 @@ def test_raw_hash_add_to_match(app: Flask, client: FlaskClient):
resp = client.get(f"/m/raw_lookup?signal_type=pdq&signal={hashes[-1]}")
assert resp.status_code == 200
assert resp.json == {"matches": [16]}


def test_raw_hash_add_to_match_with_distance(app: Flask, client: FlaskClient):
bank_name = "TEST_BANK"
create_bank(client, bank_name)

# PDQ hashes
hashes = PdqSignal.get_examples()
for pdq in hashes:
resp = client.post(f"/c/bank/{bank_name}/signal", json={"pdq": pdq})
assert resp.status_code == 200

# No background tasks in the test, so let's trigger it manually
storage = get_storage()
build_all_indices(storage, storage, storage)

# Sanity check that the index is build
resp = client.get(f"/m/index/status?signal_type=pdq")
assert resp.status_code == 200
all_build_info = resp.json
assert "pdq" in all_build_info # type: ignore
build_info = all_build_info["pdq"] # type: ignore
assert build_info["present"] == True
assert build_info["size"] == len(hashes)

# Now match
resp = client.get(
f"/m/raw_lookup?signal_type=pdq&include_distance=true&signal={hashes[-1]}"
)
assert resp.status_code == 200
assert resp.json == {"matches": [{"content_id": 16, "distance": "0"}]}

0 comments on commit 806f30e

Please sign in to comment.