Skip to content

Commit

Permalink
TIMO-ifies heuristic DWP predictor (#248)
Browse files Browse the repository at this point in the history
Same behavior as implementation in SPP, but
can be stood up as a TIMO service.
  • Loading branch information
cmwilhelm authored May 17, 2023
1 parent 87c888b commit 33bbd81
Show file tree
Hide file tree
Showing 7 changed files with 213 additions and 1 deletion.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = 'mmda'
version = '0.4.9'
version = '0.5.0'
description = 'MMDA - multimodal document analysis'
authors = [
{name = 'Allen Institute for Artificial Intelligence', email = '[email protected]'},
Expand Down Expand Up @@ -42,6 +42,7 @@ ai2_internal = [
'./bibentry_predictor_mmda/data/*',
'./citation_mentions/data/*',
'./vila/test_fixtures/*',
'./dwp_heuristic/test_fixtures/*',
'./figure_table_predictors/test_fixtures/*',
'./figure_table_predictors/test_fixtures.images/*',
'./shared_test_fixtures/*',
Expand Down
42 changes: 42 additions & 0 deletions src/ai2_internal/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -278,3 +278,45 @@ model_variants:
# Any additional sets of dependencies required by the model.
# These are the 'extras_require' keys in your setup.py.
extras_require: ["figure_table_predictors"]

dwp-heuristic:
# Class path to pydantic Instance implementation in <model_package_name>==<model_package_version>
instance: ai2_internal.dwp_heuristic.interface.Instance

# Class path to pydantic Prediction implementation in <model_package_name>==<model_package_version>
prediction: ai2_internal.dwp_heuristic.interface.Prediction

# Class path to Predictor implementation in <model_package_name>==<model_package_version>
predictor: ai2_internal.dwp_heuristic.interface.Predictor

# Class path to pydantic PredictorConfig implementation in <model_package_name>==<model_package_version>
predictor_config: ai2_internal.dwp_heuristic.interface.PredictorConfig

# One or more bash commands to execute as part of a RUN step in a Dockerfile.
# Leave this unset unless your model has special system requirements beyond
# those in your setup.py.
docker_run_commands: []

# Any additional sets of dependencies required by the model.
# These are the 'extras_require' keys in your setup.py.
extras_require: ["heuristic_predictors"]

# Full S3 path to tar.gz'ed artifacts archive, nullable
artifacts_s3_path: null

# Version of python required for model runtime, e.g. 3.7, 3.8, 3.9
python_version: 3.8

# Whether this model supports CUDA GPU acceleration
cuda: false

# One of the versions here: https://gitlab.com/nvidia/container-images/cuda/blob/master/doc/supported-tags.md#ubuntu2004, but less than 11.4.3.
# If cuda=True and cuda_version is unspecified, defaults to 11.4.2.
cuda_version: null

# Python path to a fn in <model_package_name>==<model_package_version> that
# returns a unittest.TestCase. Builder function receives a model container
# as its sole argument.
# Used by the TIMO toolchain to validate your model implementation and configuration.
integration_test: ai2_internal.dwp_heuristic.integration_test.TestInterfaceIntegration

4 changes: 4 additions & 0 deletions src/ai2_internal/dwp_heuristic/README.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
This directory contains files required internally by the
Semantic Scholar product.

It can safely be ignored by external consumers of this repository.
Empty file.
90 changes: 90 additions & 0 deletions src/ai2_internal/dwp_heuristic/integration_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""
Write integration tests for your model interface code here.
The TestCase class below is supplied a `container`
to each test method. This `container` object is a proxy to the
Dockerized application running your model. It exposes a single method:
```
predict_batch(instances: List[Instance]) -> List[Prediction]
```
To test your code, create `Instance`s and make normal `TestCase`
assertions against the returned `Prediction`s.
e.g.
```
def test_prediction(self, container):
instances = [Instance(), Instance()]
predictions = container.predict_batch(instances)
self.assertEqual(len(instances), len(predictions)
self.assertEqual(predictions[0].field1, "asdf")
self.assertGreatEqual(predictions[1].field2, 2.0)
```
"""

import gzip
import json
import logging
import os
import sys
import unittest

from .interface import Instance, Prediction
from ai2_internal import api
from mmda.types.document import Document


FIXTURE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_fixtures")


try:
from timo_interface import with_timo_container
except ImportError as e:
logging.warning("""
This test can only be run by a TIMO test runner. No tests will run.
You may need to add this file to your project's pytest exclusions.
""")
sys.exit(0)


def resolve(file: str) -> str:
return os.path.join(os.path.dirname(__file__), "test_fixtures", file)


def read_fixture_doc(filename):
path = resolve(filename)

with gzip.open(path, "r") as f:
doc_json = json.loads(f.read())

doc = Document.from_json(doc_json)

return doc


@with_timo_container
class TestInterfaceIntegration(unittest.TestCase):
def test__predictions(self, container):
doc_file = resolve("test_doc.json")
with open(doc_file) as f:
doc = Document.from_json(json.load(f))

tokens = [api.SpanGroup.from_mmda(sg) for sg in doc.tokens]
rows = [api.SpanGroup.from_mmda(sg) for sg in doc.rows]

instances = [Instance(
symbols=doc.symbols,
tokens=tokens,
rows=rows
)]

predictions = container.predict_batch(instances)

prediction = predictions[0]

self.assertTrue(len(prediction.words) <= len(tokens))
self.assertTrue(all([w.text is not None for w in prediction.words]))
74 changes: 74 additions & 0 deletions src/ai2_internal/dwp_heuristic/interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from typing import List

from pydantic import BaseModel, BaseSettings, Field

from ai2_internal.api import SpanGroup
from mmda.predictors.heuristic_predictors.dictionary_word_predictor import (
DictionaryWordPredictor
)
from mmda.types.document import Document


class Instance(BaseModel):
"""
Inference input for a single paper.
"""
symbols: str = Field(description="Extracted PDF document text")
tokens: List[SpanGroup] = Field(description="The tokens to coerce into words")
rows: List[SpanGroup] = Field(description="Document rows, used as signal for determining word boundaries")


class Prediction(BaseModel):
"""
Inference output for a single paper.
"""
words: List[SpanGroup] = Field(description="Input tokens coerced into words. Includes cleaned-up text.")


class PredictorConfig(BaseSettings):
"""
no-op for this model
"""
pass


class Predictor:
"""
This class is instantiated at application startup as a singleton,
and is used by the TIMO framework to interface with the underlying
DWP predictor.
"""

_config: PredictorConfig
_artifacts_dir: str

def __init__(self, config: PredictorConfig, artifacts_dir: str):
self._config = config
self._artifacts_dir = artifacts_dir
self._load_model()

def _load_model(self) -> None:
self._predictor = DictionaryWordPredictor("/dev/null")

def predict_one(self, instance: Instance) -> Prediction:
doc = Document(instance.symbols)
doc.annotate(tokens=[t.to_mmda() for t in instance.tokens])
doc.annotate(rows=[r.to_mmda() for r in instance.rows])
words = self._predictor.predict(doc)

# RE: https://github.com/allenai/scholar/issues/36200
for word in words:
if word.text:
word.text = word.text.replace("\u0000", "")

return Prediction(
words=[SpanGroup.from_mmda(w) for w in words]
)

def predict_batch(self, instances: List[Instance]) -> List[Prediction]:
"""
Method called by the client application. One or more Instances will
be provided, and the caller expects a corresponding Prediction for
each one.
"""
return [self.predict_one(instance) for instance in instances]
1 change: 1 addition & 0 deletions src/ai2_internal/dwp_heuristic/test_fixtures/test_doc.json

Large diffs are not rendered by default.

0 comments on commit 33bbd81

Please sign in to comment.