Skip to content

Commit

Permalink
refacto: align data api with edsnlp
Browse files Browse the repository at this point in the history
  • Loading branch information
percevalw committed Feb 7, 2024
1 parent ec083ed commit 486feef
Show file tree
Hide file tree
Showing 24 changed files with 2,169 additions and 1,059 deletions.
6 changes: 3 additions & 3 deletions docs/trainable-pipes.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,12 @@ class MyComponent(TrainablePipe):
"my-feature": ...(doc),
}

def collate(self, batch, device: torch.device) -> Dict:
def collate(self, batch) -> Dict:
# Collate the features of the "embedding" subcomponent
# and the features of this component as well
return {
"embedding": self.embedding.collate(batch["embedding"], device),
"my-feature": torch.as_tensor(batch["my-feature"], device=device),
"embedding": self.embedding.collate(batch["embedding"]),
"my-feature": torch.as_tensor(batch["my-feature"]),
}

def forward(self, batch: Dict, supervision=False) -> Dict:
Expand Down
1 change: 1 addition & 0 deletions edspdf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .pipeline import Pipeline, load
from .registry import registry
from .structures import Box, Page, PDFDoc, Text, TextBox, TextProperties
from . import data

from . import utils # isort:skip

Expand Down
97 changes: 1 addition & 96 deletions edspdf/accelerators/base.py
Original file line number Diff line number Diff line change
@@ -1,97 +1,2 @@
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Union

from ..structures import PDFDoc


class FromDictFieldsToDoc:
def __init__(self, content_field, id_field=None):
self.content_field = content_field
self.id_field = id_field

def __call__(self, item):
if isinstance(item, dict):
return PDFDoc(
content=item[self.content_field],
id=item[self.id_field] if self.id_field else None,
)
return item


class ToDoc:
@classmethod
def __get_validators__(cls):
yield cls.validate

@classmethod
def validate(cls, value, config=None):
if isinstance(value, str):
value = {"content_field": value}
if isinstance(value, dict):
value = FromDictFieldsToDoc(**value)
if callable(value):
return value
raise TypeError(
f"Invalid entry {value} ({type(value)}) for ToDoc, "
f"expected string, a dict or a callable."
)


FROM_DOC_TO_DICT_FIELDS_TEMPLATE = """
def fn(doc):
return {X}
"""


class FromDocToDictFields:
def __init__(self, mapping):
self.mapping = mapping
dict_fields = ", ".join(f"{repr(k)}: doc.{v}" for k, v in mapping.items())
local_vars = {}
exec(FROM_DOC_TO_DICT_FIELDS_TEMPLATE.replace("X", dict_fields), local_vars)
self.fn = local_vars["fn"]

def __reduce__(self):
return FromDocToDictFields, (self.mapping,)

def __call__(self, doc):
return self.fn(doc)


class FromDoc:
"""
A FromDoc converter (from a PDFDoc to an arbitrary type) can be either:
- a dict mapping field names to doc attributes
- a callable that takes a PDFDoc and returns an arbitrary type
"""

@classmethod
def __get_validators__(cls):
yield cls.validate

@classmethod
def validate(cls, value, config=None):
if isinstance(value, dict):
value = FromDocToDictFields(value)
if callable(value):
return value
raise TypeError(
f"Invalid entry {value} ({type(value)}) for ToDoc, "
f"expected dict or callable"
)


class Accelerator:
def __call__(
self,
inputs: Iterable[Any],
model: Any,
to_doc: ToDoc = FromDictFieldsToDoc("content"),
from_doc: FromDoc = lambda doc: doc,
):
raise NotImplementedError()


if TYPE_CHECKING:
ToDoc = Union[str, Dict[str, Any], Callable[[Any], PDFDoc]] # noqa: F811
FromDoc = Union[Dict[str, Any], Callable[[PDFDoc], Any]] # noqa: F811
pass
Loading

0 comments on commit 486feef

Please sign in to comment.