From de8a0b7547451f8f7ef2c0ac1f338ba77c614cec Mon Sep 17 00:00:00 2001 From: Daniel Bogdoll Date: Mon, 9 Dec 2024 05:29:04 -0500 Subject: [PATCH] Option to set 'non_blocking' for to(device) in BatchEncoding and BatchFeature (#34883) * Option to set 'non_blocking' for to(device) operation for performance improvements. Defaults to 'false', thus no behavioral changes. * Enabling non_blocking in to() operation of BatchFeature. * Improved docstring on utilization of non_blocking * Force non_blocking as keyword argument Co-authored-by: Pavel Iakubovskii --------- Co-authored-by: Daniel Bogdoll Co-authored-by: Pavel Iakubovskii --- src/transformers/feature_extraction_utils.py | 4 +++- src/transformers/tokenization_utils_base.py | 10 +++++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/transformers/feature_extraction_utils.py b/src/transformers/feature_extraction_utils.py index f3cde8180c1bd4..6e8007edbc0b78 100644 --- a/src/transformers/feature_extraction_utils.py +++ b/src/transformers/feature_extraction_utils.py @@ -213,6 +213,7 @@ def to(self, *args, **kwargs) -> "BatchFeature": Will be passed to the `to(...)` function of the tensors. kwargs (`Dict`, *optional*): Will be passed to the `to(...)` function of the tensors. + To enable asynchronous data transfer, set the `non_blocking` flag in `kwargs` (defaults to `False`). Returns: [`BatchFeature`]: The same instance after modification. @@ -222,6 +223,7 @@ def to(self, *args, **kwargs) -> "BatchFeature": new_data = {} device = kwargs.get("device") + non_blocking = kwargs.get("non_blocking", False) # Check if the args are a device or a dtype if device is None and len(args) > 0: # device should be always the first argument @@ -241,7 +243,7 @@ def to(self, *args, **kwargs) -> "BatchFeature": # cast and send to device new_data[k] = v.to(*args, **kwargs) elif isinstance(v, torch.Tensor) and device is not None: - new_data[k] = v.to(device=device) + new_data[k] = v.to(device=device, non_blocking=non_blocking) else: new_data[k] = v self.data = new_data diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 0bfcc4aa303665..f4e5b9b3aaf314 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -799,12 +799,13 @@ def as_tensor(value, dtype=None): return self - def to(self, device: Union[str, "torch.device"]) -> "BatchEncoding": + def to(self, device: Union[str, "torch.device"], *, non_blocking: bool = False) -> "BatchEncoding": """ - Send all values to device by calling `v.to(device)` (PyTorch only). + Send all values to device by calling `v.to(device, non_blocking=non_blocking)` (PyTorch only). Args: device (`str` or `torch.device`): The device to put the tensors on. + non_blocking (`bool`): Whether to perform the copy asynchronously. Returns: [`BatchEncoding`]: The same instance after modification. @@ -816,7 +817,10 @@ def to(self, device: Union[str, "torch.device"]) -> "BatchEncoding": # Otherwise it passes the casts down and casts the LongTensor containing the token idxs # into a HalfTensor if isinstance(device, str) or is_torch_device(device) or isinstance(device, int): - self.data = {k: v.to(device=device) if isinstance(v, torch.Tensor) else v for k, v in self.data.items()} + self.data = { + k: v.to(device=device, non_blocking=non_blocking) if isinstance(v, torch.Tensor) else v + for k, v in self.data.items() + } else: logger.warning(f"Attempting to cast a BatchEncoding to type {str(device)}. This is not supported.") return self