From 0cc29937115f9858272c51e8faffe31155915901 Mon Sep 17 00:00:00 2001 From: Vibhu Jawa Date: Mon, 5 Aug 2024 15:53:06 -0700 Subject: [PATCH] [REVIEW] Fix Padding Related Bugs: `Crossfit` (#66) * Add crossfit bits Signed-off-by: Vibhu Jawa * Add padding fixes Signed-off-by: Vibhu Jawa * Fix test Signed-off-by: Vibhu Jawa * Add docstrings Signed-off-by: Vibhu Jawa * fix torch import Signed-off-by: Vibhu Jawa * fix torch import Signed-off-by: Vibhu Jawa * fix padding to only pad the last dim Signed-off-by: Vibhu Jawa * fix padding tests Signed-off-by: Vibhu Jawa * Add test for left/right Signed-off-by: Vibhu Jawa * Skip test for cf_loader Signed-off-by: Vibhu Jawa * Fix bugs in clipping Signed-off-by: Vibhu Jawa * Fix bugs in clipping Signed-off-by: Vibhu Jawa * Add early stopping to HF memory estimation Signed-off-by: Vibhu Jawa * Fix copy-right year Signed-off-by: Vibhu Jawa * Add copyright year Signed-off-by: Vibhu Jawa * Address last of Ryan's reviews Signed-off-by: Vibhu Jawa * Skip loading model if its allready fitted Signed-off-by: Vibhu Jawa * Use self.load_cfg instead of AutoConfig.from_pretrained Signed-off-by: Vibhu Jawa * Use self.load_cfg instead of AutoConfig.from_pretrained Signed-off-by: Vibhu Jawa * Fix memory_curve_utils and skip loading cfg/tokenizer here Signed-off-by: Vibhu Jawa --------- Signed-off-by: Vibhu Jawa --- crossfit/__init__.py | 1 + .../backend/torch/hf/memory_curve_utils.py | 89 ++++++++++++ crossfit/backend/torch/hf/model.py | 127 ++++++++---------- crossfit/backend/torch/loader.py | 46 +++++-- crossfit/backend/torch/op/base.py | 8 +- crossfit/op/tokenize.py | 78 +++++++++-- crossfit/utils/torch_utils.py | 94 +++++++++++++ tests/op/test_loader.py | 55 ++++++++ tests/op/test_tokenize.py | 108 ++++++++++++++- tests/test_torch_utils.py | 107 +++++++++++++++ 10 files changed, 615 insertions(+), 98 deletions(-) create mode 100644 crossfit/backend/torch/hf/memory_curve_utils.py create mode 100644 crossfit/utils/torch_utils.py create mode 100644 tests/op/test_loader.py create mode 100644 tests/test_torch_utils.py diff --git a/crossfit/__init__.py b/crossfit/__init__.py index bf0141e..5e563f0 100644 --- a/crossfit/__init__.py +++ b/crossfit/__init__.py @@ -84,6 +84,7 @@ def __call__(self, *args, **kwargs): load_dataset = LazyLoader("crossfit.dataset.load.load_dataset") embed = LazyLoader("crossfit.report.beir.embed.embed") beir_report = LazyLoader("crossfit.report.beir.report.beir_report") +utils = LazyLoader("crossfit.utils") __all__.extend( [ diff --git a/crossfit/backend/torch/hf/memory_curve_utils.py b/crossfit/backend/torch/hf/memory_curve_utils.py new file mode 100644 index 0000000..dfbe47e --- /dev/null +++ b/crossfit/backend/torch/hf/memory_curve_utils.py @@ -0,0 +1,89 @@ +# Copyright 2024 NVIDIA CORPORATION +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc + +import joblib +import numpy as np +import torch +from sklearn.linear_model import LinearRegression +from tqdm import tqdm +from transformers import PreTrainedModel + +from crossfit.utils.model_adapter import adapt_model_input + + +def fit_memory_estimate_curve( + model: PreTrainedModel, + path_or_name: str, + start_batch_size: int = 1, + end_batch_size: int = 2048, + batch_size_increment: int = 256, + start_seq_len: int = 1, + end_seq_len: int = 2048, + seq_len_increment: int = 64, + mem_model_path: str = None, +) -> LinearRegression: + print(f"Fitting memory estimate curve for model: {path_or_name}") + + device = next(model.parameters()).device + X: list[list[int]] = [] + y: list[float] = [] + + batch_size_pbar = tqdm( + range(start_batch_size, end_batch_size + 1, batch_size_increment), desc="Batch size" + ) + for batch_size in batch_size_pbar: + seq_len_pbar = tqdm( + range(start_seq_len, end_seq_len + 1, seq_len_increment), + desc="Sequence length", + leave=False, + ) + for seq_len in seq_len_pbar: + torch.cuda.reset_peak_memory_stats() + + batch = { + "input_ids": torch.randint(1, 501, (batch_size, seq_len)).to(device=device), + "attention_mask": torch.ones((batch_size, seq_len)).to(device=device), + } + + try: + _ = adapt_model_input(model, batch) + memory_used = torch.cuda.max_memory_allocated() / (1024**2) # Convert to MB + X.append([batch_size, seq_len, seq_len**2]) + y.append(memory_used) + + except RuntimeError as e: + if "out of memory" in str(e) or "out_of_memory" in str(e): + # Early stopping for this batch size + seq_len_pbar.close() + break + else: + raise e + finally: + del batch + if "outputs" in vars(): + del outputs + gc.collect() + torch.cuda.empty_cache() + + # Check if we've hit the memory limit for all sequence lengths + if seq_len == start_seq_len: + batch_size_pbar.close() + break + + mem_model = LinearRegression().fit(np.array(X), np.array(y)) + joblib.dump(mem_model, mem_model_path) + + return mem_model diff --git a/crossfit/backend/torch/hf/model.py b/crossfit/backend/torch/hf/model.py index 20bee31..df72a38 100644 --- a/crossfit/backend/torch/hf/model.py +++ b/crossfit/backend/torch/hf/model.py @@ -19,24 +19,65 @@ import joblib import numpy as np import torch -from sklearn.linear_model import LinearRegression -from tqdm import tqdm from transformers import AutoConfig, AutoModel, AutoTokenizer +from crossfit.backend.torch.hf.memory_curve_utils import fit_memory_estimate_curve from crossfit.backend.torch.model import Model from crossfit.dataset.home import CF_HOME -from crossfit.utils.model_adapter import adapt_model_input class HFModel(Model): - def __init__(self, path_or_name: str, max_mem_gb: int = 16, training=False): + def __init__( + self, + path_or_name: str, + max_mem_gb: int = 16, + training: bool = False, + start_batch_size: int = 1, + end_batch_size: int = 2048, + batch_size_increment: int = 256, + start_seq_len: int = 1, + seq_len_increment: int = 64, + ): super().__init__(path_or_name, max_mem_gb) + self.start_batch_size = start_batch_size + self.end_batch_size = end_batch_size + self.batch_size_increment = batch_size_increment + self.start_seq_len = start_seq_len + self.seq_len_increment = seq_len_increment - if not training: - with torch.no_grad(): - self.fit_memory_estimate_curve() + cache_dir = os.path.join(CF_HOME, "memory", self.load_cfg()._name_or_path) + os.makedirs(cache_dir, exist_ok=True) + mem_model_path = os.path.join(cache_dir, "mem_model.pkl") + if os.path.exists(mem_model_path): + self.mem = joblib.load(mem_model_path) else: - self.fit_memory_estimate_curve() + model = self.load_model("cuda") if not training else None + end_seq_len = self.max_seq_length() + if not training: + with torch.no_grad(): + self.mem = fit_memory_estimate_curve( + model=model, + path_or_name=self.path_or_name, + start_batch_size=start_batch_size, + end_batch_size=end_batch_size, + start_seq_len=start_seq_len, + end_seq_len=end_seq_len, + batch_size_increment=batch_size_increment, + seq_len_increment=seq_len_increment, + mem_model_path=mem_model_path, + ) + else: + self.mem = fit_memory_estimate_curve( + model=model, + path_or_name=self.path_or_name, + start_batch_size=start_batch_size, + end_batch_size=end_batch_size, + start_seq_len=start_seq_len, + end_seq_len=end_seq_len, + batch_size_increment=batch_size_increment, + seq_len_increment=seq_len_increment, + mem_model_path=mem_model_path, + ) def load_on_worker(self, worker, device="cuda"): worker.torch_model = self.load_model(device) @@ -60,69 +101,6 @@ def load_tokenizer(self): def load_cfg(self): return AutoConfig.from_pretrained(self.path_or_name) - def fit_memory_estimate_curve(self, model=None): - remove_model = False - if model is None: - remove_model = True - model = self.load_model(device="cuda") - - cache_dir = os.path.join(CF_HOME, "memory", self.load_cfg()._name_or_path) - mem_model_path = os.path.join(cache_dir, "mem_model.pkl") - - if os.path.exists(mem_model_path): - self.mem = joblib.load(mem_model_path) - - return self - - print(f"Fitting memory estimate curve for model: {self.path_or_name}") - - device = next(model.parameters()).device - X = [] - y = [] - - max_seq = self.max_seq_length() - for batch_size in tqdm(range(2048, 0, -256)): - if batch_size <= 0: - continue - - for seq_len in range(max_seq, 0, -64): - if seq_len <= 0: - continue - - torch.cuda.reset_peak_memory_stats() - - batch = { - "input_ids": torch.randint(1, 501, (batch_size, seq_len)).to(device=device), - "attention_mask": torch.ones((batch_size, seq_len)).to(device=device), - } - - try: - _ = adapt_model_input(model, batch) - memory_used = torch.cuda.max_memory_allocated() / (1024**2) # Convert to MB - X.append([batch_size, seq_len, seq_len**2]) - y.append(memory_used) - - except RuntimeError as e: - if "out of memory" in str(e) or "out_of_memory" in str(e): - pass - else: - raise e - finally: - del batch - if "outputs" in vars(): - del outputs - gc.collect() - torch.cuda.empty_cache() - - self.mem = LinearRegression().fit(np.array(X), np.array(y)) - os.makedirs(cache_dir, exist_ok=True) - joblib.dump(self.mem, mem_model_path) - - if remove_model: - del model - gc.collect() - torch.cuda.empty_cache() - def estimate_memory(self, max_num_tokens: int, batch_size: int) -> int: predicted_memory = self.mem.predict( np.array([[batch_size, max_num_tokens, max_num_tokens**2]]) @@ -130,7 +108,12 @@ def estimate_memory(self, max_num_tokens: int, batch_size: int) -> int: return predicted_memory[0] / 1024 # Convert from MB to GB def max_seq_length(self) -> int: - return self.load_cfg().max_position_embeddings + max_seq_length = self.load_tokenizer().model_max_length + # Guard against the HF bug + # which sets max_seq_length to max(int) for some models + if max_seq_length > 1e5: + max_seq_length = self.load_cfg().max_position_embeddings + return max_seq_length class SentenceTransformerModel(HFModel): diff --git a/crossfit/backend/torch/loader.py b/crossfit/backend/torch/loader.py index f523367..3514df2 100644 --- a/crossfit/backend/torch/loader.py +++ b/crossfit/backend/torch/loader.py @@ -1,4 +1,4 @@ -# Copyright 2023 NVIDIA CORPORATION +# Copyright 2024 NVIDIA CORPORATION # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,6 +22,7 @@ from crossfit.data.array.conversion import convert_array from crossfit.data.array.dispatch import crossarray from crossfit.data.dataframe.dispatch import CrossFrame +from crossfit.op.tokenize import clip_tokens from crossfit.utils.model_adapter import adapt_model_input DEFAULT_BATCH_SIZE = 512 @@ -36,7 +37,14 @@ def __init__(self, data: Dict[str, torch.Tensor], batch_size: int, progress_bar= def __init__(self, data: CrossFrame, batch_size: int, progress_bar=None): ... - def __init__(self, data, batch_size: int, progress_bar=None, max_seq_len=None): + def __init__( + self, + data, + batch_size: int, + progress_bar=None, + max_seq_len=None, + padding_side: str = "right", + ): self.data = CrossFrame(data).cast(torch.Tensor) self.tensor_dict = self.data.to_dict() self._batch_size = batch_size @@ -45,6 +53,7 @@ def __init__(self, data, batch_size: int, progress_bar=None, max_seq_len=None): self._to_map = [] self.progress_bar = progress_bar self.max_seq_len = max_seq_len + self.padding_side = padding_side def map(self, fn): self._to_map.append(fn) @@ -66,7 +75,10 @@ def __next__(self): batch = {key: val[self.current_idx : end] for key, val in self.tensor_dict.items()} if self.max_seq_len is not None: - batch = {key: val[:, : self.max_seq_len] for key, val in batch.items()} + if self.padding_side == "right": + batch = {key: val[:, : self.max_seq_len] for key, val in batch.items()} + else: + batch = {key: val[:, -self.max_seq_len :] for key, val in batch.items()} self.current_idx += self.batch_size @@ -97,14 +109,27 @@ def __init__( self.to_ignore = to_ignore or [] self.to_ignore.append("seq_length") self.model = model + tokenizer = self.model.load_tokenizer() + pad_token_id = tokenizer.pad_token_id + padding_side = tokenizer.padding_side + if padding_side not in ["right", "left"]: + raise ValueError("padding_side must be either 'right' or 'left'") + + self.pad_token_id = pad_token_id frame = CrossFrame(data).cast(torch.Tensor) - seq_length = (frame[sort_key] != 0).sum(axis=1) + seq_length = (frame[sort_key] != self.pad_token_id).sum(axis=1) self.sorted_indices = seq_length.argsort(descending=True) frame = frame.apply(lambda x: x[self.sorted_indices]) frame = frame.assign(seq_length=seq_length[self.sorted_indices]) - super().__init__(frame, initial_batch_size, progress_bar=progress_bar) + super().__init__( + frame, + initial_batch_size, + progress_bar=progress_bar, + max_seq_len=self.model.max_seq_length(), + padding_side=padding_side, + ) self.splits = self._find_optimal_splits() def sort_column(self, col): @@ -128,8 +153,6 @@ def __next__(self): else: start = self.splits[self.current_idx - 1] - _tokens = self.tensor_dict["seq_length"] - end = min(self.splits[self.current_idx], self.num_rows) while end > start: try: @@ -138,8 +161,13 @@ def __next__(self): for key, val in self.tensor_dict.items() if key not in self.to_ignore } - clip_len = min(max(_tokens[start], _tokens[end - 1]), self.model.max_seq_length()) - batch = {key: val[:, :clip_len] for key, val in batch.items()} + batch = clip_tokens( + token_o=batch, + max_length=self.max_seq_len, + padding_side=self.padding_side, + pad_token_id=self.pad_token_id, + return_type="pt", + ) for fn in self._to_map: batch = adapt_model_input(fn, batch) diff --git a/crossfit/backend/torch/op/base.py b/crossfit/backend/torch/op/base.py index c7883e3..2c86fa1 100644 --- a/crossfit/backend/torch/op/base.py +++ b/crossfit/backend/torch/op/base.py @@ -26,6 +26,7 @@ from crossfit.backend.torch.loader import DEFAULT_BATCH_SIZE, InMemoryLoader, SortedSeqLoader from crossfit.backend.torch.model import Model from crossfit.op.base import Op +from crossfit.utils.torch_utils import concat_and_pad_tensors class Predictor(Op): @@ -66,6 +67,7 @@ def call(self, data, partition_info=None): loader = InMemoryLoader( data[["input_ids", "attention_mask"]], batch_size=self.batch_size, + padding_side=self.model.load_tokenizer().padding_side, progress_bar=self.create_progress_bar(len(data), partition_info), max_seq_len=self.model.max_seq_length(), ) @@ -83,7 +85,11 @@ def call(self, data, partition_info=None): all_outputs_ls.append(output) out = cudf.DataFrame(index=index) - outputs = cp.asarray(torch.cat(all_outputs_ls, dim=0)) + outputs = cp.asarray( + concat_and_pad_tensors( + all_outputs_ls, pad_token_id=loader.pad_token_id, padding_side=loader.padding_side + ) + ) _index = loader.sort_column(index.values) if self.sorted_data_loader else index if len(outputs.shape) <= 2: out[self.pred_output_col] = create_list_series_from_1d_or_2d_ar(outputs, _index) diff --git a/crossfit/op/tokenize.py b/crossfit/op/tokenize.py index e4d659e..6a64855 100644 --- a/crossfit/op/tokenize.py +++ b/crossfit/op/tokenize.py @@ -1,4 +1,4 @@ -# Copyright 2023 NVIDIA CORPORATION +# Copyright 2023-2024 NVIDIA CORPORATION # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,7 +13,7 @@ # limitations under the License. import os from enum import Enum -from typing import Optional, Union +from typing import Dict, Optional, Union import cudf import cupy as cp @@ -31,6 +31,7 @@ class TokenizerType(Enum): SUBWORD = 1 SENTENCE_PIECE = 2 + DEFAULT = 3 class Tokenizer(Op): @@ -55,8 +56,10 @@ def __init__( GPUTokenizer.from_pretrained(self.model) def tokenize_strings(self, sentences, max_length=None): - if self.tokenizer_type == TokenizerType.SENTENCE_PIECE: + if self.tokenizer_type in [TokenizerType.SENTENCE_PIECE, TokenizerType.DEFAULT]: tokenizer = self.model.load_tokenizer() + self.padding_side = tokenizer.padding_side + self.pad_token_id = tokenizer.pad_token_id if isinstance(sentences, cudf.Series): sentences = sentences.to_arrow().to_pylist() @@ -81,6 +84,8 @@ def tokenize_strings(self, sentences, max_length=None): tokenizer = GPUTokenizer.from_pretrained(self.model) worker.tokenizer = tokenizer + self.padding_side = tokenizer.padding_side + self.pad_token_id = tokenizer.pad_token_id return worker.tokenizer( sentences, max_length=max_length or self.max_length, @@ -110,7 +115,13 @@ def call_column(self, data): text = text.str.slice(0, self.max_chars) tokenized_data = self.tokenize_strings(text).copy() - tokenized_data = clip_tokens(tokenized_data, max_length=self.max_length, return_type="cp") + tokenized_data = clip_tokens( + tokenized_data, + max_length=self.max_length, + padding_side=self.padding_side, + pad_token_id=self.pad_token_id, + return_type="cp", + ) input_ids = create_list_series_from_1d_or_2d_ar( tokenized_data["input_ids"].astype("int32"), data.index @@ -173,6 +184,8 @@ def _convert_to_tokenizer_type( tokenizer_type = TokenizerType.SENTENCE_PIECE elif tokenizer_type in ["subword", "bert", TokenizerType.SUBWORD]: tokenizer_type = TokenizerType.SUBWORD + elif tokenizer_type in ["default", TokenizerType.DEFAULT]: + tokenizer_type = TokenizerType.DEFAULT return tokenizer_type @@ -180,6 +193,16 @@ class GPUTokenizer(SubwordTokenizer): def __init__(self, hash_file: str, do_lower_case: bool = True, config=None): super().__init__(str(hash_file), do_lower_case=do_lower_case) self.config = config or {"_name_or_path": hash_file} + self.padding_side = self.config.get("padding_side", "right") + self.pad_token_id = self.config.get("pad_token_id", 0) + if self.padding_side != "right": + raise ValueError( + f"Only right padding is supported for GPUTokenizer, got {self.padding_side}" + ) + if self.pad_token_id != 0: + raise ValueError( + f"Only pad_token_id=0 is supported for GPUTokenizer, got {self.pad_token_id}" + ) @classmethod def get_tokenizer_config(cls, name): @@ -224,17 +247,48 @@ def from_pretrained(cls, name, cache_dir=None): return cls(hashed_vocab_path, config=config) -def clip_tokens(token_o, max_length, return_type="pt"): +def clip_tokens( + token_o: Dict[str, Union[cp.ndarray, torch.Tensor]], + max_length: int, + padding_side: str, + pad_token_id: int, + return_type: str = "pt", +) -> Dict[str, Union[cp.ndarray, torch.Tensor]]: + # Verify non-empty max_length, padding_side, and pad_token_id + if not max_length: + raise ValueError("max_length cannot be empty or zero.") + if not padding_side: + raise ValueError("padding_side cannot be empty.") + if pad_token_id is None: + raise ValueError("pad_token_id cannot be None.") + + # Check if input_ids is a cupy array, if not convert to cupy array if not isinstance(token_o["input_ids"], cp.ndarray): token_o = {k: cp.asarray(v) for k, v in token_o.items()} - clip_len = max_length - int((token_o["input_ids"][:, ::-1] != 0).argmax(1).min()) - token_o["input_ids"] = _cast_to_appropriate_type( - token_o["input_ids"][:, :clip_len], return_type - ) - token_o["attention_mask"] = _cast_to_appropriate_type( - token_o["attention_mask"][:, :clip_len], return_type - ) + # Clip the input_ids and attention_mask based on the padding side + # max_length = min(max_length, token_o["input_ids"].shape[1]) + total_indices = token_o["input_ids"].shape[1] + if padding_side == "right": + clip_len = total_indices - int( + (token_o["input_ids"][:, ::-1] != pad_token_id).argmax(1).min() + ) + clip_len = min(clip_len, max_length) + token_o["input_ids"] = _cast_to_appropriate_type( + token_o["input_ids"][:, :clip_len], return_type + ) + token_o["attention_mask"] = _cast_to_appropriate_type( + token_o["attention_mask"][:, :clip_len], return_type + ) + else: + clip_len = total_indices - int((token_o["input_ids"] != pad_token_id).argmax(1).min()) + clip_len = min(clip_len, max_length) + token_o["input_ids"] = _cast_to_appropriate_type( + token_o["input_ids"][:, -clip_len:], return_type + ) + token_o["attention_mask"] = _cast_to_appropriate_type( + token_o["attention_mask"][:, -clip_len:], return_type + ) if "metadata" in token_o: del token_o["metadata"] diff --git a/crossfit/utils/torch_utils.py b/crossfit/utils/torch_utils.py new file mode 100644 index 0000000..18132c2 --- /dev/null +++ b/crossfit/utils/torch_utils.py @@ -0,0 +1,94 @@ +# Copyright 2024 NVIDIA CORPORATION +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Union + +import torch +import torch.nn.functional as F + + +def pad_tensors( + tensor_list: List[torch.Tensor], + pad_token_id: Union[int, float] = 0, + padding_side: str = "right", +) -> List[torch.Tensor]: + """ + Pad a list of tensors to the same shape. + + This function takes a list of tensors with potentially different shapes and pads them + to match the largest dimensions across all tensors in the list. The padding is applied + to the end of each dimension. + + Args: + tensor_list (List[torch.Tensor]): A list of tensors to be padded. + pad_token_id (Union[int, float], optional): The value to use for padding. Defaults to 0. + + Returns: + List[torch.Tensor]: A list of padded tensors, all with the same shape. + """ + if padding_side not in ["right", "left"]: + raise ValueError("padding_side must be either 'right' or 'left'") + # Find the maximum size for each dimension except the batch dimension + max_sizes = list(tensor_list[0].shape) + for tensor in tensor_list: + for dim in range(1, len(tensor.shape)): + if tensor.shape[dim] > max_sizes[dim]: + max_sizes[dim] = tensor.shape[dim] + + # Pad each tensor to the maximum size for each dimension + padded_tensors = [] + for tensor in tensor_list: + pad_sizes = [] + for i in range(len(tensor.shape) - 1, 0, -1): + pad_size = max_sizes[i] - tensor.shape[i] + if padding_side == "right": + pad_sizes.extend([0, pad_size]) + else: + pad_sizes.extend([pad_size, 0]) + + # Apply padding + padded_tensor = F.pad(tensor, pad_sizes, mode="constant", value=pad_token_id) + padded_tensors.append(padded_tensor) + + return padded_tensors + + +def concat_and_pad_tensors( + all_outputs_ls: List[torch.Tensor], + pad_token_id: Union[int, float] = 0, + padding_side: str = "right", +) -> torch.Tensor: + """ + Concatenate a list of tensors after padding them to the same shape. + + This function first pads all input tensors to the same shape using the `pad_tensors` + function, then concatenates them along the first dimension. + + Args: + all_outputs_ls (List[torch.Tensor]): A list of tensors to be padded and concatenated. + pad_token_id (Union[int, float], optional): The value to use for padding. Defaults to 0. + + Returns: + torch.Tensor: A single tensor resulting from the concatenation of all padded input tensors. + + """ + # Ensure all tensors are on the same device + device = all_outputs_ls[0].device + all_outputs_ls = [tensor.to(device) for tensor in all_outputs_ls] + + # Pad the tensors + padded_outputs = pad_tensors(all_outputs_ls, pad_token_id, padding_side) + + # Concatenate the padded tensors + return torch.cat(padded_outputs, dim=0) diff --git a/tests/op/test_loader.py b/tests/op/test_loader.py new file mode 100644 index 0000000..38a0dcf --- /dev/null +++ b/tests/op/test_loader.py @@ -0,0 +1,55 @@ +# Copyright 2024 NVIDIA CORPORATION +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +transformers = pytest.importorskip("transformers") +torch = pytest.importorskip("torch") +cf_loader = pytest.importorskip("crossfit.backend.torch.loader") + + +def test_padding_side_right(): + sample_data_for_padding = { + "input_ids": torch.tensor([[1, 2, 3, 0, 0], [4, 5, 0, 0, 0], [6, 7, 8, 9, 0]]) + } + + loader = cf_loader.InMemoryLoader( + sample_data_for_padding, batch_size=2, max_seq_len=3, padding_side="right" + ) + batches = list(loader) + + expected_batch_1 = {"input_ids": torch.tensor([[1, 2, 3], [4, 5, 0]])} + expected_batch_2 = {"input_ids": torch.tensor([[6, 7, 8]])} + + assert len(batches) == 2 + assert torch.equal(batches[0]["input_ids"], expected_batch_1["input_ids"]) + assert torch.equal(batches[1]["input_ids"], expected_batch_2["input_ids"]) + + +def test_padding_side_left(): + sample_data_for_padding = { + "input_ids": torch.tensor([[0, 0, 1, 2, 3], [0, 0, 4, 5, 6], [0, 6, 7, 8, 9]]) + } + + loader = cf_loader.InMemoryLoader( + sample_data_for_padding, batch_size=2, max_seq_len=3, padding_side="left" + ) + batches = list(loader) + + expected_batch_1 = {"input_ids": torch.tensor([[1, 2, 3], [4, 5, 6]])} + expected_batch_2 = {"input_ids": torch.tensor([[7, 8, 9]])} + + assert len(batches) == 2 + assert torch.equal(batches[0]["input_ids"], expected_batch_1["input_ids"]) + assert torch.equal(batches[1]["input_ids"], expected_batch_2["input_ids"]) diff --git a/tests/op/test_tokenize.py b/tests/op/test_tokenize.py index bc79bc2..6123208 100644 --- a/tests/op/test_tokenize.py +++ b/tests/op/test_tokenize.py @@ -1,27 +1,54 @@ +# Copyright 2024 NVIDIA CORPORATION +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np import pytest +cp = pytest.importorskip("cupy") cudf = pytest.importorskip("cudf") dask_cudf = pytest.importorskip("dask_cudf") transformers = pytest.importorskip("transformers") +torch = pytest.importorskip("torch") import crossfit as cf # noqa: E402 from crossfit import op # noqa: E402 +cf_loader = pytest.importorskip("crossfit.backend.torch.loader") + def test_tokenizer_sentence_piece(model_name="microsoft/deberta-v3-base"): model = cf.HFModel(model_name) tokenizer = op.Tokenizer(model, cols=["text"], tokenizer_type="spm") + input_strings = ["hello world", "this is a sentence"] ddf = dask_cudf.from_cudf( - cudf.DataFrame({"text": ["hello world", "this is a sentence"]}), - npartitions=2, + cudf.DataFrame({"text": input_strings}), + npartitions=1, ) results = tokenizer(ddf) results = results.compute() hf_tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) + tokenized_strings = hf_tokenizer.batch_encode_plus( + input_strings, return_tensors="pt", padding="longest" + ) assert isinstance(results, cudf.DataFrame) - assert results["input_ids"][0] == hf_tokenizer(["hello world"])["input_ids"][0] - assert results["input_ids"][1] == hf_tokenizer(["this is a sentence"])["input_ids"][0] + np.testing.assert_equal( + np.asarray(results["input_ids"][0]), tokenized_strings["input_ids"][0].numpy() + ) + np.testing.assert_equal( + np.asarray(results["input_ids"][1]), tokenized_strings["input_ids"][1].numpy() + ) def test_tokenizer_max_chars(model_name="sentence-transformers/all-MiniLM-L6-v2"): @@ -44,3 +71,76 @@ def test_tokenizer_max_chars(model_name="sentence-transformers/all-MiniLM-L6-v2" assert results1["input_ids"][0] == results2["input_ids"][0] assert results1["input_ids"][1] == results2["input_ids"][1] + + +def test_tokenizer_padded(model_name="microsoft/deberta-v3-base"): + model = cf.HFModel(model_name) + tokenizer = op.Tokenizer(model, cols=["text"], tokenizer_type="spm") + input_strings = ["hello world", "this is a sentence"] + ddf = dask_cudf.from_cudf( + cudf.DataFrame({"text": input_strings}), + npartitions=1, + ) + results = tokenizer(ddf) + results = results.compute() + + hf_tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) + tokenized_strings = hf_tokenizer.batch_encode_plus( + input_strings, return_tensors="pt", padding="longest" + ) + assert isinstance(results, cudf.DataFrame) + np.testing.assert_equal( + np.asarray(results["input_ids"][0]), tokenized_strings["input_ids"][0].numpy() + ) + np.testing.assert_equal( + np.asarray(results["input_ids"][1]), tokenized_strings["input_ids"][1].numpy() + ) + + +def test_clip_tokens_right_padding(): + input_ids = cp.array([[1, 2, 3, 0, 0], [1, 2, 3, 4, 0]]) + attention_mask = cp.array([[1, 1, 1, 0, 0], [1, 1, 1, 1, 0]]) + token_o = {"input_ids": input_ids, "attention_mask": attention_mask} + + result = cf_loader.clip_tokens(token_o, max_length=4, padding_side="right", pad_token_id=0) + + assert isinstance(result["input_ids"], torch.Tensor) + assert isinstance(result["attention_mask"], torch.Tensor) + assert result["input_ids"].shape == (2, 4) + assert result["attention_mask"].shape == (2, 4) + assert torch.equal(result["input_ids"].to("cpu"), torch.tensor([[1, 2, 3, 0], [1, 2, 3, 4]])) + assert torch.equal( + result["attention_mask"].to("cpu"), torch.tensor([[1, 1, 1, 0], [1, 1, 1, 1]]) + ) + + +def test_clip_tokens_left_padding(): + input_ids = cp.array([[0, 0, 1, 2, 3], [0, 1, 2, 3, 4]]) + attention_mask = cp.array([[0, 0, 1, 1, 1], [0, 1, 1, 1, 1]]) + token_o = {"input_ids": input_ids, "attention_mask": attention_mask} + + result = cf_loader.clip_tokens(token_o, max_length=4, padding_side="left", pad_token_id=0) + + assert isinstance(result["input_ids"], torch.Tensor) + assert isinstance(result["attention_mask"], torch.Tensor) + assert result["input_ids"].shape == (2, 4) + assert result["attention_mask"].shape == (2, 4) + assert torch.equal(result["input_ids"].to("cpu"), torch.tensor([[0, 1, 2, 3], [1, 2, 3, 4]])) + assert torch.equal( + result["attention_mask"].to("cpu"), torch.tensor([[0, 1, 1, 1], [1, 1, 1, 1]]) + ) + + +def test_clip_tokens_no_clipping_needed(): + input_ids = cp.array([[1, 2, 3], [4, 5, 6]]) + attention_mask = cp.array([[1, 1, 1], [1, 1, 1]]) + token_o = {"input_ids": input_ids, "attention_mask": attention_mask} + + result = cf_loader.clip_tokens(token_o, max_length=4, padding_side="right", pad_token_id=0) + + assert isinstance(result["input_ids"], torch.Tensor) + assert isinstance(result["attention_mask"], torch.Tensor) + assert result["input_ids"].shape == (2, 3) + assert result["attention_mask"].shape == (2, 3) + assert torch.equal(result["input_ids"].to("cpu"), torch.tensor([[1, 2, 3], [4, 5, 6]])) + assert torch.equal(result["attention_mask"].to("cpu"), torch.tensor([[1, 1, 1], [1, 1, 1]])) diff --git a/tests/test_torch_utils.py b/tests/test_torch_utils.py new file mode 100644 index 0000000..39fdf8c --- /dev/null +++ b/tests/test_torch_utils.py @@ -0,0 +1,107 @@ +# Copyright 2024 NVIDIA CORPORATION +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +torch = pytest.importorskip("torch") + + +def test_pad_tensors_2d(): + from crossfit.utils.torch_utils import pad_tensors + + # Test with 2D tensors + tensor1 = torch.tensor([[1, 2], [3, 4]]) + tensor2 = torch.tensor([[5, 6, 7], [8, 9, 10], [11, 12, 13]]) + tensor_list = [tensor1, tensor2] + + padded_tensors = pad_tensors(tensor_list) + + assert len(padded_tensors) == 2 + assert padded_tensors[0].shape == (2, 3) + assert padded_tensors[1].shape == (3, 3) + assert torch.all(padded_tensors[0] == torch.tensor([[1, 2, 0], [3, 4, 0]])) + assert torch.all(padded_tensors[1] == tensor2) + + +def test_pad_tensors_3d(): + from crossfit.utils.torch_utils import pad_tensors + + # Test with 3D tensors + tensor1 = torch.rand(2, 3, 4) + tensor2 = torch.rand(3, 3, 5) + tensor_list = [tensor1, tensor2] + + padded_tensors = pad_tensors(tensor_list) + + assert len(padded_tensors) == 2 + assert padded_tensors[0].shape == (2, 3, 5) + assert padded_tensors[1].shape == (3, 3, 5) + + +def test_pad_tensors_custom_value(): + from crossfit.utils.torch_utils import pad_tensors + + # Test with custom pad value + tensor1 = torch.tensor([[1, 2], [3, 4]]) + tensor2 = torch.tensor([[5, 6, 7]]) + tensor_list = [tensor1, tensor2] + + padded_tensors = pad_tensors(tensor_list, pad_token_id=-1) + + assert torch.all(padded_tensors[0] == torch.tensor([[1, 2, -1], [3, 4, -1]])) + assert torch.all(padded_tensors[1] == torch.tensor([[5, 6, 7]])) + + +def test_concat_padded_tensors(): + from crossfit.utils.torch_utils import concat_and_pad_tensors + + tensor1 = torch.tensor([[1, 2], [3, 4]]) + tensor2 = torch.tensor([[5, 6, 7], [8, 9, 10]]) + all_outputs_ls = [tensor1, tensor2] + + result = concat_and_pad_tensors(all_outputs_ls) + + expected_result = torch.tensor([[1, 2, 0], [3, 4, 0], [5, 6, 7], [8, 9, 10]]) + + assert torch.all(result == expected_result) + + +def test_concat_padded_tensors_custom_value(): + from crossfit.utils.torch_utils import concat_and_pad_tensors + + tensor1 = torch.tensor([[1, 2], [3, 4]]) + tensor2 = torch.tensor([[5, 6, 7], [8, 9, 10]]) + all_outputs_ls = [tensor1, tensor2] + + result = concat_and_pad_tensors(all_outputs_ls, pad_token_id=-1) + + expected_result = torch.tensor([[1, 2, -1], [3, 4, -1], [5, 6, 7], [8, 9, 10]]) + + assert torch.all(result == expected_result) + + +def test_concat_padded_tensors_different_devices(): + from crossfit.utils.torch_utils import concat_and_pad_tensors + + if torch.cuda.is_available(): + tensor1 = torch.tensor([[1, 2], [3, 4]], device="cuda") + tensor2 = torch.tensor([[5, 6, 7], [8, 9, 10]], device="cpu") + all_outputs_ls = [tensor1, tensor2] + + result = concat_and_pad_tensors(all_outputs_ls) + + assert result.device == tensor1.device + assert result.shape == (4, 3) + else: + pytest.skip("CUDA not available, skipping test_concat_padded_tensors_different_devices")