-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: delegate inference logic to new accelerator parameter in pipe
- Loading branch information
Showing
6 changed files
with
249 additions
and
68 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Union | ||
|
||
from ..structures import PDFDoc | ||
|
||
|
||
class FromDictFieldsToDoc: | ||
def __init__(self, content_field, id_field=None): | ||
self.content_field = content_field | ||
self.id_field = id_field | ||
|
||
def __call__(self, item): | ||
if isinstance(item, dict): | ||
return PDFDoc( | ||
content=item[self.content_field], | ||
id=item[self.id_field] if self.id_field else None, | ||
) | ||
return item | ||
|
||
|
||
class ToDoc: | ||
@classmethod | ||
def __get_validators__(cls): | ||
yield cls.validate | ||
|
||
@classmethod | ||
def validate(cls, value, config=None): | ||
if isinstance(value, str): | ||
return FromDictFieldsToDoc(value) | ||
elif isinstance(value, dict): | ||
return FromDictFieldsToDoc(**value) | ||
elif callable(value): | ||
return value | ||
else: | ||
raise TypeError( | ||
f"Invalid entry {value} ({type(value)}) for ToDoc, " | ||
f"expected string, a dict or a callable." | ||
) | ||
|
||
|
||
def identity(x): | ||
return x | ||
|
||
|
||
FROM_DOC_TO_DICT_FIELDS_TEMPLATE = """\ | ||
def fn(doc): | ||
return {X} | ||
""" | ||
|
||
|
||
class FromDocToDictFields: | ||
def __init__(self, mapping): | ||
self.mapping = mapping | ||
dict_fields = ", ".join(f"{k}: doc.{v}" for k, v in mapping.items()) | ||
self.fn = eval(FROM_DOC_TO_DICT_FIELDS_TEMPLATE.replace("X", dict_fields)) | ||
|
||
def __reduce__(self): | ||
return FromDocToDictFields, (self.mapping,) | ||
|
||
def __call__(self, doc): | ||
return self.fn(doc) | ||
|
||
|
||
class FromDoc: | ||
""" | ||
A FromDoc converter (from a PDFDoc to an arbitrary type) can be either: | ||
- a dict mapping field names to doc attributes | ||
- a callable that takes a PDFDoc and returns an arbitrary type | ||
""" | ||
|
||
@classmethod | ||
def __get_validators__(cls): | ||
yield cls.validate | ||
|
||
@classmethod | ||
def validate(cls, value, config=None): | ||
if isinstance(value, dict): | ||
return FromDocToDictFields(value) | ||
elif callable(value): | ||
return value | ||
else: | ||
raise TypeError( | ||
f"Invalid entry {value} ({type(value)}) for ToDoc, " | ||
f"expected dict or callable" | ||
) | ||
|
||
|
||
class Accelerator: | ||
def __call__( | ||
self, | ||
inputs: Iterable[Any], | ||
model: Any, | ||
to_doc: ToDoc = FromDictFieldsToDoc("content"), | ||
from_doc: FromDoc = lambda doc: doc, | ||
component_cfg: Dict[str, Dict[str, Any]] = None, | ||
): | ||
raise NotImplementedError() | ||
|
||
|
||
if TYPE_CHECKING: | ||
ToDoc = Union[str, Dict[str, Any], Callable[[Any], PDFDoc]] # noqa: F811 | ||
FromDoc = Union[Dict[str, Any], Callable[[PDFDoc], Any]] # noqa: F811 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
from typing import Any, Dict, Iterable | ||
|
||
import torch | ||
|
||
from ..registry import registry | ||
from ..utils.collections import batchify | ||
from .base import Accelerator, FromDictFieldsToDoc, FromDoc, ToDoc | ||
|
||
|
||
@registry.accelerator.register("simple") | ||
class SimpleAccelerator(Accelerator): | ||
""" | ||
This is the simplest accelerator which batches the documents and process each batch | ||
on the main process (the one calling `.pipe()`). | ||
Examples | ||
-------- | ||
```python | ||
docs = list(pipeline.pipe([content1, content2, ...])) | ||
``` | ||
or, if you want to override the model defined batch size | ||
```python | ||
docs = list(pipeline.pipe([content1, content2, ...], batch_size=8)) | ||
``` | ||
which is equivalent to passing a confit dict | ||
```python | ||
docs = list( | ||
pipeline.pipe( | ||
[content1, content2, ...], | ||
accelerator={ | ||
"@accelerator": "simple", | ||
"batch_size": 8, | ||
}, | ||
) | ||
) | ||
``` | ||
or the instantiated accelerator directly | ||
```python | ||
from edspdf.accelerators.simple import SimpleAccelerator | ||
accelerator = SimpleAccelerator(batch_size=8) | ||
docs = list(pipeline.pipe([content1, content2, ...], accelerator=accelerator)) | ||
``` | ||
If you have a GPU, make sure to move the model to the appropriate device before | ||
calling `.pipe()`. If you have multiple GPUs, use the | ||
[multiprocessing][edspdf.accelerators.multiprocessing.MultiprocessingAccelerator] | ||
accelerator instead. | ||
```python | ||
pipeline.to("cuda") | ||
docs = list(pipeline.pipe([content1, content2, ...])) | ||
``` | ||
Parameters | ||
---------- | ||
batch_size: int | ||
The number of documents to process in each batch. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
*, | ||
batch_size: int = 32, | ||
): | ||
self.batch_size = batch_size | ||
|
||
def __call__( | ||
self, | ||
inputs: Iterable[Any], | ||
model: Any, | ||
to_doc: ToDoc = FromDictFieldsToDoc("content"), | ||
from_doc: FromDoc = lambda doc: doc, | ||
component_cfg: Dict[str, Dict[str, Any]] = None, | ||
): | ||
if from_doc is None: | ||
|
||
def from_doc(doc): | ||
return doc | ||
|
||
docs = (to_doc(doc) for doc in inputs) | ||
for batch in batchify(docs, batch_size=self.batch_size): | ||
with torch.no_grad(), model.cache(), model.train(False): | ||
for name, pipe in model.pipeline: | ||
if name not in model._disabled: | ||
if hasattr(pipe, "batch_process"): | ||
batch = pipe.batch_process(batch) | ||
else: | ||
batch = [pipe(doc) for doc in batch] # type: ignore | ||
yield from (from_doc(doc) for doc in batch) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters