diff --git a/edspdf/accelerators/__init__.py b/edspdf/accelerators/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/edspdf/accelerators/base.py b/edspdf/accelerators/base.py new file mode 100644 index 00000000..a72650a2 --- /dev/null +++ b/edspdf/accelerators/base.py @@ -0,0 +1,102 @@ +from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Union + +from ..structures import PDFDoc + + +class FromDictFieldsToDoc: + def __init__(self, content_field, id_field=None): + self.content_field = content_field + self.id_field = id_field + + def __call__(self, item): + if isinstance(item, dict): + return PDFDoc( + content=item[self.content_field], + id=item[self.id_field] if self.id_field else None, + ) + return item + + +class ToDoc: + @classmethod + def __get_validators__(cls): + yield cls.validate + + @classmethod + def validate(cls, value, config=None): + if isinstance(value, str): + return FromDictFieldsToDoc(value) + elif isinstance(value, dict): + return FromDictFieldsToDoc(**value) + elif callable(value): + return value + else: + raise TypeError( + f"Invalid entry {value} ({type(value)}) for ToDoc, " + f"expected string, a dict or a callable." + ) + + +def identity(x): + return x + + +FROM_DOC_TO_DICT_FIELDS_TEMPLATE = """\ +def fn(doc): + return {X} +""" + + +class FromDocToDictFields: + def __init__(self, mapping): + self.mapping = mapping + dict_fields = ", ".join(f"{k}: doc.{v}" for k, v in mapping.items()) + self.fn = eval(FROM_DOC_TO_DICT_FIELDS_TEMPLATE.replace("X", dict_fields)) + + def __reduce__(self): + return FromDocToDictFields, (self.mapping,) + + def __call__(self, doc): + return self.fn(doc) + + +class FromDoc: + """ + A FromDoc converter (from a PDFDoc to an arbitrary type) can be either: + + - a dict mapping field names to doc attributes + - a callable that takes a PDFDoc and returns an arbitrary type + """ + + @classmethod + def __get_validators__(cls): + yield cls.validate + + @classmethod + def validate(cls, value, config=None): + if isinstance(value, dict): + return FromDocToDictFields(value) + elif callable(value): + return value + else: + raise TypeError( + f"Invalid entry {value} ({type(value)}) for ToDoc, " + f"expected dict or callable" + ) + + +class Accelerator: + def __call__( + self, + inputs: Iterable[Any], + model: Any, + to_doc: ToDoc = FromDictFieldsToDoc("content"), + from_doc: FromDoc = lambda doc: doc, + component_cfg: Dict[str, Dict[str, Any]] = None, + ): + raise NotImplementedError() + + +if TYPE_CHECKING: + ToDoc = Union[str, Dict[str, Any], Callable[[Any], PDFDoc]] # noqa: F811 + FromDoc = Union[Dict[str, Any], Callable[[PDFDoc], Any]] # noqa: F811 diff --git a/edspdf/accelerators/simple.py b/edspdf/accelerators/simple.py new file mode 100644 index 00000000..219af4c5 --- /dev/null +++ b/edspdf/accelerators/simple.py @@ -0,0 +1,97 @@ +from typing import Any, Dict, Iterable + +import torch + +from ..registry import registry +from ..utils.collections import batchify +from .base import Accelerator, FromDictFieldsToDoc, FromDoc, ToDoc + + +@registry.accelerator.register("simple") +class SimpleAccelerator(Accelerator): + """ + This is the simplest accelerator which batches the documents and process each batch + on the main process (the one calling `.pipe()`). + + Examples + -------- + + ```python + docs = list(pipeline.pipe([content1, content2, ...])) + ``` + + or, if you want to override the model defined batch size + + ```python + docs = list(pipeline.pipe([content1, content2, ...], batch_size=8)) + ``` + + which is equivalent to passing a confit dict + + ```python + docs = list( + pipeline.pipe( + [content1, content2, ...], + accelerator={ + "@accelerator": "simple", + "batch_size": 8, + }, + ) + ) + ``` + + or the instantiated accelerator directly + + ```python + from edspdf.accelerators.simple import SimpleAccelerator + + accelerator = SimpleAccelerator(batch_size=8) + docs = list(pipeline.pipe([content1, content2, ...], accelerator=accelerator)) + ``` + + If you have a GPU, make sure to move the model to the appropriate device before + calling `.pipe()`. If you have multiple GPUs, use the + [multiprocessing][edspdf.accelerators.multiprocessing.MultiprocessingAccelerator] + accelerator instead. + + ```python + pipeline.to("cuda") + docs = list(pipeline.pipe([content1, content2, ...])) + ``` + + Parameters + ---------- + batch_size: int + The number of documents to process in each batch. + """ + + def __init__( + self, + *, + batch_size: int = 32, + ): + self.batch_size = batch_size + + def __call__( + self, + inputs: Iterable[Any], + model: Any, + to_doc: ToDoc = FromDictFieldsToDoc("content"), + from_doc: FromDoc = lambda doc: doc, + component_cfg: Dict[str, Dict[str, Any]] = None, + ): + if from_doc is None: + + def from_doc(doc): + return doc + + docs = (to_doc(doc) for doc in inputs) + for batch in batchify(docs, batch_size=self.batch_size): + with torch.no_grad(), model.cache(), model.train(False): + for name, pipe in model.pipeline: + if name not in model._disabled: + if hasattr(pipe, "batch_process"): + batch = pipe.batch_process(batch) + else: + batch = [pipe(doc) for doc in batch] # type: ignore + yield from (from_doc(doc) for doc in batch) diff --git a/edspdf/pipeline.py b/edspdf/pipeline.py index 395608ee..932e3160 100644 --- a/edspdf/pipeline.py +++ b/edspdf/pipeline.py @@ -23,7 +23,7 @@ Union, ) -from confit import Config +from confit import Config, validate_arguments from confit.errors import ConfitValidationError, patch_errors from confit.utils.collections import join_path, split_path from confit.utils.xjson import Reference @@ -31,6 +31,7 @@ import edspdf +from .accelerators.base import Accelerator, FromDoc, ToDoc from .registry import CurriedFactory, registry from .structures import PDFDoc from .utils.collections import ( @@ -239,51 +240,19 @@ def add_pipe( self._components.insert(insertion_idx, (name, pipe)) return pipe - def make_doc(self, content: bytes) -> PDFDoc: - """ - Create a PDFDoc from text. - - Parameters - ---------- - content: bytes - The bytes content to create the PDFDoc from. - - Returns - ------- - PDFDoc - """ - return PDFDoc(content=content) - - def _ensure_doc(self, text: Union[bytes, PDFDoc]) -> PDFDoc: - """ - Ensure that the input is a PDFDoc. - - Parameters - ---------- - text: Union[str, PDFDoc] - The text to create the PDFDoc from, or a PDFDoc. - - Returns - ------- - PDFDoc - """ - return text if isinstance(text, PDFDoc) else self.make_doc(text) - - def __call__(self, text: Union[str, PDFDoc]) -> PDFDoc: + def __call__(self, doc: Any) -> PDFDoc: """ Apply each component successively on a document. Parameters ---------- - text: Union[str, PDFDoc] - The text to create the PDFDoc from, or a PDFDoc. + doc: Union[str, PDFDoc] + The doc to create the PDFDoc from, or a PDFDoc. Returns ------- PDFDoc """ - doc = self._ensure_doc(text) - with self.cache(): for name, pipe in self.pipeline: if name in self._disabled: @@ -298,11 +267,15 @@ def __call__(self, text: Union[str, PDFDoc]) -> PDFDoc: return doc + @validate_arguments def pipe( self, - texts: Iterable[Union[str, PDFDoc]], + inputs: Any, batch_size: Optional[int] = None, - component_cfg: Dict[str, Dict[str, Any]] = None, + *, + accelerator: Optional[Union[str, Accelerator]] = None, + to_doc: Optional[ToDoc] = None, + from_doc: FromDoc = lambda doc: doc, ) -> Iterable[PDFDoc]: """ Process a stream of documents by applying each component successively on @@ -310,48 +283,49 @@ def pipe( Parameters ---------- - texts: Iterable[Union[str, PDFDoc]] - The texts to create the Docs from, or Docs directly. + inputs: Iterable[Union[str, PDFDoc]] + The inputs to create the PDFDocs from, or the PDFDocs directly. batch_size: Optional[int] The batch size to use. If not provided, the batch size of the pipeline object will be used. - component_cfg: Dict[str, Dict[str, Any]] - The arguments to pass to the components when processing the documents. + accelerator: Optional[Union[str, Accelerator]] + The accelerator to use for processing the documents. If not provided, + the default accelerator will be used. + to_doc: ToDoc + The function to use to convert the inputs to PDFDoc objects. By default, + the `content` field of the inputs will be used if dict-like objects are + provided, otherwise the inputs will be passed directly to the pipeline. + from_doc: FromDoc + The function to use to convert the PDFDoc objects to outputs. By default, + the PDFDoc objects will be returned directly. Returns ------- Iterable[PDFDoc] """ - import torch - if component_cfg is None: - component_cfg = {} if batch_size is None: batch_size = self.batch_size - docs = (self._ensure_doc(text) for text in texts) - - was_training = {name: proc.training for name, proc in self.trainable_pipes()} - for name, proc in self.trainable_pipes(): - proc.train(False) - - with torch.no_grad(): - for batch in batchify(docs, batch_size=batch_size): - with self.cache(): - for name, pipe in self.pipeline: - if name not in self._disabled: - kwargs = component_cfg.get(name, {}) - if hasattr(pipe, "batch_process"): - batch = pipe.batch_process(batch, **kwargs) - else: - batch = [ - pipe(doc, **kwargs) for doc in batch # type: ignore - ] - - yield from batch - - for name, proc in self.trainable_pipes(): - proc.train(was_training[name]) + if accelerator is None: + accelerator = {"@accelerator": "simple", "batch_size": batch_size} + if isinstance(accelerator, str): + accelerator = {"@accelerator": accelerator, "batch_size": batch_size} + if isinstance(accelerator, dict): + accelerator = Config(accelerator).resolve(registry=registry) + + kwargs = { + "inputs": inputs, + "model": self, + "to_doc": to_doc, + "from_doc": from_doc, + } + for k, v in list(kwargs.items()): + if v is None: + del kwargs[k] + + with self.train(False): + return accelerator(**kwargs) @contextmanager def cache(self): diff --git a/edspdf/registry.py b/edspdf/registry.py index a7333b74..641b0eb2 100644 --- a/edspdf/registry.py +++ b/edspdf/registry.py @@ -219,3 +219,4 @@ class registry(RegistryCollection): factory = factories = FactoryRegistry(("edspdf", "factories"), entry_points=True) misc = Registry(("edspdf", "misc"), entry_points=True) adapter = Registry(("edspdf", "adapter"), entry_points=True) + accelerator = Registry(("edspdf", "accelerator"), entry_points=True) diff --git a/pyproject.toml b/pyproject.toml index 2d7d8073..c5e81df5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -89,6 +89,10 @@ docs = [ # Aggregators "simple-aggregator" = "edspdf.pipes.aggregators.simple:SimpleAggregator" +[project.entry-points."edspdf_accelerator"] +"simple" = "edspdf.accelerators.simple:SimpleAccelerator" +"multiprocessing" = "edspdf.accelerators.multiprocessing:MultiprocessingAccelerator" + [project.entry-points."mkdocs.plugins"] "bibtex" = "docs.scripts.bibtex:BibTexPlugin" @@ -119,6 +123,9 @@ color = true omit-covered-files = false [tool.coverage.report] +omit = [ + "edspdf/accelerators/multi_gpu.py", +] exclude_also = [ "def __repr__", "if __name__ == .__main__.:",