diff --git a/src/transformers/feature_extraction_utils.py b/src/transformers/feature_extraction_utils.py index f3cde8180c1bd4..6e8007edbc0b78 100644 --- a/src/transformers/feature_extraction_utils.py +++ b/src/transformers/feature_extraction_utils.py @@ -213,6 +213,7 @@ def to(self, *args, **kwargs) -> "BatchFeature": Will be passed to the `to(...)` function of the tensors. kwargs (`Dict`, *optional*): Will be passed to the `to(...)` function of the tensors. + To enable asynchronous data transfer, set the `non_blocking` flag in `kwargs` (defaults to `False`). Returns: [`BatchFeature`]: The same instance after modification. @@ -222,6 +223,7 @@ def to(self, *args, **kwargs) -> "BatchFeature": new_data = {} device = kwargs.get("device") + non_blocking = kwargs.get("non_blocking", False) # Check if the args are a device or a dtype if device is None and len(args) > 0: # device should be always the first argument @@ -241,7 +243,7 @@ def to(self, *args, **kwargs) -> "BatchFeature": # cast and send to device new_data[k] = v.to(*args, **kwargs) elif isinstance(v, torch.Tensor) and device is not None: - new_data[k] = v.to(device=device) + new_data[k] = v.to(device=device, non_blocking=non_blocking) else: new_data[k] = v self.data = new_data diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 0bfcc4aa303665..f4e5b9b3aaf314 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -799,12 +799,13 @@ def as_tensor(value, dtype=None): return self - def to(self, device: Union[str, "torch.device"]) -> "BatchEncoding": + def to(self, device: Union[str, "torch.device"], *, non_blocking: bool = False) -> "BatchEncoding": """ - Send all values to device by calling `v.to(device)` (PyTorch only). + Send all values to device by calling `v.to(device, non_blocking=non_blocking)` (PyTorch only). Args: device (`str` or `torch.device`): The device to put the tensors on. + non_blocking (`bool`): Whether to perform the copy asynchronously. Returns: [`BatchEncoding`]: The same instance after modification. @@ -816,7 +817,10 @@ def to(self, device: Union[str, "torch.device"]) -> "BatchEncoding": # Otherwise it passes the casts down and casts the LongTensor containing the token idxs # into a HalfTensor if isinstance(device, str) or is_torch_device(device) or isinstance(device, int): - self.data = {k: v.to(device=device) if isinstance(v, torch.Tensor) else v for k, v in self.data.items()} + self.data = { + k: v.to(device=device, non_blocking=non_blocking) if isinstance(v, torch.Tensor) else v + for k, v in self.data.items() + } else: logger.warning(f"Attempting to cast a BatchEncoding to type {str(device)}. This is not supported.") return self