diff --git a/ai2_internal/bibentry_detection_predictor/__init__.py b/ai2_internal/bibentry_detection_predictor/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ai2_internal/bibentry_detection_predictor/data/26bab3c52aa8ff37dc3e155ffbcb506aa1f6.pdf b/ai2_internal/bibentry_detection_predictor/data/26bab3c52aa8ff37dc3e155ffbcb506aa1f6.pdf new file mode 100644 index 00000000..e75b128f Binary files /dev/null and b/ai2_internal/bibentry_detection_predictor/data/26bab3c52aa8ff37dc3e155ffbcb506aa1f6.pdf differ diff --git a/ai2_internal/bibentry_detection_predictor/data/vila_span_groups.json b/ai2_internal/bibentry_detection_predictor/data/vila_span_groups.json new file mode 100644 index 00000000..0a15aac4 --- /dev/null +++ b/ai2_internal/bibentry_detection_predictor/data/vila_span_groups.json @@ -0,0 +1 @@ +{"pdf": "26bab3c52aa8ff37dc3e155ffbcb506aa1f6.pdf", "vila_span_groups": [{"spans": [{"start": 0, "end": 116}], "type": "Header", "uuid": "1ed23454-6f59-4124-9c6d-088dfe5629da"}, {"spans": [{"start": 117, "end": 301}], "type": "Paragraph", "uuid": "095b6ffe-6f28-4887-97d4-14d528a865a0"}, {"spans": [{"start": 302, "end": 399}], "type": "Title", "uuid": "0513f91c-d87b-4a94-88e9-49a03067251a"}, {"spans": [{"start": 400, "end": 507}], "type": "Author", "uuid": "27ef4887-2d93-4519-8d74-543b5c1be8ee"}, {"spans": [{"start": 508, "end": 1255}], "type": "Abstract", "uuid": "caa7d0ca-6d6f-4077-b778-f313d5f3dbb5"}, {"spans": [{"start": 1256, "end": 1320}], "type": "Keywords", "uuid": "f8fc5abc-d5bf-417a-8517-5df1eb623bed"}, {"spans": [{"start": 1321, "end": 1333}], "type": "Section", "uuid": "f1c28f82-a9ca-44a6-9d59-3612f4294793"}, {"spans": [{"start": 1334, "end": 1828}], "type": "Paragraph", "uuid": "349323a7-2c66-48ca-994d-54545c46fca3"}, {"spans": [{"start": 1829, "end": 1850}], "type": "Section", "uuid": "8bea65b2-9e48-4417-aeea-7b2cda3741ea"}, {"spans": [{"start": 1851, "end": 4203}], "type": "Paragraph", "uuid": "ba38ee31-ddc8-4687-9b16-fd15ff66a811"}, {"spans": [{"start": 4204, "end": 4232}], "type": "Bibliography", "uuid": "9b24fde4-21b5-41ea-855f-72aa7b0f63ca"}, {"spans": [{"start": 4233, "end": 4234}], "type": "Paragraph", "uuid": "0492f4a9-0297-457f-9cae-3c1a72f5c2e3"}, {"spans": [{"start": 4235, "end": 4237}], "type": "Bibliography", "uuid": "5c846795-8a7f-4097-a8cb-4499b5632988"}, {"spans": [{"start": 4238, "end": 4241}], "type": "Paragraph", "uuid": "1272c2d7-01bd-4e90-9bd7-ad1d3d6a7048"}, {"spans": [{"start": 4242, "end": 4257}], "type": "Bibliography", "uuid": "be86bb58-0ff3-4857-9986-dfde02aec550"}, {"spans": [{"start": 4258, "end": 4262}], "type": "Footer", "uuid": "00dfba00-2007-4b67-b4d7-269bb73766e3"}, {"spans": [{"start": 4263, "end": 9408}], "type": "Paragraph", "uuid": "39269155-dddc-49c2-872c-47b27846667d"}, {"spans": [{"start": 9409, "end": 9437}], "type": "Bibliography", "uuid": "60a00abf-3a76-4102-88a6-300ae6b58b05"}, {"spans": [{"start": 9438, "end": 9439}], "type": "Paragraph", "uuid": "2be26eef-b033-4e70-a188-cc2d827b57ec"}, {"spans": [{"start": 9440, "end": 9442}], "type": "Bibliography", "uuid": "9cbc9281-374e-44cd-a4b1-da07ab2d1001"}, {"spans": [{"start": 9443, "end": 9446}], "type": "Paragraph", "uuid": "a24d9c5c-05a0-4759-8ccc-c05a35e45b32"}, {"spans": [{"start": 9447, "end": 9462}], "type": "Bibliography", "uuid": "6b697903-c2b9-42bc-9952-784e25868b70"}, {"spans": [{"start": 9463, "end": 9467}], "type": "Footer", "uuid": "ce968cf0-0457-4d40-ac3b-56891e340d89"}, {"spans": [{"start": 9468, "end": 12658}], "type": "Paragraph", "uuid": "97a8fe7a-92db-4005-ae19-040a788fc9cb"}, {"spans": [{"start": 12659, "end": 12660}], "type": "Header", "uuid": "91598648-60bf-4cad-9070-46e17aade785"}, {"spans": [{"start": 12661, "end": 12772}], "type": "Paragraph", "uuid": "02a736ba-50c0-4978-ae4b-81e79682a7e7"}, {"spans": [{"start": 12773, "end": 12928}], "type": "Caption", "uuid": "4e536525-c562-4946-bad4-d9cbadd8a54a"}, {"spans": [{"start": 12929, "end": 12958}], "type": "Table", "uuid": "d0db6692-17ac-4e5c-9c31-d928d04fb405"}, {"spans": [{"start": 12959, "end": 12962}], "type": "Caption", "uuid": "1edfd927-8685-4cc3-956b-fac069140e71"}, {"spans": [{"start": 12963, "end": 13704}], "type": "Table", "uuid": "7f85f4ea-4465-407f-b3d5-004728475e91"}, {"spans": [{"start": 13705, "end": 13738}], "type": "Table", "uuid": "92a36d57-2877-47e2-bd75-125081190b30"}, {"spans": [{"start": 13739, "end": 13792}], "type": "Bibliography", "uuid": "7a14d753-3ddb-48a7-a3d1-9d2a28adf670"}, {"spans": [{"start": 13793, "end": 13797}], "type": "Paragraph", "uuid": "64a40c23-3a0a-46a7-b471-4dd816394b04"}, {"spans": [{"start": 13798, "end": 13977}], "type": "Caption", "uuid": "e21146f0-a983-4bdc-8e86-f337cb7d3d3d"}, {"spans": [{"start": 13978, "end": 14847}], "type": "Table", "uuid": "a7479a05-4d68-4bab-90db-4d1394aafe88"}, {"spans": [{"start": 14848, "end": 15646}], "type": "Paragraph", "uuid": "e2a24a1d-0b96-43c7-a95d-685cb0fce4b5"}, {"spans": [{"start": 15647, "end": 15652}], "type": "Table", "uuid": "1e913424-db4a-4b86-a16b-fc3aba58da05"}, {"spans": [{"start": 15653, "end": 17666}], "type": "Paragraph", "uuid": "54ebaab7-9891-4641-85fc-4605daaf1854"}, {"spans": [{"start": 17667, "end": 17720}], "type": "Bibliography", "uuid": "1787b0fc-9b0a-463a-85b7-4db89f279659"}, {"spans": [{"start": 17721, "end": 17725}], "type": "Footer", "uuid": "30bdcdbb-8a6d-4794-93a1-a4bfd32b058f"}, {"spans": [{"start": 17726, "end": 17729}], "type": "Paragraph", "uuid": "5d118ab4-326d-4e57-92c4-9970731a16da"}, {"spans": [{"start": 17730, "end": 17829}], "type": "Bibliography", "uuid": "58a238c5-6239-4678-b8ba-61c0856e5078"}, {"spans": [{"start": 17830, "end": 18221}], "type": "Paragraph", "uuid": "a3440f56-3498-404d-b505-1106c1e9e525"}, {"spans": [{"start": 18222, "end": 18233}], "type": "Section", "uuid": "e8887279-6f8a-4172-83aa-99c32835ccc2"}, {"spans": [{"start": 18234, "end": 22386}], "type": "Bibliography", "uuid": "605304ac-ad21-47be-9ec3-aa61cb03ab95"}, {"spans": [{"start": 22387, "end": 22440}], "type": "Bibliography", "uuid": "c88faee7-4503-42d4-bfe5-8fe24aaa87c6"}, {"spans": [{"start": 22441, "end": 22445}], "type": "Footer", "uuid": "11fce7d0-9e5a-4629-bfde-a1f032666037"}, {"spans": [{"start": 22446, "end": 24816}], "type": "Bibliography", "uuid": "87e43fd8-cc46-4439-8c3d-8366be53f8e7"}]} \ No newline at end of file diff --git a/ai2_internal/bibentry_detection_predictor/integration_test.py b/ai2_internal/bibentry_detection_predictor/integration_test.py new file mode 100644 index 00000000..d7f69ae8 --- /dev/null +++ b/ai2_internal/bibentry_detection_predictor/integration_test.py @@ -0,0 +1,99 @@ +""" +Write integration tests for your model interface code here. + +The TestCase class below is supplied a `container` +to each test method. This `container` object is a proxy to the +Dockerized application running your model. It exposes a single method: + +``` +predict_batch(instances: List[Instance]) -> List[Prediction] +``` + +To test your code, create `Instance`s and make normal `TestCase` +assertions against the returned `Prediction`s. + +e.g. + +``` +def test_prediction(self, container): + instances = [Instance(), Instance()] + predictions = container.predict_batch(instances) + + self.assertEqual(len(instances), len(predictions) + + self.assertEqual(predictions[0].field1, "asdf") + self.assertGreatEqual(predictions[1].field2, 2.0) +``` +""" + +import json +import logging +import os +import pathlib +import sys +import unittest + +from .interface import Instance + +from mmda.parsers.pdfplumber_parser import PDFPlumberParser +from mmda.rasterizers.rasterizer import PDF2ImageRasterizer +from mmda.types import api +from mmda.types.image import tobase64 + +try: + from timo_interface import with_timo_container +except ImportError as e: + logging.warning(""" + This test can only be run by a TIMO test runner. No tests will run. + You may need to add this file to your project's pytest exclusions. + """) + sys.exit(0) + +pdf = "26bab3c52aa8ff37dc3e155ffbcb506aa1f6.pdf" + + +def resolve(file: str) -> str: + return os.path.join(pathlib.Path(os.path.dirname(__file__)), "data", file) + + +@with_timo_container +class TestInterfaceIntegration(unittest.TestCase): + + def get_images(self): + rasterizer = PDF2ImageRasterizer() + return rasterizer.rasterize(str(resolve(pdf)), dpi=72) + + def test__predictions(self, container): + doc = PDFPlumberParser(split_at_punctuation=True).parse(resolve(pdf)) + + tokens = [api.SpanGroup.from_mmda(sg) for sg in doc.tokens] + rows = [api.SpanGroup.from_mmda(sg) for sg in doc.rows] + pages = [api.SpanGroup.from_mmda(sg) for sg in doc.pages] + + page_images = self.get_images() + encoded_page_images = [tobase64(img) for img in page_images] + + doc.annotate_images(page_images) + + with open(resolve("vila_span_groups.json")) as f: + vila_span_groups = [api.SpanGroup(**sg) for sg in json.load(f)["vila_span_groups"]] + + instances = [Instance( + symbols=doc.symbols, + tokens=tokens, + rows=rows, + pages=pages, + page_images=encoded_page_images, + vila_span_groups=vila_span_groups)] + + predictions = container.predict_batch(instances) + + for bib_entry in predictions[0].bib_entry_boxes: + self.assertEqual(bib_entry.type, "bib_entry") + + for raw_box in predictions[0].raw_bib_entry_boxes: + self.assertEqual(raw_box.type, "raw_model_prediction") + + number_of_found_bib_boxes = 31 + self.assertEqual(len(predictions[0].bib_entry_boxes), number_of_found_bib_boxes) + self.assertEqual(len(predictions[0].raw_bib_entry_boxes), number_of_found_bib_boxes) diff --git a/ai2_internal/bibentry_detection_predictor/interface.py b/ai2_internal/bibentry_detection_predictor/interface.py new file mode 100644 index 00000000..173ffa81 --- /dev/null +++ b/ai2_internal/bibentry_detection_predictor/interface.py @@ -0,0 +1,124 @@ +""" +This file contains the classes required by Semantic Scholar's +TIMO tooling. + +You must provide a wrapper around your model, as well +as a definition of the objects it expects, and those it returns. +""" + +from typing import List + +from pydantic import BaseModel, BaseSettings, Field + +from mmda.predictors.d2_predictors.bibentry_detection_predictor import BibEntryDetectionPredictor +from mmda.types import api, image +from mmda.types.document import Document + + +class Instance(BaseModel): + """ + Describes one Instance over which the model performs inference. + + The fields below are examples only; please replace them with + appropriate fields for your model. + + To learn more about declaring pydantic model fields, please see: + https://pydantic-docs.helpmanual.io/ + """ + + symbols: str + tokens: List[api.SpanGroup] + rows: List[api.SpanGroup] + pages: List[api.SpanGroup] + vila_span_groups: List[api.SpanGroup] + page_images: List[str] = Field(description="List of base64-encoded page images") + + +class Prediction(BaseModel): + """ + Describes the outcome of inference for one Instance + """ + bib_entry_boxes: List[api.BoxGroup] + raw_bib_entry_boxes: List[api.BoxGroup] + + +class PredictorConfig(BaseSettings): + """ + Configuration required by the model to do its work. + Uninitialized fields will be set via Environment variables. + + These serve as a record of the ENV + vars the consuming application needs to set. + """ + + BIB_ENTRY_DETECTION_PREDICTOR_SCORE_THRESHOLD: float = Field(default=0.88, description="Prediction accuracy score used to determine threshold of returned predictions") + BIB_ENTRY_DETECTION_MIN_VILA_BIB_ROWS: int = Field(default=2, description="Minimum number of rows in a Bibliography VILA SpanGroup required to qualify as a Bibliography section") + + +class Predictor: + """ + Interface on to your underlying model. + + This class is instantiated at application startup as a singleton. + You should initialize your model inside of it, and implement + prediction methods. + + If you specified an artifacts.tar.gz for your model, it will + have been extracted to `artifacts_dir`, provided as a constructor + arg below. + """ + + _config: PredictorConfig + _artifacts_dir: str + + def __init__(self, config: PredictorConfig, artifacts_dir: str): + self._config = config + self._artifacts_dir = artifacts_dir + self._load_model() + + def _load_model(self) -> None: + """ + Perform whatever start-up operations are required to get your + model ready for inference. This operation is performed only once + during the application life-cycle. + """ + self._predictor = BibEntryDetectionPredictor(self._artifacts_dir, self._config.BIB_ENTRY_DETECTION_PREDICTOR_SCORE_THRESHOLD) + + def predict_one(self, inst: Instance) -> Prediction: + """ + Should produce a single Prediction for the provided Instance. + Leverage your underlying model to perform this inference. + """ + doc = Document(symbols=inst.symbols) + doc.annotate(tokens=[sg.to_mmda() for sg in inst.tokens]) + doc.annotate(rows=[sg.to_mmda() for sg in inst.rows]) + doc.annotate(pages=[sg.to_mmda() for sg in inst.pages]) + images = [image.frombase64(im) for im in inst.page_images] + doc.annotate_images(images) + doc.annotate(vila_span_groups=[sg.to_mmda() for sg in inst.vila_span_groups]) + + processed_bib_entry_box_groups, original_box_groups = self._predictor.predict(doc, self._config.BIB_ENTRY_DETECTION_MIN_VILA_BIB_ROWS) + + prediction = Prediction( + bib_entry_boxes=[api.BoxGroup.from_mmda(bg) for bg in processed_bib_entry_box_groups], + raw_bib_entry_boxes=[api.BoxGroup.from_mmda(bg) for bg in original_box_groups]) + + return prediction + + def predict_batch(self, instances: List[Instance]) -> List[Prediction]: + """ + Method called by the client application. One or more Instances will + be provided, and the caller expects a corresponding Prediction for + each one. + + If your model gets performance benefits from batching during inference, + implement that here, explicitly. + + Otherwise, you can leave this method as-is and just implement + `predict_one()` above. The default implementation here passes + each Instance into `predict_one()`, one at a time. + + The size of the batches passed into this method is configurable + via environment variable by the calling application. + """ + return [self.predict_one(instance) for instance in instances] diff --git a/ai2_internal/config.yaml b/ai2_internal/config.yaml index 8b1ebd51..c93337f5 100644 --- a/ai2_internal/config.yaml +++ b/ai2_internal/config.yaml @@ -167,3 +167,51 @@ model_variants: # One or more bash commands to execute as part of a RUN step in a Dockerfile. docker_run_commands: [] + + + bibentry_detection_predictor: + # Class path to pydantic Instance implementation in == + instance: ai2_internal.bibentry_detection_predictor.interface.Instance + + # Class path to pydantic Prediction implementation in == + prediction: ai2_internal.bibentry_detection_predictor.interface.Prediction + + # Class path to Predictor implementation in == + predictor: ai2_internal.bibentry_detection_predictor.interface.Predictor + + # Class path to pydantic PredictorConfig implementation in == + predictor_config: ai2_internal.bibentry_detection_predictor.interface.PredictorConfig + + # Full S3 path to tar.gz'ed artifacts archive, nullable + artifacts_s3_path: s3://ai2-s2-mmda/models/bibliography-entries/outputs/p3_2xl/bibs/silver_dataset_8k/one_category/pln_mask_rcnn_X_101_32x8d_FPN_3x/archive.tar.gz + + # Version of python required for model runtime, e.g. 3.7, 3.8, 3.9 + python_version: "3.8" + + # Whether this model supports CUDA GPU acceleration + cuda: true + + # One of the versions here: https://gitlab.com/nvidia/container-images/cuda/blob/master/doc/supported-tags.md#ubuntu2004, but less than 11.4.3. + # If cuda=True and cuda_version is unspecified, defaults to 11.4.2. + cuda_version: "11.1.1" + + # Python path to a fn in == that + # returns a unittest.TestCase. Builder function receives a model container + # as its sole argument. + # Used by the TIMO toolchain to validate your model implementation and configuration. + integration_test: ai2_internal.bibentry_detection_predictor.integration_test.TestInterfaceIntegration + + # One or more bash commands to execute as part of a RUN step in a Dockerfile AFTER extras require. + # Leave this unset unless your model has special system requirements beyond + # those in your setup.py. + + # pip installing dependencies listed in setup.py so that detectron2 installs successfully + docker_run_commands: [ "apt-get update && apt-get install -y poppler-utils libgl1", + "pip install layoutparser", + "pip install torch==1.8.0", + "pip install torchvision==0.9.0", + "pip install 'detectron2@git+https://github.com/facebookresearch/detectron2.git@v0.6#egg=detectron2'"] + + # Any additional sets of dependencies required by the model. + # These are the 'extras_require' keys in your setup.py. + extras_require: [ "bibentry_detection_predictor" ] diff --git a/mmda/predictors/d2_predictors/__init__.py b/mmda/predictors/d2_predictors/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mmda/predictors/d2_predictors/bibentry_detection_predictor.py b/mmda/predictors/d2_predictors/bibentry_detection_predictor.py new file mode 100644 index 00000000..8b7e8b47 --- /dev/null +++ b/mmda/predictors/d2_predictors/bibentry_detection_predictor.py @@ -0,0 +1,181 @@ +from functools import reduce +import itertools +from typing import Union, List, Dict, Any, Optional + +import layoutparser as lp + +from mmda.predictors.base_predictors.base_predictor import BasePredictor +from mmda.types.annotation import BoxGroup +from mmda.types.box import Box +from mmda.types.document import Document +from mmda.types.names import Pages, Images, Tokens +from mmda.types.span import Span + + +def union(block1, block2): + x11, y11, x12, y12 = block1.coordinates + x21, y21, x22, y22 = block2.coordinates + + return lp.Rectangle(min(x11, x21), min(y11, y21), max(x12, x22), max(y12, y22)) + + +def union_blocks(blocks): + return reduce(union, blocks) + + +def make_rect(box: Box, page_width, page_height): + box = box.get_absolute(page_width, page_height) + rect = lp.elements.Rectangle(x_1=box.l, y_1=box.t, x_2=(box.l + box.w), y_2=(box.t + box.h)) + + return rect + + +def tighten_boxes(bib_box_group, page_tokens, page_width, page_height): + page_token_rects = [make_rect(span.box, page_width, page_height) for span in page_tokens] + page_tokens_as_layout = lp.elements.Layout(blocks=page_token_rects) + + new_boxes = [] + for box in bib_box_group.boxes: + abs_box = box.get_absolute(page_width, page_height) + rect = lp.elements.Rectangle( + abs_box.l, + abs_box.t, + abs_box.l + abs_box.w, + abs_box.t + abs_box.h + ) + new_rect = union_blocks(page_tokens_as_layout.filter_by(rect, center=True)) + new_boxes.append( + Box(l=new_rect.x_1, + t=new_rect.y_1, + w=new_rect.width, + h=new_rect.height, + page=box.page).get_relative( + page_width=page_width, + page_height=page_height, + ) + ) + new_box_group = BoxGroup( + boxes=new_boxes, + id=bib_box_group.id, + type="bib_entry" + + ) + return new_box_group + + +class BibEntryDetectionPredictor(BasePredictor): + REQUIRED_BACKENDS = ["layoutparser", "detectron2"] + REQUIRED_DOCUMENT_FIELDS = [Pages, Images, Tokens] + + def __init__(self, artifacts_dir: str, threshold: float = 0.88): + label_map = {0: "bibentry"} + + self.model = lp.Detectron2LayoutModel( + config_path=f"{artifacts_dir}/archive/config.yaml", + model_path=f"{artifacts_dir}/archive/model_final.pth", + label_map=label_map, + extra_config=["MODEL.ROI_HEADS.SCORE_THRESH_TEST", threshold] + ) + + def postprocess(self, + model_outputs: lp.Layout, + page_tokens: List[Span], + page_index: int, + image: "PIL.Image") -> (List[BoxGroup], List[BoxGroup]): + """Convert the model outputs for a single page image into the mmda format + + Args: + model_outputs (lp.Layout): + The layout detection results from the layoutparser model for + a page image + page_tokens (List[Span]): + List of the Document's Token spans for this Page + page_index (int): + The index of the current page, used for creating the + `Box` object + image (PIL.Image): + The image of the current page, used for converting + to relative coordinates for the box objects + + Returns: + (List[BoxGroup], List[BoxGroup]): + A tuple of the BoxGroups detected bibentry boxes tightened around + tokens, and the BoxGroups containing the originally detected, unprocessed model output boxes. + """ + id_counter = itertools.count() + original_box_groups: List[BoxGroup] = [] + page_width, page_height = image.size + + for ele in model_outputs: + model_output_box = Box.from_coordinates( + x1=ele.block.x_1, + y1=ele.block.y_1, + x2=ele.block.x_2, + y2=ele.block.y_2, + page=page_index + ).get_relative( + page_width=page_width, + page_height=page_height, + ) + + current_id = next(id_counter) + + original_box_groups.append( + BoxGroup( + boxes=[model_output_box], + id=current_id, + type="raw_model_prediction" + ) + ) + + processed_box_groups: List[BoxGroup] = [] + for o_box_group in original_box_groups: + tightened_box_group = tighten_boxes(o_box_group, page_tokens, page_width, page_height) + processed_box_groups.append(tightened_box_group) + + return processed_box_groups, original_box_groups + + def predict(self, doc: Document, min_vila_bib_rows: int) -> (List[BoxGroup], List[BoxGroup]): + """Returns a list of BoxGroups for the detected bibentry boxes for pages identified as bib containing pages + via VILA heuristic (pages with "Bibliography" Vila SpanGroups that span more rows than min_vila_bib_rows), + and second list of BoxGroups for original model output boxes from those same pages. + + Args: + doc (Document): + The input document object containing all required annotations + min_vila_bib_rows (int): + Minimum number of rows in a Bibliography VILA SpanGroup required to qualify as a Bibliography section + + Returns: + (List[BoxGroup], List[BoxGroup]): + A tuple of the BoxGroups containing bibentry boxes tightened around + tokens, and the BoxGroups containing the originally detected, unprocessed model output boxes. + """ + bib_entries: List[BoxGroup] = [] + original_model_output: List[BoxGroup] = [] + + vila_bib_sgs = [sg for sg in doc.vila_span_groups if + sg.type == "Bibliography" and (len(sg.rows) > min_vila_bib_rows)] + vila_bib_pgs = set([sg.rows[0].spans[0].box.page for sg in vila_bib_sgs]) + vila_bib_pg_to_image = {page_index: doc.images[page_index] for page_index in vila_bib_pgs} + + for page_index, image in vila_bib_pg_to_image.items(): + model_outputs: lp.Layout = self.model.detect(image) + page_tokens: List[Span] = list( + itertools.chain.from_iterable( + token_span_group.spans + for token_span_group in doc.pages[page_index].tokens + ) + ) + + bib_entry_box_groups, og_box_groups = self.postprocess( + model_outputs, + page_tokens, + page_index, + image + ) + + bib_entries.extend(bib_entry_box_groups) + original_model_output.extend(og_box_groups) + + return bib_entries, original_model_output diff --git a/setup.py b/setup.py index 3033884f..96322fc1 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name="mmda", description="mmda", - version="0.0.26.1", + version="0.0.28", url="https://www.github.com/allenai/mmda", python_requires=">= 3.7", packages=find_namespace_packages(include=["mmda*", "ai2_internal*"]), @@ -24,10 +24,12 @@ "mention_predictor": ["transformers[torch]", "optimum[onnxruntime]"], "mention_predictor_gpu": ["transformers[torch]", "optimum[onnxruntime-gpu]"], "bibentry_predictor": ["transformers", "unidecode", "torch"], + "bibentry_detection_predictor": ["layoutparser", "torch==1.8.0", "torchvision==0.9.0"], "citation_links": ["numpy", "thefuzz[speedup]", "sklearn", "xgboost"], }, include_package_data=True, package_data={ + "ai2_internal.bibentry_detection_predictor.data": ["*"], "ai2_internal.citation_mentions.data": ["*"], "ai2_internal.vila.test_fixtures": ["*"], "ai2_internal.shared_test_fixtures": ["*"]