Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Huggingface #17

Merged
merged 9 commits into from
Sep 1, 2023
Merged
14 changes: 13 additions & 1 deletion changelog.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,20 @@
# Changelog

## v0.7.1 (Pending)
## Unreleased

### Added

- Add multi-modal transformers (`huggingface-embedding`) with windowing options
- Add `render_page` option to `pdfminer` extractor, for multi-modal PDF features

### Changed

- Updated API to follow EDS-NLP's refactoring
- Updated `confit` to 0.4.2 (better errors) and `foldedtensor` to 0.3.0 (better multiprocess support)
- Better test coverage

### Fixed

- Fixed `attrs` dependency only being installed in dev mode

## v0.7.0
Expand Down
2 changes: 1 addition & 1 deletion docs/assets/termynal/termynal.css
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ a[data-terminal-control] {

[data-ty="input"]:before,
[data-ty-prompt]:before {
margin-right: 0.75em;
margin-right: 0.72em;
color: var(--color-text-subtle);
}

Expand Down
3 changes: 3 additions & 0 deletions docs/pipes/embeddings/assets/transformer-windowing.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
295 changes: 196 additions & 99 deletions docs/recipes/training.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,69 +3,108 @@
In this chapter, we'll see how we can train a deep-learning based classifier to better classify the lines of the
document and extract texts from the document.

The architecture of the trainable classifier of this recipe is described in the following figure:
![Architecture of the trainable classifier](resources/deep-learning-architecture.svg)

## Step-by-step walkthrough

Training supervised models consists in feeding batches of samples taken from a training corpus
to a model instantiated from a given architecture and optimizing the learnable weights of the
model to decrease a given loss. The process of training a pipeline with EDS-PDF is as follows:

1. We first start by seeding the random states and instantiating a new trainable pipeline
```python
from edspdf import Pipeline
from edspdf.utils.random import set_seed
1. We first start by seeding the random states and instantiating a new trainable pipeline. Here we show two examples of pipeline, the first one based on a custom embedding architecture and the second one based on a pre-trained HuggingFace transformer model.

=== "Custom architecture"

set_seed(42)

model = Pipeline()
model.add_pipe("pdfminer-extractor", name="extractor")
model.add_pipe(
"box-transformer",
name="embedding",
config={
"num_heads": 4,
"dropout_p": 0.1,
"activation": "gelu",
"init_resweight": 0.01,
"head_size": 16,
"attention_mode": ["c2c", "c2p", "p2c"],
"n_layers": 1,
"n_relative_positions": 64,
"embedding": {
"@factory": "embedding-combiner",
The architecture of the trainable classifier of this recipe is described in the following figure:
![Architecture of the trainable classifier](resources/deep-learning-architecture.svg)

```{ .python .annotate }
from edspdf import Pipeline
from edspdf.utils.random import set_seed

set_seed(42)

model = Pipeline()
model.add_pipe("pdfminer-extractor", name="extractor") # (1)
model.add_pipe(
"box-transformer",
name="embedding",
config={
"num_heads": 4,
"dropout_p": 0.1,
"text_encoder": {
"@factory": "sub-box-cnn-pooler",
"out_channels": 64,
"kernel_sizes": (3, 4, 5),
"embedding": {
"@factory": "simple-text-embedding",
"activation": "gelu",
"init_resweight": 0.01,
"head_size": 16,
"attention_mode": ["c2c", "c2p", "p2c"],
"n_layers": 1,
"n_relative_positions": 64,
"embedding": {
"@factory": "embedding-combiner",
"dropout_p": 0.1,
"text_encoder": {
"@factory": "sub-box-cnn-pooler",
"out_channels": 64,
"kernel_sizes": (3, 4, 5),
"embedding": {
"@factory": "simple-text-embedding",
"size": 72,
},
},
"layout_encoder": {
"@factory": "box-layout-embedding",
"n_positions": 64,
"x_mode": "learned",
"y_mode": "learned",
"w_mode": "learned",
"h_mode": "learned",
"size": 72,
},
},
"layout_encoder": {
"@factory": "box-layout-embedding",
"n_positions": 64,
"x_mode": "learned",
"y_mode": "learned",
"w_mode": "learned",
"h_mode": "learned",
"size": 72,
},
},
},
)
model.add_pipe(
"trainable-classifier",
name="classifier",
config={
"embedding": model.get_pipe("embedding"),
"labels": [],
},
)
```
)
model.add_pipe(
"trainable-classifier",
name="classifier",
config={
"embedding": model.get_pipe("embedding"),
"labels": [],
},
)
```

1. You can choose between multiple extractors, such as "pdfminer-extractor", "mupdf-extractor" or "poppler-extractor" (the latter does not support rendering images). See the extractors list here [extractors](/pipes/extractors) for more details.

=== "Pre-trained HuggingFace transformer"

```{ .python .annotate }
model = Pipeline()
model.add_pipe(
"mupdf-extractor",
name="extractor",
config={
"render_pages": True,
},
) # (1)
model.add_pipe(
"huggingface-embedding",
name="embedding",
config={
"model": "microsoft/layoutlmv3-base",
"use_image": False,
"window": 128,
"stride": 64,
"line_pooling": "mean",
},
)
model.add_pipe(
"trainable-classifier",
name="classifier",
config={
"embedding": model.get_pipe("embedding"),
"labels": [],
},
)
```

1. You can choose between multiple extractors, such as "pdfminer-extractor", "mupdf-extractor" or "poppler-extractor" (the latter does not support rendering images). See the extractors list here [extractors](/pipes/extractors) for more details.

2. We then load and adapt (i.e., convert into PDFDoc) the training and validation dataset, which is often a combination of JSON and PDF files. The recommended way of doing this is to make a Python generator of PDFDoc objects.
```python
Expand Down Expand Up @@ -369,55 +408,113 @@ At the end of the training, the pipeline is ready to use (with the `.pipe` metho

## Configuration

To decouple the configuration and the code of our training script, let's define a configuration file where we will describe **both** our training parameters and the pipeline.

```toml title="config.cfg"
# This is this equivalent of the API-based declaration at the beginning of the tutorial
[pipeline]
components = ["extractor","classifier"]
components_config = ${components}

[components]

[components.extractor]
@factory = pdfminer-extractor

[components.classifier]
@factory = trainable-classifier

[components.classifier.embedding]
@factory = box-embedding
size = 72
dropout_p = 0.1

[components.classifier.embedding.text_encoder]
@factory = "box-text-embedding"

[components.classifier.embedding.text_encoder.pooler]
@factory = "cnn-pooler"
out_channels = 64
kernel_sizes = [3,4,5]

[components.classifier.embedding.layout_encoder]
@factory = "box-layout-embedding"
n_positions = 64
x_mode = sin
y_mode = sin
w_mode = sin
h_mode = sin

# This is were we define the training script parameters
# the "train" section refers to the name of the command in the training script
[train]
model = ${pipeline}
train_data = {"@adapter": "my-segmentation-adapter", "path": "data/train"}
val_data = {"@adapter": "my-segmentation-adapter", "path": "data/val"}
max_steps = 1000
seed = 42
lr = 3e-4
batch_size = 4
To decouple the configuration and the code of our training script, let's define a configuration file where we will describe **both** our training parameters and the pipeline. You can either write the config of the pipeline by hand, or generate it from an instantiated pipeline by running:

```python
print(pipeline.config.to_str())
```

=== "Custom architecture"

```toml title="config.cfg"
# This is this equivalent of the API-based declaration at the beginning of the tutorial
[pipeline]
pipeline = ["extractor", "embedding", "classifier"]
disabled = []
components = ${components}

[components]

[components.extractor]
@factory = "pdfminer-extractor"

[components.embedding]
@factory = "box-transformer"
num_heads = 4
dropout_p = 0.1
activation = "gelu"
init_resweight = 0.01
head_size = 16
attention_mode = ["c2c", "c2p", "p2c"]
n_layers = 1
n_relative_positions = 64

[components.embedding.embedding]
@factory = "embedding-combiner"
dropout_p = 0.1

[components.embedding.embedding.text_encoder]
@factory = "sub-box-cnn-pooler"
out_channels = 64
kernel_sizes = (3, 4, 5)

[components.embedding.embedding.text_encoder.embedding]
@factory = "simple-text-embedding"
size = 72

[components.embedding.embedding.layout_encoder]
@factory = "box-layout-embedding"
n_positions = 64
x_mode = "learned"
y_mode = "learned"
w_mode = "learned"
h_mode = "learned"
size = 72

[components.classifier]
@factory = "trainable-classifier"
embedding = ${components.embedding}
labels = []

# This is were we define the training script parameters
# the "train" section refers to the name of the command in the training script
[train]
model = ${pipeline}
train_data = {"@adapter": "my-segmentation-adapter", "path": "data/train"}
val_data = {"@adapter": "my-segmentation-adapter", "path": "data/val"}
max_steps = 1000
seed = 42
lr = 3e-4
batch_size = 4
```

=== "Pretrained Huggingface Transformer"

```toml title="config.cfg"
[pipeline]
pipeline = ["extractor", "embedding", "classifier"]
disabled = []
components = ${components}

[components]

[components.extractor]
@factory = "mupdf-extractor"
render_pages = true

[components.embedding]
@factory = "huggingface-embedding"
model = "microsoft/layoutlmv3-base"
use_image = false
window = 128
stride = 64
line_pooling = "mean"

[components.classifier]
@factory = "trainable-classifier"
embedding = ${components.embedding}
labels = []

[train]
model = ${pipeline}
max_steps = 1000
lr = 5e-5
seed = 42
train_data = {"@adapter": "my-segmentation-adapter", "path": "data/train"}
val_data = {"@adapter": "my-segmentation-adapter", "path": "data/val"}
batch_size = 8
```

and update our training script to use the pipeline and the data adapters defined in the configuration file instead of the Python declaration :

```diff
Expand Down Expand Up @@ -497,5 +594,5 @@ def train_my_model(
That's it ! We can now call the training script with the configuration file as a parameter, and override some of its defaults values:

```bash
python train.py --config config.cfg --seed 43
python train.py --config config.cfg --components.extractor.extract_styles=true --seed 43
```
14 changes: 3 additions & 11 deletions edspdf/layers/relative_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,10 @@ def arange_at_dim(n, dim, ndim):


class GroupedLinear(torch.nn.Module):
def __init__(self, input_size, output_size, bias=True, n_groups=1):
def __init__(self, input_size, output_size, n_groups=1):
super().__init__()
self.n_groups = n_groups
if bias:
self.bias = torch.nn.Parameter(torch.zeros(n_groups, output_size))
else:
self.bias = None
self.bias = torch.nn.Parameter(torch.zeros(n_groups, output_size))
self.weight = torch.nn.Parameter(
torch.stack(
[
Expand All @@ -50,12 +47,7 @@ def __init__(self, input_size, output_size, bias=True, n_groups=1):
)
)

def forward(self, x, reshape=True):
if not reshape:
x = torch.einsum("...ni,nio->...no", x, self.weight)
if self.bias is not None:
x = x + self.bias
return x
def forward(self, x):
(*base_shape, dim) = x.shape
x = x.reshape(*base_shape, self.n_groups, dim // self.n_groups)
x = torch.einsum("...ni,nio->...no", x, self.weight)
Expand Down
Loading
Loading