-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
TIMO-ifies heuristic DWP predictor (#248)
Same behavior as implementation in SPP, but can be stood up as a TIMO service.
- Loading branch information
Showing
7 changed files
with
213 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
[project] | ||
name = 'mmda' | ||
version = '0.4.9' | ||
version = '0.5.0' | ||
description = 'MMDA - multimodal document analysis' | ||
authors = [ | ||
{name = 'Allen Institute for Artificial Intelligence', email = '[email protected]'}, | ||
|
@@ -42,6 +42,7 @@ ai2_internal = [ | |
'./bibentry_predictor_mmda/data/*', | ||
'./citation_mentions/data/*', | ||
'./vila/test_fixtures/*', | ||
'./dwp_heuristic/test_fixtures/*', | ||
'./figure_table_predictors/test_fixtures/*', | ||
'./figure_table_predictors/test_fixtures.images/*', | ||
'./shared_test_fixtures/*', | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
This directory contains files required internally by the | ||
Semantic Scholar product. | ||
|
||
It can safely be ignored by external consumers of this repository. |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
""" | ||
Write integration tests for your model interface code here. | ||
The TestCase class below is supplied a `container` | ||
to each test method. This `container` object is a proxy to the | ||
Dockerized application running your model. It exposes a single method: | ||
``` | ||
predict_batch(instances: List[Instance]) -> List[Prediction] | ||
``` | ||
To test your code, create `Instance`s and make normal `TestCase` | ||
assertions against the returned `Prediction`s. | ||
e.g. | ||
``` | ||
def test_prediction(self, container): | ||
instances = [Instance(), Instance()] | ||
predictions = container.predict_batch(instances) | ||
self.assertEqual(len(instances), len(predictions) | ||
self.assertEqual(predictions[0].field1, "asdf") | ||
self.assertGreatEqual(predictions[1].field2, 2.0) | ||
``` | ||
""" | ||
|
||
import gzip | ||
import json | ||
import logging | ||
import os | ||
import sys | ||
import unittest | ||
|
||
from .interface import Instance, Prediction | ||
from ai2_internal import api | ||
from mmda.types.document import Document | ||
|
||
|
||
FIXTURE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_fixtures") | ||
|
||
|
||
try: | ||
from timo_interface import with_timo_container | ||
except ImportError as e: | ||
logging.warning(""" | ||
This test can only be run by a TIMO test runner. No tests will run. | ||
You may need to add this file to your project's pytest exclusions. | ||
""") | ||
sys.exit(0) | ||
|
||
|
||
def resolve(file: str) -> str: | ||
return os.path.join(os.path.dirname(__file__), "test_fixtures", file) | ||
|
||
|
||
def read_fixture_doc(filename): | ||
path = resolve(filename) | ||
|
||
with gzip.open(path, "r") as f: | ||
doc_json = json.loads(f.read()) | ||
|
||
doc = Document.from_json(doc_json) | ||
|
||
return doc | ||
|
||
|
||
@with_timo_container | ||
class TestInterfaceIntegration(unittest.TestCase): | ||
def test__predictions(self, container): | ||
doc_file = resolve("test_doc.json") | ||
with open(doc_file) as f: | ||
doc = Document.from_json(json.load(f)) | ||
|
||
tokens = [api.SpanGroup.from_mmda(sg) for sg in doc.tokens] | ||
rows = [api.SpanGroup.from_mmda(sg) for sg in doc.rows] | ||
|
||
instances = [Instance( | ||
symbols=doc.symbols, | ||
tokens=tokens, | ||
rows=rows | ||
)] | ||
|
||
predictions = container.predict_batch(instances) | ||
|
||
prediction = predictions[0] | ||
|
||
self.assertTrue(len(prediction.words) <= len(tokens)) | ||
self.assertTrue(all([w.text is not None for w in prediction.words])) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
from typing import List | ||
|
||
from pydantic import BaseModel, BaseSettings, Field | ||
|
||
from ai2_internal.api import SpanGroup | ||
from mmda.predictors.heuristic_predictors.dictionary_word_predictor import ( | ||
DictionaryWordPredictor | ||
) | ||
from mmda.types.document import Document | ||
|
||
|
||
class Instance(BaseModel): | ||
""" | ||
Inference input for a single paper. | ||
""" | ||
symbols: str = Field(description="Extracted PDF document text") | ||
tokens: List[SpanGroup] = Field(description="The tokens to coerce into words") | ||
rows: List[SpanGroup] = Field(description="Document rows, used as signal for determining word boundaries") | ||
|
||
|
||
class Prediction(BaseModel): | ||
""" | ||
Inference output for a single paper. | ||
""" | ||
words: List[SpanGroup] = Field(description="Input tokens coerced into words. Includes cleaned-up text.") | ||
|
||
|
||
class PredictorConfig(BaseSettings): | ||
""" | ||
no-op for this model | ||
""" | ||
pass | ||
|
||
|
||
class Predictor: | ||
""" | ||
This class is instantiated at application startup as a singleton, | ||
and is used by the TIMO framework to interface with the underlying | ||
DWP predictor. | ||
""" | ||
|
||
_config: PredictorConfig | ||
_artifacts_dir: str | ||
|
||
def __init__(self, config: PredictorConfig, artifacts_dir: str): | ||
self._config = config | ||
self._artifacts_dir = artifacts_dir | ||
self._load_model() | ||
|
||
def _load_model(self) -> None: | ||
self._predictor = DictionaryWordPredictor("/dev/null") | ||
|
||
def predict_one(self, instance: Instance) -> Prediction: | ||
doc = Document(instance.symbols) | ||
doc.annotate(tokens=[t.to_mmda() for t in instance.tokens]) | ||
doc.annotate(rows=[r.to_mmda() for r in instance.rows]) | ||
words = self._predictor.predict(doc) | ||
|
||
# RE: https://github.com/allenai/scholar/issues/36200 | ||
for word in words: | ||
if word.text: | ||
word.text = word.text.replace("\u0000", "") | ||
|
||
return Prediction( | ||
words=[SpanGroup.from_mmda(w) for w in words] | ||
) | ||
|
||
def predict_batch(self, instances: List[Instance]) -> List[Prediction]: | ||
""" | ||
Method called by the client application. One or more Instances will | ||
be provided, and the caller expects a corresponding Prediction for | ||
each one. | ||
""" | ||
return [self.predict_one(instance) for instance in instances] |
Large diffs are not rendered by default.
Oops, something went wrong.