Skip to content

Commit

Permalink
refacto: update trainable components and data api
Browse files Browse the repository at this point in the history
  • Loading branch information
percevalw committed Jun 14, 2024
1 parent c5b1a84 commit 06f142a
Show file tree
Hide file tree
Showing 10 changed files with 471 additions and 213 deletions.
7 changes: 7 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,17 @@

## Unreleased

### Changed

- Default to fp16 when inferring with gpu
- Support `inputs` parameter in `TrainablePipe.postprocess(...)` method (as in edsnlp)
- We now check that the user isn't trying to write a single file in a split fashion (when `write_in_worker is True ` or `num_rows_per_file is not None`) and raise an error if they do

### Fixed

- Batches full of empty content boxes no longer crash the `huggingface-embedding` component
- Ensure models are always loaded in non training mode
- Improved performance of `edsnlp.data` methods over a filesystem (`fs` parameter)

## v0.9.1

Expand Down
9 changes: 7 additions & 2 deletions docs/trainable-pipes.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ Additionally, there is a fifth method:
Here is an example of a trainable component:

```python
from typing import Any, Dict, Iterable, Sequence
from typing import Any, Dict, Iterable, Sequence, List

import torch
from tqdm import tqdm
Expand Down Expand Up @@ -114,7 +114,12 @@ class MyComponent(TrainablePipe):

return output

def postprocess(self, docs: Sequence[PDFDoc], output: Dict) -> Sequence[PDFDoc]:
def postprocess(
self,
docs: Sequence[PDFDoc],
output: Dict,
inputs: List[Dict[str, Any]],
) -> Sequence[PDFDoc]:
# Annotate the docs with the outputs of the forward method
...
return docs
Expand Down
116 changes: 60 additions & 56 deletions edspdf/data/parquet.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import os
import sys
from itertools import chain
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Optional, TypeVar, Union

import fsspec
import pyarrow.dataset
import pyarrow.fs
import pyarrow.parquet
Expand All @@ -17,6 +18,7 @@
from edspdf.lazy_collection import LazyCollection
from edspdf.structures import PDFDoc, registry
from edspdf.utils.collections import dl_to_ld, flatten, ld_to_dl
from edspdf.utils.filesystem import FileSystem, normalize_fs_path


class ParquetReader(BaseReader):
Expand All @@ -27,27 +29,15 @@ def __init__(
path: Union[str, Path],
*,
read_in_worker: bool,
filesystem: Optional[pyarrow.fs.FileSystem] = None,
filesystem: Optional[FileSystem] = None,
):
super().__init__()
# Either the filesystem has not been passed
# or the path is a URL (e.g. s3://) => we need to infer the filesystem
fs_path = path
if filesystem is None or (isinstance(path, str) and "://" in path):
path = (
path
if isinstance(path, Path) or "://" in path
else f"file://{os.path.abspath(path)}"
)
inferred_fs, fs_path = pyarrow.fs.FileSystem.from_uri(path)
filesystem = filesystem or inferred_fs
assert inferred_fs.type_name == filesystem.type_name, (
f"Protocol {inferred_fs.type_name} in path does not match "
f"filesystem {filesystem.type_name}"
)
filesystem, path = normalize_fs_path(filesystem, path)
self.read_in_worker = read_in_worker
self.dataset = pyarrow.dataset.dataset(
fs_path, format="parquet", filesystem=filesystem
path, format="parquet", filesystem=filesystem
)

def read_main(self):
Expand All @@ -60,17 +50,14 @@ def read_main(self):
return (
(line, 1)
for f in fragments
for batch in f.to_table().to_batches(1024)
for line in dl_to_ld(batch.to_pydict())
for line in dl_to_ld(f.to_table().to_pydict())
)

def read_worker(self, tasks):
if self.read_in_worker:
tasks = list(
chain.from_iterable(
dl_to_ld(batch.to_pydict())
for task in tasks
for batch in task.to_table().to_batches(1024)
dl_to_ld(task.to_table().to_pydict()) for task in tasks
)
)
return tasks
Expand All @@ -82,47 +69,55 @@ def read_worker(self, tasks):
class ParquetWriter(BaseWriter):
def __init__(
self,
*,
path: Union[str, Path],
num_rows_per_file: int,
num_rows_per_file: Optional[int] = None,
overwrite: bool,
write_in_worker: bool,
accumulate: bool = True,
filesystem: Optional[pyarrow.fs.FileSystem] = None,
filesystem: Optional[FileSystem] = None,
):
super().__init__()
fs_path = path
if filesystem is None or (isinstance(path, str) and "://" in path):
path = (
path
if isinstance(path, Path) or "://" in path
else f"file://{os.path.abspath(path)}"
)
inferred_fs, fs_path = pyarrow.fs.FileSystem.from_uri(path)
filesystem = filesystem or inferred_fs
assert inferred_fs.type_name == filesystem.type_name, (
f"Protocol {inferred_fs.type_name} in path does not match "
f"filesystem {filesystem.type_name}"
)
path = fs_path
filesystem, path = normalize_fs_path(filesystem, path)
# Check that filesystem has the same protocol as indicated by path
filesystem.create_dir(fs_path, recursive=True)
looks_like_dir = Path(path).suffix == ""
if looks_like_dir or num_rows_per_file is not None:
num_rows_per_file = num_rows_per_file or 8192
filesystem.makedirs(path, exist_ok=True)
save_as_dataset = True
else:
assert (
num_rows_per_file is None
), "num_rows_per_file should not be set when writing to a single file"
assert (
write_in_worker is False
), "write_in_worker cannot be set when writing to a single file"
save_as_dataset = False
num_rows_per_file = sys.maxsize
if overwrite is False:
dataset = pyarrow.dataset.dataset(
fs_path, format="parquet", filesystem=filesystem
)
if len(list(dataset.get_fragments())):
raise FileExistsError(
f"Directory {fs_path} already exists and is not empty. "
"Use overwrite=True to overwrite."
if save_as_dataset:
dataset = pyarrow.dataset.dataset(
path, format="parquet", filesystem=filesystem
)
self.filesystem = filesystem
if len(list(dataset.get_fragments())):
raise FileExistsError(
f"Directory {path} already exists and is not empty. "
"Use overwrite=True to overwrite."
)
else:
if filesystem.exists(path):
raise FileExistsError(
f"File {path} already exists. Use overwrite=True to overwrite."
)
self.filesystem: fsspec.AbstractFileSystem = filesystem
self.path = path
self.save_as_dataset = save_as_dataset
self.write_in_worker = write_in_worker
self.batch = []
self.num_rows_per_file = num_rows_per_file
self.closed = False
self.finalized = False
self.accumulate = accumulate
self.accumulate = (not self.save_as_dataset) and accumulate
if not self.accumulate:
self.finalize = super().finalize

Expand Down Expand Up @@ -162,13 +157,22 @@ def finalize(self):
return self.write_worker([], last=True)

def write_main(self, fragments: Iterable[List[Union[pyarrow.Table, Path]]]):
for table in flatten(fragments):
if not self.write_in_worker:
pyarrow.parquet.write_to_dataset(
table=table,
root_path=self.path,
filesystem=self.filesystem,
)
tables = list(flatten(fragments))
if self.save_as_dataset:
for table in tables:
if not self.write_in_worker:
pyarrow.parquet.write_to_dataset(
table=table,
root_path=self.path,
filesystem=self.filesystem,
)
else:
pyarrow.parquet.write_table(
table=pyarrow.concat_tables(tables),
where=self.path,
filesystem=self.filesystem,
)

return pyarrow.dataset.dataset(
self.path, format="parquet", filesystem=self.filesystem
)
Expand Down Expand Up @@ -202,7 +206,7 @@ def write_parquet(
path: Union[str, Path],
*,
write_in_worker: bool = False,
num_rows_per_file: int = 1024,
num_rows_per_file: Optional[int] = None,
overwrite: bool = False,
filesystem: Optional[pyarrow.fs.FileSystem] = None,
accumulate: bool = True,
Expand All @@ -216,7 +220,7 @@ def write_parquet(

return data.write(
ParquetWriter(
path,
path=path,
num_rows_per_file=num_rows_per_file,
overwrite=overwrite,
write_in_worker=write_in_worker,
Expand Down
31 changes: 31 additions & 0 deletions edspdf/lazy_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,37 @@ def to(self, device: Union[str, Optional["torch.device"]] = None): # noqa F821
pipe.to(device)
return self

def train(self, mode=True):
"""
Enables training mode on pytorch modules
Parameters
----------
mode: bool
Whether to enable training or not
"""

class context:
def __enter__(self):
pass

def __exit__(ctx_self, type, value, traceback):
for name, proc in procs:
proc.train(was_training[name])

procs = [x for x in self.torch_components() if hasattr(x[1], "train")]
was_training = {name: proc.training for name, proc in procs}
for name, proc in procs:
proc.train(mode)

return context()

def eval(self):
"""
Enables evaluation mode on pytorch modules
"""
return self.train(False)

def worker_copy(self):
return LazyCollection(
reader=self.reader.worker_copy(),
Expand Down
9 changes: 7 additions & 2 deletions edspdf/pipes/classifiers/trainable.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import os
from pathlib import Path
from typing import Any, Dict, Iterable, Sequence, Set
from typing import Any, Dict, Iterable, List, Sequence, Set

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -201,7 +201,12 @@ def forward(self, batch: Dict) -> Dict:

return output

def postprocess(self, docs: Sequence[PDFDoc], output: Dict) -> Sequence[PDFDoc]:
def postprocess(
self,
docs: Sequence[PDFDoc],
output: Dict,
inputs: List[Dict[str, Any]],
) -> Sequence[PDFDoc]:
for b, label in zip(
(b for doc in docs for b in doc.text_boxes),
output["labels"].tolist(),
Expand Down
Loading

0 comments on commit 06f142a

Please sign in to comment.