Skip to content

Commit

Permalink
[REVIEW] Fix Padding Related Bugs: Crossfit (#66)
Browse files Browse the repository at this point in the history
* Add crossfit bits

Signed-off-by: Vibhu Jawa <[email protected]>

* Add padding fixes

Signed-off-by: Vibhu Jawa <[email protected]>

* Fix test

Signed-off-by: Vibhu Jawa <[email protected]>

* Add docstrings

Signed-off-by: Vibhu Jawa <[email protected]>

* fix torch import

Signed-off-by: Vibhu Jawa <[email protected]>

* fix torch import

Signed-off-by: Vibhu Jawa <[email protected]>

* fix padding to only pad the last dim

Signed-off-by: Vibhu Jawa <[email protected]>

* fix padding tests

Signed-off-by: Vibhu Jawa <[email protected]>

* Add test for left/right

Signed-off-by: Vibhu Jawa <[email protected]>

* Skip test for cf_loader

Signed-off-by: Vibhu Jawa <[email protected]>

* Fix bugs in clipping

Signed-off-by: Vibhu Jawa <[email protected]>

* Fix bugs in clipping

Signed-off-by: Vibhu Jawa <[email protected]>

* Add early stopping to HF memory estimation

Signed-off-by: Vibhu Jawa <[email protected]>

* Fix copy-right year

Signed-off-by: Vibhu Jawa <[email protected]>

* Add copyright year

Signed-off-by: Vibhu Jawa <[email protected]>

* Address last of Ryan's reviews

Signed-off-by: Vibhu Jawa <[email protected]>

* Skip loading model if its allready fitted

Signed-off-by: Vibhu Jawa <[email protected]>

* Use  self.load_cfg instead of AutoConfig.from_pretrained

Signed-off-by: Vibhu Jawa <[email protected]>

* Use  self.load_cfg instead of AutoConfig.from_pretrained

Signed-off-by: Vibhu Jawa <[email protected]>

* Fix memory_curve_utils and skip loading cfg/tokenizer here

Signed-off-by: Vibhu Jawa <[email protected]>

---------

Signed-off-by: Vibhu Jawa <[email protected]>
  • Loading branch information
VibhuJawa authored Aug 5, 2024
1 parent 005e2fc commit 0cc2993
Show file tree
Hide file tree
Showing 10 changed files with 615 additions and 98 deletions.
1 change: 1 addition & 0 deletions crossfit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def __call__(self, *args, **kwargs):
load_dataset = LazyLoader("crossfit.dataset.load.load_dataset")
embed = LazyLoader("crossfit.report.beir.embed.embed")
beir_report = LazyLoader("crossfit.report.beir.report.beir_report")
utils = LazyLoader("crossfit.utils")

__all__.extend(
[
Expand Down
89 changes: 89 additions & 0 deletions crossfit/backend/torch/hf/memory_curve_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright 2024 NVIDIA CORPORATION
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import gc

import joblib
import numpy as np
import torch
from sklearn.linear_model import LinearRegression
from tqdm import tqdm
from transformers import PreTrainedModel

from crossfit.utils.model_adapter import adapt_model_input


def fit_memory_estimate_curve(
model: PreTrainedModel,
path_or_name: str,
start_batch_size: int = 1,
end_batch_size: int = 2048,
batch_size_increment: int = 256,
start_seq_len: int = 1,
end_seq_len: int = 2048,
seq_len_increment: int = 64,
mem_model_path: str = None,
) -> LinearRegression:
print(f"Fitting memory estimate curve for model: {path_or_name}")

device = next(model.parameters()).device
X: list[list[int]] = []
y: list[float] = []

batch_size_pbar = tqdm(
range(start_batch_size, end_batch_size + 1, batch_size_increment), desc="Batch size"
)
for batch_size in batch_size_pbar:
seq_len_pbar = tqdm(
range(start_seq_len, end_seq_len + 1, seq_len_increment),
desc="Sequence length",
leave=False,
)
for seq_len in seq_len_pbar:
torch.cuda.reset_peak_memory_stats()

batch = {
"input_ids": torch.randint(1, 501, (batch_size, seq_len)).to(device=device),
"attention_mask": torch.ones((batch_size, seq_len)).to(device=device),
}

try:
_ = adapt_model_input(model, batch)
memory_used = torch.cuda.max_memory_allocated() / (1024**2) # Convert to MB
X.append([batch_size, seq_len, seq_len**2])
y.append(memory_used)

except RuntimeError as e:
if "out of memory" in str(e) or "out_of_memory" in str(e):
# Early stopping for this batch size
seq_len_pbar.close()
break
else:
raise e
finally:
del batch
if "outputs" in vars():
del outputs
gc.collect()
torch.cuda.empty_cache()

# Check if we've hit the memory limit for all sequence lengths
if seq_len == start_seq_len:
batch_size_pbar.close()
break

mem_model = LinearRegression().fit(np.array(X), np.array(y))
joblib.dump(mem_model, mem_model_path)

return mem_model
127 changes: 55 additions & 72 deletions crossfit/backend/torch/hf/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,65 @@
import joblib
import numpy as np
import torch
from sklearn.linear_model import LinearRegression
from tqdm import tqdm
from transformers import AutoConfig, AutoModel, AutoTokenizer

from crossfit.backend.torch.hf.memory_curve_utils import fit_memory_estimate_curve
from crossfit.backend.torch.model import Model
from crossfit.dataset.home import CF_HOME
from crossfit.utils.model_adapter import adapt_model_input


class HFModel(Model):
def __init__(self, path_or_name: str, max_mem_gb: int = 16, training=False):
def __init__(
self,
path_or_name: str,
max_mem_gb: int = 16,
training: bool = False,
start_batch_size: int = 1,
end_batch_size: int = 2048,
batch_size_increment: int = 256,
start_seq_len: int = 1,
seq_len_increment: int = 64,
):
super().__init__(path_or_name, max_mem_gb)
self.start_batch_size = start_batch_size
self.end_batch_size = end_batch_size
self.batch_size_increment = batch_size_increment
self.start_seq_len = start_seq_len
self.seq_len_increment = seq_len_increment

if not training:
with torch.no_grad():
self.fit_memory_estimate_curve()
cache_dir = os.path.join(CF_HOME, "memory", self.load_cfg()._name_or_path)
os.makedirs(cache_dir, exist_ok=True)
mem_model_path = os.path.join(cache_dir, "mem_model.pkl")
if os.path.exists(mem_model_path):
self.mem = joblib.load(mem_model_path)
else:
self.fit_memory_estimate_curve()
model = self.load_model("cuda") if not training else None
end_seq_len = self.max_seq_length()
if not training:
with torch.no_grad():
self.mem = fit_memory_estimate_curve(
model=model,
path_or_name=self.path_or_name,
start_batch_size=start_batch_size,
end_batch_size=end_batch_size,
start_seq_len=start_seq_len,
end_seq_len=end_seq_len,
batch_size_increment=batch_size_increment,
seq_len_increment=seq_len_increment,
mem_model_path=mem_model_path,
)
else:
self.mem = fit_memory_estimate_curve(
model=model,
path_or_name=self.path_or_name,
start_batch_size=start_batch_size,
end_batch_size=end_batch_size,
start_seq_len=start_seq_len,
end_seq_len=end_seq_len,
batch_size_increment=batch_size_increment,
seq_len_increment=seq_len_increment,
mem_model_path=mem_model_path,
)

def load_on_worker(self, worker, device="cuda"):
worker.torch_model = self.load_model(device)
Expand All @@ -60,77 +101,19 @@ def load_tokenizer(self):
def load_cfg(self):
return AutoConfig.from_pretrained(self.path_or_name)

def fit_memory_estimate_curve(self, model=None):
remove_model = False
if model is None:
remove_model = True
model = self.load_model(device="cuda")

cache_dir = os.path.join(CF_HOME, "memory", self.load_cfg()._name_or_path)
mem_model_path = os.path.join(cache_dir, "mem_model.pkl")

if os.path.exists(mem_model_path):
self.mem = joblib.load(mem_model_path)

return self

print(f"Fitting memory estimate curve for model: {self.path_or_name}")

device = next(model.parameters()).device
X = []
y = []

max_seq = self.max_seq_length()
for batch_size in tqdm(range(2048, 0, -256)):
if batch_size <= 0:
continue

for seq_len in range(max_seq, 0, -64):
if seq_len <= 0:
continue

torch.cuda.reset_peak_memory_stats()

batch = {
"input_ids": torch.randint(1, 501, (batch_size, seq_len)).to(device=device),
"attention_mask": torch.ones((batch_size, seq_len)).to(device=device),
}

try:
_ = adapt_model_input(model, batch)
memory_used = torch.cuda.max_memory_allocated() / (1024**2) # Convert to MB
X.append([batch_size, seq_len, seq_len**2])
y.append(memory_used)

except RuntimeError as e:
if "out of memory" in str(e) or "out_of_memory" in str(e):
pass
else:
raise e
finally:
del batch
if "outputs" in vars():
del outputs
gc.collect()
torch.cuda.empty_cache()

self.mem = LinearRegression().fit(np.array(X), np.array(y))
os.makedirs(cache_dir, exist_ok=True)
joblib.dump(self.mem, mem_model_path)

if remove_model:
del model
gc.collect()
torch.cuda.empty_cache()

def estimate_memory(self, max_num_tokens: int, batch_size: int) -> int:
predicted_memory = self.mem.predict(
np.array([[batch_size, max_num_tokens, max_num_tokens**2]])
)
return predicted_memory[0] / 1024 # Convert from MB to GB

def max_seq_length(self) -> int:
return self.load_cfg().max_position_embeddings
max_seq_length = self.load_tokenizer().model_max_length
# Guard against the HF bug
# which sets max_seq_length to max(int) for some models
if max_seq_length > 1e5:
max_seq_length = self.load_cfg().max_position_embeddings
return max_seq_length


class SentenceTransformerModel(HFModel):
Expand Down
46 changes: 37 additions & 9 deletions crossfit/backend/torch/loader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 NVIDIA CORPORATION
# Copyright 2024 NVIDIA CORPORATION
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -22,6 +22,7 @@
from crossfit.data.array.conversion import convert_array
from crossfit.data.array.dispatch import crossarray
from crossfit.data.dataframe.dispatch import CrossFrame
from crossfit.op.tokenize import clip_tokens
from crossfit.utils.model_adapter import adapt_model_input

DEFAULT_BATCH_SIZE = 512
Expand All @@ -36,7 +37,14 @@ def __init__(self, data: Dict[str, torch.Tensor], batch_size: int, progress_bar=
def __init__(self, data: CrossFrame, batch_size: int, progress_bar=None):
...

def __init__(self, data, batch_size: int, progress_bar=None, max_seq_len=None):
def __init__(
self,
data,
batch_size: int,
progress_bar=None,
max_seq_len=None,
padding_side: str = "right",
):
self.data = CrossFrame(data).cast(torch.Tensor)
self.tensor_dict = self.data.to_dict()
self._batch_size = batch_size
Expand All @@ -45,6 +53,7 @@ def __init__(self, data, batch_size: int, progress_bar=None, max_seq_len=None):
self._to_map = []
self.progress_bar = progress_bar
self.max_seq_len = max_seq_len
self.padding_side = padding_side

def map(self, fn):
self._to_map.append(fn)
Expand All @@ -66,7 +75,10 @@ def __next__(self):

batch = {key: val[self.current_idx : end] for key, val in self.tensor_dict.items()}
if self.max_seq_len is not None:
batch = {key: val[:, : self.max_seq_len] for key, val in batch.items()}
if self.padding_side == "right":
batch = {key: val[:, : self.max_seq_len] for key, val in batch.items()}
else:
batch = {key: val[:, -self.max_seq_len :] for key, val in batch.items()}

self.current_idx += self.batch_size

Expand Down Expand Up @@ -97,14 +109,27 @@ def __init__(
self.to_ignore = to_ignore or []
self.to_ignore.append("seq_length")
self.model = model
tokenizer = self.model.load_tokenizer()
pad_token_id = tokenizer.pad_token_id
padding_side = tokenizer.padding_side

if padding_side not in ["right", "left"]:
raise ValueError("padding_side must be either 'right' or 'left'")

self.pad_token_id = pad_token_id
frame = CrossFrame(data).cast(torch.Tensor)
seq_length = (frame[sort_key] != 0).sum(axis=1)
seq_length = (frame[sort_key] != self.pad_token_id).sum(axis=1)
self.sorted_indices = seq_length.argsort(descending=True)
frame = frame.apply(lambda x: x[self.sorted_indices])
frame = frame.assign(seq_length=seq_length[self.sorted_indices])

super().__init__(frame, initial_batch_size, progress_bar=progress_bar)
super().__init__(
frame,
initial_batch_size,
progress_bar=progress_bar,
max_seq_len=self.model.max_seq_length(),
padding_side=padding_side,
)
self.splits = self._find_optimal_splits()

def sort_column(self, col):
Expand All @@ -128,8 +153,6 @@ def __next__(self):
else:
start = self.splits[self.current_idx - 1]

_tokens = self.tensor_dict["seq_length"]

end = min(self.splits[self.current_idx], self.num_rows)
while end > start:
try:
Expand All @@ -138,8 +161,13 @@ def __next__(self):
for key, val in self.tensor_dict.items()
if key not in self.to_ignore
}
clip_len = min(max(_tokens[start], _tokens[end - 1]), self.model.max_seq_length())
batch = {key: val[:, :clip_len] for key, val in batch.items()}
batch = clip_tokens(
token_o=batch,
max_length=self.max_seq_len,
padding_side=self.padding_side,
pad_token_id=self.pad_token_id,
return_type="pt",
)

for fn in self._to_map:
batch = adapt_model_input(fn, batch)
Expand Down
8 changes: 7 additions & 1 deletion crossfit/backend/torch/op/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from crossfit.backend.torch.loader import DEFAULT_BATCH_SIZE, InMemoryLoader, SortedSeqLoader
from crossfit.backend.torch.model import Model
from crossfit.op.base import Op
from crossfit.utils.torch_utils import concat_and_pad_tensors


class Predictor(Op):
Expand Down Expand Up @@ -66,6 +67,7 @@ def call(self, data, partition_info=None):
loader = InMemoryLoader(
data[["input_ids", "attention_mask"]],
batch_size=self.batch_size,
padding_side=self.model.load_tokenizer().padding_side,
progress_bar=self.create_progress_bar(len(data), partition_info),
max_seq_len=self.model.max_seq_length(),
)
Expand All @@ -83,7 +85,11 @@ def call(self, data, partition_info=None):
all_outputs_ls.append(output)

out = cudf.DataFrame(index=index)
outputs = cp.asarray(torch.cat(all_outputs_ls, dim=0))
outputs = cp.asarray(
concat_and_pad_tensors(
all_outputs_ls, pad_token_id=loader.pad_token_id, padding_side=loader.padding_side
)
)
_index = loader.sort_column(index.values) if self.sorted_data_loader else index
if len(outputs.shape) <= 2:
out[self.pred_output_col] = create_list_series_from_1d_or_2d_ar(outputs, _index)
Expand Down
Loading

0 comments on commit 0cc2993

Please sign in to comment.