Skip to content

Commit

Permalink
Fixed issue with failing initialization when metadata is provided to …
Browse files Browse the repository at this point in the history
…span group (#126)

* fixed issue with failing initialization when metadata was provided

* forgot to add support for box groups
  • Loading branch information
soldni authored Aug 24, 2022
1 parent 2880069 commit 499b7f3
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 38 deletions.
66 changes: 34 additions & 32 deletions mmda/types/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,25 +114,27 @@ def to_json(self) -> Dict:

@classmethod
def from_json(cls, box_group_dict: Dict) -> "BoxGroup":
return BoxGroup(

if "metadata" in box_group_dict:
metadata_dict = box_group_dict["metadata"]
else:
# this fallback is necessary to ensure compatibility with box
# groups that were create before the metadata migration and
# therefore have "id", "type" in the root of the json dict instead.
metadata_dict = {
"id": box_group_dict.get("id", None),
"type": box_group_dict.get("type", None),
"text": box_group_dict.get("text", None)
}

return cls(
boxes=[
Box.from_json(box_coords=box_dict)
for box_dict in box_group_dict["boxes"]
# box_group_dict["boxes"] might not be present since we
# minimally serialize when running to_json()
for box_dict in box_group_dict.get("boxes", [])
],
metadata=Metadata.from_json(
box_group_dict.get(
"metadata",
# this fallback is necessary to ensure compatibility with
# box groups that were create before the metadata
# migration and therefore have "id", "type" in the
# root of the json dict instead.
{
key: box_group_dict[key]
for key in ("id", "type")
if key in box_group_dict
},
)
),
metadata=Metadata.from_json(metadata_dict),
uuid=box_group_dict.get("uuid", str(uuid4())),
)

Expand Down Expand Up @@ -165,7 +167,7 @@ def _text_span_group_getter(span_group: "SpanGroup") -> str:
return maybe_text if maybe_text else " ".join(span_group.symbols)


# NOTE[LucaS]: by using the store_field_in_metadata decorator, we are
# NOTE[@soldni]: by using the store_field_in_metadata decorator, we are
# able to store id and type in the metadata of BoxGroup, while keeping it
# accessible via SpanGroup.id and SpanGroup.type respectively. This is
# useful because it keeps backward compatibility with the old API, while
Expand Down Expand Up @@ -227,25 +229,25 @@ def from_json(cls, span_group_dict: Dict) -> "SpanGroup":
box_group = BoxGroup.from_json(box_group_dict=box_group_dict)
else:
box_group = None
return SpanGroup(

if "metadata" in span_group_dict:
metadata_dict = span_group_dict["metadata"]
else:
# this fallback is necessary to ensure compatibility with span
# groups that were create before the metadata migration and
# therefore have "id", "type" in the root of the json dict instead.
metadata_dict = {
"id": span_group_dict.get("id", None),
"type": span_group_dict.get("type", None),
"text": span_group_dict.get("text", None)
}

return cls(
spans=[
Span.from_json(span_dict=span_dict)
for span_dict in span_group_dict["spans"]
],
metadata=Metadata.from_json(
span_group_dict.get(
"metadata",
# this fallback is necessary to ensure compatibility with
# span groups that were create before the metadata
# migration and therefore have "id", "type" in the
# root of the json dict instead.
{
key: span_group_dict[key]
for key in ("id", "type", "text")
if key in span_group_dict
},
)
),
metadata=Metadata.from_json(metadata_dict),
box_group=box_group,
uuid=span_group_dict.get("uuid", str(uuid4())),
)
Expand Down
109 changes: 107 additions & 2 deletions mmda/types/metadata.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from copy import deepcopy
from dataclasses import MISSING, Field, fields, is_dataclass
from functools import wraps
import inspect

from typing import (
Any,
Callable,
Expand Down Expand Up @@ -202,12 +205,37 @@ class MyDataclass:
to the metadata dictionary.
"""

def wrapper(
def wrapper_fn(
cls_: Type[T],
wrapper_field_name: str = field_name,
wrapper_getter_fn: Optional[Callable[[T], Any]] = getter_fn,
wrapper_setter_fn: Optional[Callable[[T, Any], None]] = setter_fn,
) -> Type[T]:
"""
This wrapper consists of three steps:
1. Basic checks to determine if a field can be stored in the metadata.
This includes checking that cls_ is a dataclass, that cls_ has a
metadata attribute, and that the field is a field of cls_.
2. Wrap the init method of cls_ to ensure that, if a field is specified
in the metadata, it is *NOT* overwritten by the default value of the
field. (keep reading through this function for more details)
3. Create getter and setter methods for the field; these are going to
override the original attribute and will be responsible for querying
the metadata for the value of the field.
Args:
cls_ (Type[T]): The dataclass to wrap.
wrapper_field_name (str): The name of the field to store in the
metadata.
wrapper_getter_fn (Optional[Callable[[T], Any]]): A function that
takes returns the value of the field. If None, the getter
is a simple lookup in the metadata dictionary.
wrapper_setter_fn (Optional[Callable[[T, Any], None]]): A function
that is used to set the value of the field. If None, the setter
is a simple addition to the metadata dictionary.
"""

# # # # # # # # # # # # STEP 1: BASIC CHECKS # # # # # # # # # # # # #
if not (is_dataclass(cls_)):
raise TypeError("add_deprecated_field only works on dataclasses")

Expand All @@ -233,7 +261,82 @@ def wrapper(
f"add_deprecated_field requires a `{wrapper_field_name}` field"
"in the dataclass."
)
# # # # # # # # # # # # # # END OF STEP 1 # # # # # # # # # # # # # # #

# # # # # # # # # # # # # STEP 2: WRAP INIT # # # # # # # # # # # # # #
# In the following comment, we explain the need for step 2.
#
# We want to make sure that if a field is specified in the metadata,
# the default value of the field provided during class annotation does
# not override it. For example, consider the following code:
#
# @store_field_in_metadata('field_a')
# @dataclass
# class MyDataclass:
# metadata: Metadata = field(default_factory=Metadata)
# field_a: int = 3
#
# If we don't disable wrap the __init__ method, the following code
# will print `3`
#
# d = MyDataclass(metadata={'field_a': 5})
# print(d.field_a)
#
# but if we do, it will work fine and print `5` as expected.
#
# The reason why this occurs is that the __init__ method generated
# by a dataclass uses the default value of the field to initialize the
# class if a default is not provided.
#
# Our solution is rather simple: before calling the dataclass init,
# we look if:
# 1. A `metadata` argument is provided in the constructor, and
# 2. The `metadata` argument contains a field with name ``
#
# To disable the auto-init, we have to do two things:
# 1. create a new dataclass that inherits from the original one,
# but with init=False for field wrapper_field_name
# 2. create a wrapper for the __init__ method of the new dataclass
# that, when called, calls the original __init__ method and then
# adds the field value to the metadata dict.

# This signature is going to be used to bind to the args/kwargs during
# init, which allows easy lookup of arguments/keywords arguments by
# name.
cls_signature = inspect.signature(cls_.__init__)

# We need to save the init method since we will override it.
cls_init_fn = cls_.__init__

@wraps(cls_init_fn)
def init_wrapper(self, *args, **kwargs):
# parse the arguments and keywords arguments
arguments = cls_signature.bind(self, *args, **kwargs).arguments

# this adds the metadata to kwargs if it is not already there
metadata = arguments.setdefault("metadata", Metadata())

# this is the main check:
# (a) the metadata argument contains the field we are storing in
# the metadata, and (b) the field is not in args/kwargs, then we
# pass the field value in the metadata to the original init method
# to prevent it from being overwritten by its default value.
if (
wrapper_field_name in metadata
and wrapper_field_name not in arguments
):
arguments[wrapper_field_name] = metadata[wrapper_field_name]

# type: ignore is due to pylance not recognizing that the
# arguments in the signature contain a `self` key
cls_init_fn(**arguments) # type: ignore

setattr(cls_, "__init__", init_wrapper)
# # # # # # # # # # # # # # END OF STEP 2 # # # # # # # # # # # # # # #

# # # # # # # # # # # STEP 3: GETTERS & SETTERS # # # # # # # # # # # #

# We add the getter from here on:
if wrapper_getter_fn is None:
# create property for the deprecated field, as well as a setter
# that will add to the underlying metadata dict
Expand Down Expand Up @@ -261,6 +364,7 @@ def _wrapper_getter_fn(

field_property = property(wrapper_getter_fn)

# We add the setter from here on:
if wrapper_setter_fn is None:

def _wrapper_setter_fn(self: T, value: Any) -> None:
Expand Down Expand Up @@ -289,7 +393,8 @@ def _wrapper_setter_fn(self: T, value: Any) -> None:

# assign the property to the dataclass
setattr(cls_, wrapper_field_name, field_property)
# # # # # # # # # # # # # # END OF STEP 3 # # # # # # # # # # # # # # #

return cls_

return wrapper
return wrapper_fn
21 changes: 17 additions & 4 deletions tests/test_types/test_json_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,31 @@
'''

import json
from pathlib import Path

from mmda.types.annotation import SpanGroup
from mmda.types.annotation import BoxGroup, SpanGroup
from mmda.types.document import Document
from mmda.parsers.pdfplumber_parser import PDFPlumberParser
from mmda.types.metadata import Metadata


PDFFILEPATH = "../fixtures/1903.10676.pdf"
PDFFILEPATH = Path(__file__).parent / "../fixtures/1903.10676.pdf"

def test_json_conversion():

def test_span_group_conversion():
sg = SpanGroup(id=3, metadata=Metadata.from_json({"text": "test"}))
sg2 = SpanGroup.from_json(sg.to_json())
assert sg2 == sg

bg = BoxGroup(metadata=Metadata.from_json({"text": "test", "id": 1}))
bg2 = BoxGroup.from_json(bg.to_json())
assert bg2 == bg


def test_doc_conversion():
pdfparser = PDFPlumberParser()

orig_doc = pdfparser.parse(input_pdf_path=PDFFILEPATH)
orig_doc = pdfparser.parse(input_pdf_path=str(PDFFILEPATH))

json_doc = json.dumps(orig_doc.to_json())
new_doc = Document.from_json(json.loads(json_doc))
Expand Down

0 comments on commit 499b7f3

Please sign in to comment.