Skip to content

Commit

Permalink
Option to set 'non_blocking' for to(device) in BatchEncoding and Batc…
Browse files Browse the repository at this point in the history
…hFeature (huggingface#34883)

* Option to set 'non_blocking' for to(device) operation for performance improvements. Defaults to 'false', thus no behavioral changes.

* Enabling non_blocking in to() operation of BatchFeature.

* Improved docstring on utilization of non_blocking

* Force non_blocking as keyword argument

Co-authored-by: Pavel Iakubovskii <[email protected]>

---------

Co-authored-by: Daniel Bogdoll <[email protected]>
Co-authored-by: Pavel Iakubovskii <[email protected]>
  • Loading branch information
3 people authored Dec 9, 2024
1 parent 1452dc2 commit de8a0b7
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
4 changes: 3 additions & 1 deletion src/transformers/feature_extraction_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def to(self, *args, **kwargs) -> "BatchFeature":
Will be passed to the `to(...)` function of the tensors.
kwargs (`Dict`, *optional*):
Will be passed to the `to(...)` function of the tensors.
To enable asynchronous data transfer, set the `non_blocking` flag in `kwargs` (defaults to `False`).
Returns:
[`BatchFeature`]: The same instance after modification.
Expand All @@ -222,6 +223,7 @@ def to(self, *args, **kwargs) -> "BatchFeature":

new_data = {}
device = kwargs.get("device")
non_blocking = kwargs.get("non_blocking", False)
# Check if the args are a device or a dtype
if device is None and len(args) > 0:
# device should be always the first argument
Expand All @@ -241,7 +243,7 @@ def to(self, *args, **kwargs) -> "BatchFeature":
# cast and send to device
new_data[k] = v.to(*args, **kwargs)
elif isinstance(v, torch.Tensor) and device is not None:
new_data[k] = v.to(device=device)
new_data[k] = v.to(device=device, non_blocking=non_blocking)
else:
new_data[k] = v
self.data = new_data
Expand Down
10 changes: 7 additions & 3 deletions src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,12 +799,13 @@ def as_tensor(value, dtype=None):

return self

def to(self, device: Union[str, "torch.device"]) -> "BatchEncoding":
def to(self, device: Union[str, "torch.device"], *, non_blocking: bool = False) -> "BatchEncoding":
"""
Send all values to device by calling `v.to(device)` (PyTorch only).
Send all values to device by calling `v.to(device, non_blocking=non_blocking)` (PyTorch only).
Args:
device (`str` or `torch.device`): The device to put the tensors on.
non_blocking (`bool`): Whether to perform the copy asynchronously.
Returns:
[`BatchEncoding`]: The same instance after modification.
Expand All @@ -816,7 +817,10 @@ def to(self, device: Union[str, "torch.device"]) -> "BatchEncoding":
# Otherwise it passes the casts down and casts the LongTensor containing the token idxs
# into a HalfTensor
if isinstance(device, str) or is_torch_device(device) or isinstance(device, int):
self.data = {k: v.to(device=device) if isinstance(v, torch.Tensor) else v for k, v in self.data.items()}
self.data = {
k: v.to(device=device, non_blocking=non_blocking) if isinstance(v, torch.Tensor) else v
for k, v in self.data.items()
}
else:
logger.warning(f"Attempting to cast a BatchEncoding to type {str(device)}. This is not supported.")
return self
Expand Down

0 comments on commit de8a0b7

Please sign in to comment.