-
Notifications
You must be signed in to change notification settings - Fork 0
/
processors.py
102 lines (81 loc) · 4.04 KB
/
processors.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from contextlib import contextmanager
from transformers.feature_extraction_utils import FeatureExtractionMixin
import math
import torch
from torch import nn
##processor 만들기
class TrOCRProcessor:
def __init__(self, feature_extractor, tokenizer):
if not isinstance(feature_extractor, FeatureExtractionMixin):
raise ValueError(
f"`feature_extractor` has to be of type {FeatureExtractionMixin.__class__}, but is {type(feature_extractor)}"
)
self.feature_extractor = feature_extractor
self.tokenizer = tokenizer
self.current_processor = self.feature_extractor
def __call__(self, *args, **kwargs):
return self.current_processor(*args, **kwargs)
def batch_decode(self, *args, **kwargs):
return self.tokenizer.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
return self.tokenizer.decode(*args, **kwargs)
@contextmanager
def as_target_processor(self):
self.current_processor = self.tokenizer
yield
self.current_processor = self.feature_extractor
class TrOCRSinusoidalPositionalEmbedding(nn.Module):
"""This module produces sinusoidal positional embeddings of any length."""
def __init__(self, num_positions: int, embedding_dim: int, padding_idx=None):
super().__init__()
self.offset = 2
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.weights = self.get_embedding(num_positions, embedding_dim, padding_idx)
self.register_buffer("_float_tensor", torch.FloatTensor(1))
@staticmethod
def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx=None):
"""
Build sinusoidal embeddings. This matches the implementation in tensor2tensor, but differs slightly from the
description in Section 3.5 of "Attention Is All You Need".
"""
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
if embedding_dim % 2 == 1:
# zero pad
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
if padding_idx is not None:
emb[padding_idx, :] = 0
return emb
@torch.no_grad()
def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
bsz, seq_len = input_ids.size()
# Create the position ids from the input token ids. Any padded tokens remain padded.
position_ids = self.create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length).to(
input_ids.device
)
# expand embeddings if needed
max_pos = self.padding_idx + 1 + seq_len
if self.weights is None or max_pos > self.weights.size(0):
# recompute/expand embeddings if needed
self.weights = self.get_embedding(max_pos, self.embedding_dim, self.padding_idx)
self.weights = self.weights.to(self._float_tensor)
x = self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, -1).detach()
return x
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
"""
Shift input ids one token to the right.
"""
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
if decoder_start_token_id is None:
raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.")
shifted_input_ids[:, 0] = decoder_start_token_id
if pad_token_id is None:
raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.")
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
return shifted_input_ids