Skip to content

Commit

Permalink
feat: delegate inference logic to new accelerator parameter in pipe
Browse files Browse the repository at this point in the history
  • Loading branch information
percevalw committed Sep 7, 2023
1 parent 2d117f3 commit 46c4d46
Show file tree
Hide file tree
Showing 6 changed files with 249 additions and 68 deletions.
Empty file added edspdf/accelerators/__init__.py
Empty file.
102 changes: 102 additions & 0 deletions edspdf/accelerators/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Union

from ..structures import PDFDoc


class FromDictFieldsToDoc:
def __init__(self, content_field, id_field=None):
self.content_field = content_field
self.id_field = id_field

def __call__(self, item):
if isinstance(item, dict):
return PDFDoc(
content=item[self.content_field],
id=item[self.id_field] if self.id_field else None,
)
return item


class ToDoc:
@classmethod
def __get_validators__(cls):
yield cls.validate

@classmethod
def validate(cls, value, config=None):
if isinstance(value, str):
return FromDictFieldsToDoc(value)
elif isinstance(value, dict):
return FromDictFieldsToDoc(**value)
elif callable(value):
return value
else:
raise TypeError(
f"Invalid entry {value} ({type(value)}) for ToDoc, "
f"expected string, a dict or a callable."
)


def identity(x):
return x


FROM_DOC_TO_DICT_FIELDS_TEMPLATE = """\
def fn(doc):
return {X}
"""


class FromDocToDictFields:
def __init__(self, mapping):
self.mapping = mapping
dict_fields = ", ".join(f"{k}: doc.{v}" for k, v in mapping.items())
self.fn = eval(FROM_DOC_TO_DICT_FIELDS_TEMPLATE.replace("X", dict_fields))

def __reduce__(self):
return FromDocToDictFields, (self.mapping,)

def __call__(self, doc):
return self.fn(doc)


class FromDoc:
"""
A FromDoc converter (from a PDFDoc to an arbitrary type) can be either:
- a dict mapping field names to doc attributes
- a callable that takes a PDFDoc and returns an arbitrary type
"""

@classmethod
def __get_validators__(cls):
yield cls.validate

@classmethod
def validate(cls, value, config=None):
if isinstance(value, dict):
return FromDocToDictFields(value)
elif callable(value):
return value
else:
raise TypeError(
f"Invalid entry {value} ({type(value)}) for ToDoc, "
f"expected dict or callable"
)


class Accelerator:
def __call__(
self,
inputs: Iterable[Any],
model: Any,
to_doc: ToDoc = FromDictFieldsToDoc("content"),
from_doc: FromDoc = lambda doc: doc,
component_cfg: Dict[str, Dict[str, Any]] = None,
):
raise NotImplementedError()


if TYPE_CHECKING:
ToDoc = Union[str, Dict[str, Any], Callable[[Any], PDFDoc]] # noqa: F811
FromDoc = Union[Dict[str, Any], Callable[[PDFDoc], Any]] # noqa: F811
97 changes: 97 additions & 0 deletions edspdf/accelerators/simple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from typing import Any, Dict, Iterable

import torch

from ..registry import registry
from ..utils.collections import batchify
from .base import Accelerator, FromDictFieldsToDoc, FromDoc, ToDoc


@registry.accelerator.register("simple")
class SimpleAccelerator(Accelerator):
"""
This is the simplest accelerator which batches the documents and process each batch
on the main process (the one calling `.pipe()`).
Examples
--------
```python
docs = list(pipeline.pipe([content1, content2, ...]))
```
or, if you want to override the model defined batch size
```python
docs = list(pipeline.pipe([content1, content2, ...], batch_size=8))
```
which is equivalent to passing a confit dict
```python
docs = list(
pipeline.pipe(
[content1, content2, ...],
accelerator={
"@accelerator": "simple",
"batch_size": 8,
},
)
)
```
or the instantiated accelerator directly
```python
from edspdf.accelerators.simple import SimpleAccelerator
accelerator = SimpleAccelerator(batch_size=8)
docs = list(pipeline.pipe([content1, content2, ...], accelerator=accelerator))
```
If you have a GPU, make sure to move the model to the appropriate device before
calling `.pipe()`. If you have multiple GPUs, use the
[multiprocessing][edspdf.accelerators.multiprocessing.MultiprocessingAccelerator]
accelerator instead.
```python
pipeline.to("cuda")
docs = list(pipeline.pipe([content1, content2, ...]))
```
Parameters
----------
batch_size: int
The number of documents to process in each batch.
"""

def __init__(
self,
*,
batch_size: int = 32,
):
self.batch_size = batch_size

def __call__(
self,
inputs: Iterable[Any],
model: Any,
to_doc: ToDoc = FromDictFieldsToDoc("content"),
from_doc: FromDoc = lambda doc: doc,
component_cfg: Dict[str, Dict[str, Any]] = None,
):
if from_doc is None:

def from_doc(doc):
return doc

docs = (to_doc(doc) for doc in inputs)
for batch in batchify(docs, batch_size=self.batch_size):
with torch.no_grad(), model.cache(), model.train(False):
for name, pipe in model.pipeline:
if name not in model._disabled:
if hasattr(pipe, "batch_process"):
batch = pipe.batch_process(batch)
else:
batch = [pipe(doc) for doc in batch] # type: ignore
yield from (from_doc(doc) for doc in batch)
110 changes: 42 additions & 68 deletions edspdf/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,15 @@
Union,
)

from confit import Config
from confit import Config, validate_arguments
from confit.errors import ConfitValidationError, patch_errors
from confit.utils.collections import join_path, split_path
from confit.utils.xjson import Reference
from tqdm import tqdm

import edspdf

from .accelerators.base import Accelerator, FromDoc, ToDoc
from .registry import CurriedFactory, registry
from .structures import PDFDoc
from .utils.collections import (
Expand Down Expand Up @@ -239,51 +240,19 @@ def add_pipe(
self._components.insert(insertion_idx, (name, pipe))
return pipe

def make_doc(self, content: bytes) -> PDFDoc:
"""
Create a PDFDoc from text.
Parameters
----------
content: bytes
The bytes content to create the PDFDoc from.
Returns
-------
PDFDoc
"""
return PDFDoc(content=content)

def _ensure_doc(self, text: Union[bytes, PDFDoc]) -> PDFDoc:
"""
Ensure that the input is a PDFDoc.
Parameters
----------
text: Union[str, PDFDoc]
The text to create the PDFDoc from, or a PDFDoc.
Returns
-------
PDFDoc
"""
return text if isinstance(text, PDFDoc) else self.make_doc(text)

def __call__(self, text: Union[str, PDFDoc]) -> PDFDoc:
def __call__(self, doc: Any) -> PDFDoc:
"""
Apply each component successively on a document.
Parameters
----------
text: Union[str, PDFDoc]
The text to create the PDFDoc from, or a PDFDoc.
doc: Union[str, PDFDoc]
The doc to create the PDFDoc from, or a PDFDoc.
Returns
-------
PDFDoc
"""
doc = self._ensure_doc(text)

with self.cache():
for name, pipe in self.pipeline:
if name in self._disabled:
Expand All @@ -298,60 +267,65 @@ def __call__(self, text: Union[str, PDFDoc]) -> PDFDoc:

return doc

@validate_arguments
def pipe(
self,
texts: Iterable[Union[str, PDFDoc]],
inputs: Any,
batch_size: Optional[int] = None,
component_cfg: Dict[str, Dict[str, Any]] = None,
*,
accelerator: Optional[Union[str, Accelerator]] = None,
to_doc: Optional[ToDoc] = None,
from_doc: FromDoc = lambda doc: doc,
) -> Iterable[PDFDoc]:
"""
Process a stream of documents by applying each component successively on
batches of documents.
Parameters
----------
texts: Iterable[Union[str, PDFDoc]]
The texts to create the Docs from, or Docs directly.
inputs: Iterable[Union[str, PDFDoc]]
The inputs to create the PDFDocs from, or the PDFDocs directly.
batch_size: Optional[int]
The batch size to use. If not provided, the batch size of the pipeline
object will be used.
component_cfg: Dict[str, Dict[str, Any]]
The arguments to pass to the components when processing the documents.
accelerator: Optional[Union[str, Accelerator]]
The accelerator to use for processing the documents. If not provided,
the default accelerator will be used.
to_doc: ToDoc
The function to use to convert the inputs to PDFDoc objects. By default,
the `content` field of the inputs will be used if dict-like objects are
provided, otherwise the inputs will be passed directly to the pipeline.
from_doc: FromDoc
The function to use to convert the PDFDoc objects to outputs. By default,
the PDFDoc objects will be returned directly.
Returns
-------
Iterable[PDFDoc]
"""
import torch

if component_cfg is None:
component_cfg = {}
if batch_size is None:
batch_size = self.batch_size

docs = (self._ensure_doc(text) for text in texts)

was_training = {name: proc.training for name, proc in self.trainable_pipes()}
for name, proc in self.trainable_pipes():
proc.train(False)

with torch.no_grad():
for batch in batchify(docs, batch_size=batch_size):
with self.cache():
for name, pipe in self.pipeline:
if name not in self._disabled:
kwargs = component_cfg.get(name, {})
if hasattr(pipe, "batch_process"):
batch = pipe.batch_process(batch, **kwargs)
else:
batch = [
pipe(doc, **kwargs) for doc in batch # type: ignore
]

yield from batch

for name, proc in self.trainable_pipes():
proc.train(was_training[name])
if accelerator is None:
accelerator = {"@accelerator": "simple", "batch_size": batch_size}
if isinstance(accelerator, str):
accelerator = {"@accelerator": accelerator, "batch_size": batch_size}
if isinstance(accelerator, dict):
accelerator = Config(accelerator).resolve(registry=registry)

kwargs = {
"inputs": inputs,
"model": self,
"to_doc": to_doc,
"from_doc": from_doc,
}
for k, v in list(kwargs.items()):
if v is None:
del kwargs[k]

with self.train(False):
return accelerator(**kwargs)

@contextmanager
def cache(self):
Expand Down
1 change: 1 addition & 0 deletions edspdf/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,3 +219,4 @@ class registry(RegistryCollection):
factory = factories = FactoryRegistry(("edspdf", "factories"), entry_points=True)
misc = Registry(("edspdf", "misc"), entry_points=True)
adapter = Registry(("edspdf", "adapter"), entry_points=True)
accelerator = Registry(("edspdf", "accelerator"), entry_points=True)
7 changes: 7 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ docs = [
# Aggregators
"simple-aggregator" = "edspdf.pipes.aggregators.simple:SimpleAggregator"

[project.entry-points."edspdf_accelerator"]
"simple" = "edspdf.accelerators.simple:SimpleAccelerator"
"multiprocessing" = "edspdf.accelerators.multiprocessing:MultiprocessingAccelerator"

[project.entry-points."mkdocs.plugins"]

"bibtex" = "docs.scripts.bibtex:BibTexPlugin"
Expand Down Expand Up @@ -119,6 +123,9 @@ color = true
omit-covered-files = false

[tool.coverage.report]
omit = [
"edspdf/accelerators/multi_gpu.py",
]
exclude_also = [
"def __repr__",
"if __name__ == .__main__.:",
Expand Down

0 comments on commit 46c4d46

Please sign in to comment.