diff --git a/pyproject.toml b/pyproject.toml index b61437e3..455f8477 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ exclude = ["images*"] [project.optional-dependencies] huggingface = [ + "unsloth_zoo", "packaging", "tyro", "transformers>=4.44.2", @@ -210,6 +211,7 @@ colab-ampere-torch220 = [ "flash-attn>=2.6.3", ] colab-new = [ + "unsloth_zoo", "packaging", "tyro", "transformers>=4.44.2", diff --git a/unsloth/__init__.py b/unsloth/__init__.py index e7db41ce..abee9c9e 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -27,6 +27,13 @@ # pass # pass +# Check for unsloth_zoo +try: + import unsloth_zoo +except: + raise ImportError("Unsloth: Please install unsloth_zoo via `pip install unsloth-zoo`") +pass + # Unsloth currently does not work on multi GPU setups - sadly we are a 2 brother team so # enabling it will require much more work, so we have to prioritize. Please understand! # We do have a beta version, which you can contact us about! @@ -124,7 +131,7 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16 # Try linking cuda folder, or everything in local if len(possible_cudas) == 0: - os.system(f"ldconfig /usr/local/") + os.system("ldconfig /usr/local/") else: find_number = re.compile(r"([\d\.]{2,})") latest_cuda = np.argsort([float(find_number.search(x).group(1)) for x in possible_cudas])[::-1][0] diff --git a/unsloth/chat_templates.py b/unsloth/chat_templates.py index 9b5f9ff2..cab6130d 100644 --- a/unsloth/chat_templates.py +++ b/unsloth/chat_templates.py @@ -35,7 +35,9 @@ from .tokenizer_utils import * from .models._utils import patch_tokenizer import re - +from unsloth_zoo.dataset_utils import ( + train_on_responses_only, +) CHAT_TEMPLATES = {} # =========================================== Unsloth @@ -910,7 +912,7 @@ def get_chat_template( # Check fast tokenizer if not is_fast_tokenizer: print( - f"Unsloth: Not a fast tokenizer, so can't process it as of yet :(\n"\ + "Unsloth: Not a fast tokenizer, so can't process it as of yet :(\n"\ "Please log a Github issue if you want this as a new feature!\n"\ "Your chat template will still work, but it won't add or edit tokens." ) @@ -1236,7 +1238,7 @@ def __convert_to_sharegpt__(examples): n_extensions = max(conversation_extension-1, 0) if n_extensions == 0: return dataset - dataset = dataset.rename_columns({"conversations" : f"conversations0"}) + dataset = dataset.rename_columns({"conversations" : "conversations0"}) all_shuffled = [dataset] for j in range(1, n_extensions+1): shuffled = dataset.shuffle(seed = random_state+j).rename_columns({"conversations0" : f"conversations{j}"}) @@ -1254,7 +1256,7 @@ def __convert_to_sharegpt__(examples): f"in zip({', '.join(f'conversations{j}__' for j in range(n_extensions))}):\n" function += f"{' '*8}convos.append("\ f"{'+'.join(f'conversations{j}' for j in range(n_extensions))})\n" - function += f"{' '*4}return " + "{ " + f"'conversations' : convos" + " }" + function += f"{' '*4}return " + "{ " + "'conversations' : convos" + " }" # Map function exec(function, globals()) @@ -1812,194 +1814,6 @@ def formatting_prompts_func(examples): pass -# From https://www.geeksforgeeks.org/longest-common-substring-array-strings/ -# Longest Common Substring in an Array of Strings -def _longest_common_substring(arr): - n = len(arr) - s = arr[0] - l = len(s) - res = "" - for i in range(l): - for j in range(i + 1, l + 1): - stem = s[i:j] - k = 1 - for k in range(1, n): - if stem not in arr[k]: - break - if (k + 1 == n and len(res) < len(stem)): - res = stem - return res -pass - - -def _find_common_token_ids(component, tokenizer): - """ - \n### User:\n\n - \n\n### User:\n\n - etc - we need to find the middle most repeatted part. - Tokenizers can tokenize newlines or spaces as 1 token! - """ - right_text = "" - if component.endswith (" "): right_text = " " - elif component.endswith("\n"): right_text = "\n" - left_text = "" - if component.startswith (" "): left_text = " " - elif component.startswith("\n"): left_text = "\n" - stripped = component.strip() - - # Add current pieces and also newlines - all_input_ids = [] - for left in range(3): - for right in range(3): - x = left*left_text + stripped + right*right_text - x = tokenizer(x, add_special_tokens = False).input_ids - all_input_ids.append(x) - - x = left*"\n" + stripped + right*"\n" - x = tokenizer(x, add_special_tokens = False).input_ids - all_input_ids.append(x) - pass - pass - substring = _longest_common_substring([str(x + [0]) for x in all_input_ids]) - substring = substring.split(", ")[:-1] - substring = [int(x) for x in substring] - - # Also get rest of tokenized string - original = tokenizer(component, add_special_tokens = False).input_ids - # Get optional left and right - for j in range(len(original)): - if original[j : j + len(substring)] == substring: break - optional_left = original[:j] - optional_right = original[j+len(substring):] - return substring, optional_left, optional_right -pass - - -def train_on_responses_only( - trainer, - instruction_part = None, - response_part = None, -): - """ - Trains only on responses and not on the instruction by masking out - the labels with -100 for the instruction part. - """ - tokenizer = trainer.tokenizer - - if not hasattr(tokenizer, "_unsloth_input_part") or \ - not hasattr(tokenizer, "_unsloth_output_part"): - - if instruction_part is None or response_part is None: - raise ValueError("Unsloth: instruction_part and response_part must be given!") - pass - elif (instruction_part is not None or response_part is not None) and \ - (hasattr(tokenizer, "_unsloth_input_part") or hasattr(tokenizer, "_unsloth_output_part")): - - raise ValueError("Unsloth: Your tokenizer already has instruction and response parts set - do not give custom ones!") - else: - instruction_part = tokenizer._unsloth_input_part - response_part = tokenizer._unsloth_output_part - pass - - # Get most common tokens since tokenizers can tokenize stuff differently! - Q_must, Q_left, Q_right = _find_common_token_ids(instruction_part, tokenizer) - A_must, A_left, A_right = _find_common_token_ids(response_part, tokenizer) - - # Store some temporary stuff - A_first = A_must[0] - len_A_must = len(A_must) - A_left_reversed = A_left[::-1] - A_right_forward = A_right - - Q_first = Q_must[0] - len_Q_must = len(Q_must) - Q_left_reversed = Q_left[::-1] - Q_right_forward = Q_right - - def _train_on_responses_only(examples): - input_ids_ = examples["input_ids"] - all_labels = [] - - for input_ids in input_ids_: - n = len(input_ids) - labels = [-100] * n - n_minus_1 = n - 1 - j = 0 - while j < n: - # Find - if (input_ids[j] == A_first) and \ - (input_ids[j : (k := j + len_A_must)] == A_must): - - # Now backtrack to get previous optional tokens - for optional_left in A_left_reversed: - if j < 1: break - if optional_left == input_ids[j-1]: j -= 1 - else: break - pass - # And forwards look as well - for optional_right in A_right_forward: - if k >= n_minus_1: break - if optional_right == input_ids[k+1]: k += 1 - else: break - pass - # assistant_j = j - assistant_k = k - - j = assistant_k - # Given , now find next user - while j < n: - # Find - # Also accept last final item if assistant is the last turn - if (j == n_minus_1) or \ - ((input_ids[j] == Q_first) and \ - (input_ids[j : (k := j + len_Q_must)] == Q_must)): - - # Now backtrack to get previous optional tokens - for optional_left in Q_left_reversed: - if j < 1: break - if optional_left == input_ids[j-1]: j -= 1 - else: break - pass - # And forwards look as well - for optional_right in Q_right_forward: - if k >= n_minus_1: break - if optional_right == input_ids[k+1]: k += 1 - else: break - pass - user_j = j - # Account for last item - if user_j != n_minus_1: - # user_k = k - # j = user_k - j = k - else: - user_j = n - k = n - pass - # Now copy input_ids to labels - labels[assistant_k : user_j] = input_ids[assistant_k : user_j] - # print(assistant_j, assistant_k, user_j, user_k) - break - pass - j += 1 - pass - pass - j += 1 - pass - all_labels.append(labels) - pass - return { "labels" : all_labels } - pass - - if hasattr(trainer, "train_dataset") and trainer.train_dataset is not None: - trainer.train_dataset = trainer.train_dataset.map(_train_on_responses_only, batched = True) - if hasattr(trainer, "eval_dataset") and trainer.eval_dataset is not None: - trainer.eval_dataset = trainer.eval_dataset.map(_train_on_responses_only, batched = True) - return trainer -pass - - def create_stopping_criteria(tokenizer, stop_word = "eos_token"): class StoppingCriteriaSub(StoppingCriteria): __slots__ = "stop_token", "single_match", "length", diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index 1fec5d7a..5abed6a3 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -355,6 +355,7 @@ def fast_cross_entropy_loss( labels, logit_softcapping = 0, logit_scaling = 0, + n_items = None, ): """ Arguments: @@ -372,7 +373,8 @@ def fast_cross_entropy_loss( logit_softcapping, logit_scaling, ) - n_items = torch.count_nonzero(labels != -100) + if n_items is None: + n_items = torch.count_nonzero(labels != -100) return loss.sum() / n_items pass @@ -409,6 +411,7 @@ def fast_cross_entropy_loss( labels = shift_labels, logit_softcapping = logit_softcapping, logit_scaling = logit_scaling, + n_items = kwargs.get("n_items", None), ) else: if logit_scaling != 0: diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index b14bb391..aa7a69c9 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2024.9.post4" +__version__ = "2024.10.0" __all__ = [ "prepare_model_for_kbit_training", diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index a2453301..4cd512a9 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -975,13 +975,14 @@ def _CausalLM_fast_forward( # Fixes https://github.com/unslothai/unsloth/issues/10 self.extra_ignored_labels = torch.full((self.max_seq_length, 1), -100, device = "cuda:0") pass - + shift_labels = torch.hstack((labels[..., 1:], self.extra_ignored_labels[:labels.shape[0]])) loss = fast_cross_entropy_loss( logits = shift_logits, labels = shift_labels, logit_softcapping = logit_softcapping, logit_scaling = logit_scaling, + n_items = kwargs.get("n_items", None), ) else: if logit_scaling != 0: @@ -2019,8 +2020,8 @@ def get_peft_model( if loftq_config == {}: from peft import LoftQConfig logger.warning_once( - f"Unsloth: init_lora_weights = `loftq` is set, but `loftq_config` is None.\n"\ - f"We shall use `loftq_config = LoftQConfig(loftq_bits = 4, loftq_iter = 1)`." + "Unsloth: init_lora_weights = `loftq` is set, but `loftq_config` is None.\n"\ + "We shall use `loftq_config = LoftQConfig(loftq_bits = 4, loftq_iter = 1)`." ) loftq_config = LoftQConfig(loftq_bits = 4, loftq_iter = 1) pass diff --git a/unsloth/save.py b/unsloth/save.py index dce30c4c..3760e232 100644 --- a/unsloth/save.py +++ b/unsloth/save.py @@ -555,7 +555,7 @@ def unsloth_save_model( # max_ram = max(max_ram - W.nbytes, 0) else: # Save to Disk - logger.warning_once(f"We will save to Disk and not RAM now.") + logger.warning_once("We will save to Disk and not RAM now.") filename = os.path.join(temporary_location, f"{name}.pt") torch.save(W, filename, pickle_module = pickle, pickle_protocol = pickle.HIGHEST_PROTOCOL,) # weights_only = True weirdly fails? @@ -1460,7 +1460,7 @@ def fix_tokenizer_bos_token(tokenizer): fix_bos_token = True logger.warning( - f"Unsloth: ##### The current model auto adds a BOS token.\n"\ + "Unsloth: ##### The current model auto adds a BOS token.\n"\ "Unsloth: ##### Your chat template has a BOS token. We shall remove it temporarily." ) @@ -1671,7 +1671,7 @@ def unsloth_save_pretrained_gguf( if fix_bos_token: logger.warning( - f"Unsloth: ##### The current model auto adds a BOS token.\n"\ + "Unsloth: ##### The current model auto adds a BOS token.\n"\ "Unsloth: ##### We removed it in GGUF's chat template for you." ) pass @@ -1867,7 +1867,7 @@ def unsloth_push_to_hub_gguf( if fix_bos_token: logger.warning( - f"Unsloth: ##### The current model auto adds a BOS token.\n"\ + "Unsloth: ##### The current model auto adds a BOS token.\n"\ "Unsloth: ##### We removed it in GGUF's chat template for you." ) pass diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index 196e4961..63d07c92 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -26,6 +26,15 @@ import gc import subprocess +from unsloth_zoo.tokenizer_utils import ( + mean_of_trained_tokens, + add_new_tokens, + fix_untrained_tokens, +) +from unsloth_zoo.training_utils import ( + fix_zero_training_loss, +) + __all__ = [ "load_correct_tokenizer", "fix_sentencepiece_tokenizer", @@ -807,347 +816,6 @@ def check_tokenizer( pass -@torch.inference_mode -def fix_untrained_tokens(model, tokenizer, train_dataset, eps = 1e-16): - """ - Llama-3 for eg has untrained vectors in the base model. - These include <|eot_id|>, <|start_header_id|>, <|end_header_id|> - We reset them to the mean of the rest of the tokens - """ - embedding_matrix = model.get_input_embeddings ().weight - lm_head_matrix = model.get_output_embeddings().weight - - # Ignore some model checks for now - if model.config._name_or_path in IGNORED_TOKENIZER_NAMES: - return - pass - - # Get untrained tokens - indicator_untrained1 = torch.amax(embedding_matrix, axis = 1) <= eps - # Check lm_head as well - - # Does NOT work for Llama 3.1!! - indicator_untrained2 = torch.amax(lm_head_matrix, axis = 1) <= eps - - # We instead check for repeated vectors - lm_head_where = torch.where(indicator_untrained1)[0] - lm_head_bad = lm_head_matrix[lm_head_where] - lm_head_bad = lm_head_bad.cpu().float().numpy().round(3) - from collections import Counter - counter = Counter() - for row in lm_head_bad: counter[hash(row.data.tobytes())] += 1 - counter = Counter({k: c for k, c in counter.items() if c >= 2}) - - lm_head_where = lm_head_where.cpu().numpy() - final_bad_lm_head = [] - for j, row in enumerate(lm_head_bad): - if hash(row.data.tobytes()) in counter: - final_bad_lm_head.append(lm_head_where[j]) - indicator_untrained2 = indicator_untrained2 | torch.zeros_like(indicator_untrained2) - indicator_untrained2[final_bad_lm_head] = True - - # Combine both checks - indicator_untrained = indicator_untrained1 & indicator_untrained2 - - where_untrained = torch.where(indicator_untrained)[0] - n_untrained = where_untrained.shape[0] - n_trained = embedding_matrix.shape[0] - n_untrained - - # Get set and actual tokens - where_untrained = where_untrained.tolist() - if len(where_untrained) == 0: return - - # Remove untrained indices where it's longer - - where_untrained_set = frozenset(where_untrained) - actual_bad_tokens = tokenizer.convert_ids_to_tokens(where_untrained) - # Remove None items in actual_bad_tokens - actual_bad_tokens = [x for x in actual_bad_tokens if x is not None] - - # Check if tokenizer and training datasets have bad tokens - if_bad_first = False - if_bad_second = False - # Check tokenizer's chat template for any untrained tokens - chat_template = getattr(tokenizer, "chat_template", None) - if chat_template is not None: - if_bad_first = any(x in chat_template for x in actual_bad_tokens) - pass - - # Check the first 250, last 250 input_ids - size_dataset = len(train_dataset) - size = min(size_dataset, 250) - for j in range(size): - input_ids = train_dataset[j] - if "input_ids" in input_ids: - input_ids = input_ids["input_ids"] - if_bad = any(item in where_untrained_set for item in input_ids) - if if_bad: - if_bad_second = True - break - pass - pass - pass - - # Check last 250 - if not if_bad_second: - left = max(size_dataset-250, 0) - for j in range(left, size_dataset): - input_ids = train_dataset[j] - if "input_ids" in input_ids: - input_ids = input_ids["input_ids"] - if_bad = any(item in where_untrained_set for item in input_ids) - if if_bad: - if_bad_second = True - break - pass - pass - pass - pass - - # Check if bad tokens exists! - if not if_bad_first and not if_bad_second: return - - # Check if lm_head / embed_token are trainable! - bad_not_trainable = False - if not embedding_matrix.requires_grad: bad_not_trainable = True - if not lm_head_matrix .requires_grad: bad_not_trainable = True - - if bad_not_trainable: - - final_bad_items = [] - - # Re-check the first 250, last 250 input_ids - size_dataset = len(train_dataset) - size = min(size_dataset, 250) - for j in range(size): - input_ids = train_dataset[j] - if "input_ids" in input_ids: - input_ids = input_ids["input_ids"] - for item in input_ids: - if item in where_untrained_set: final_bad_items.append(item) - pass - pass - - # Re-check last 250 - left = max(size_dataset-250, 0) - for j in range(left, size_dataset): - input_ids = train_dataset[j] - if "input_ids" in input_ids: - input_ids = input_ids["input_ids"] - for item in input_ids: - if item in where_untrained_set: final_bad_items.append(item) - pass - pass - - raise ValueError( - f'Unsloth: Untrained tokens of [{list(set(final_bad_items))}] found, but embed_tokens & lm_head not trainable, causing NaNs. '\ - 'Restart then add `embed_tokens` & `lm_head` to '\ - '`FastLanguageModel.get_peft_model(target_modules = [..., "embed_tokens", "lm_head",]). `'\ - 'Are you using the `base` model? Instead, use the `instruct` version to silence this warning.', - ) - pass - - # Count all the possible bad tokens - final_counts = np.zeros(max(len(tokenizer), embedding_matrix.shape[0]), dtype = np.int64) - def mapping(examples): - input_ids = examples["input_ids"] - counter = np.fromiter(itertools.chain.from_iterable(input_ids), dtype = np.int32) - np.add.at(final_counts, counter, 1) - pass - train_dataset.map(mapping, batched = True, desc = "Counting untrained tokens") - - # Get sum of all items - sum_embedding = torch.sum(embedding_matrix, dtype = torch.float32, axis = 0) - sum_lm_head = torch.sum(lm_head_matrix, dtype = torch.float32, axis = 0) - - # Remove bad tokens - sum_embedding -= torch.sum(embedding_matrix[where_untrained], dtype = torch.float32, axis = 0) - sum_lm_head -= torch.sum(lm_head_matrix [where_untrained], dtype = torch.float32, axis = 0) - - # Find correct average by dividing by sum of trained tokens - mean_embedding = (sum_embedding / n_trained) - mean_lm_head = (sum_lm_head / n_trained) - - # Scale each to be equal to 1/max_frequency. Also set some to 0 if none seen - scaling = final_counts[where_untrained] / max(final_counts.max(), 1) - scaling = torch.tensor(scaling, device = mean_embedding.device).unsqueeze(1) - mean_embedding = mean_embedding.repeat((n_untrained, 1,)) * scaling - mean_lm_head = mean_lm_head .repeat((n_untrained, 1,)) * scaling - where_null = scaling.ravel() == 0 - mean_embedding[where_null] = 0 - mean_lm_head [where_null] = 0 - - # Set them to the mean - logger.warning( - "Unsloth: Setting embed_tokens & lm_head untrained tokens to "\ - "mean(trained) to counteract NaNs during training." - ) - embedding_matrix[where_untrained] = mean_embedding.to(embedding_matrix.dtype) - lm_head_matrix [where_untrained] = mean_lm_head .to(lm_head_matrix .dtype) - - # Clean up - for _ in range(3): - gc.collect() - torch.cuda.empty_cache() - pass - return -pass - - -@torch.inference_mode -def mean_of_trained_tokens(model, eps = 1e-16): - """ - Llama-3 for eg has untrained vectors in the base model. - These include <|eot_id|>, <|start_header_id|>, <|end_header_id|> - We reset them to the mean of the rest of the tokens - """ - embedding_matrix = model.get_input_embeddings ().weight.clone() - lm_head_matrix = model.get_output_embeddings().weight.clone() - - # Get untrained tokens - indicator_untrained = torch.amax(embedding_matrix, axis = 1) <= eps - where_untrained = torch.where(indicator_untrained)[0] - n_untrained = where_untrained.shape[0] - n_trained = embedding_matrix.shape[0] - n_untrained - # if n_untrained != 0: - # print( - # f"Unsloth: Not an error, but your model has {n_untrained} untrained tokens.\n"\ - # "We shall set them to the mean of the other trained tokens." - # ) - # pass - - # Get sum of all items - sum_embedding = torch.sum(embedding_matrix, dtype = torch.float32, axis = 0) - sum_lm_head = torch.sum(lm_head_matrix, dtype = torch.float32, axis = 0) - - # Remove bad tokens - sum_embedding -= torch.sum(embedding_matrix[where_untrained], dtype = torch.float32, axis = 0) - sum_lm_head -= torch.sum(lm_head_matrix [where_untrained], dtype = torch.float32, axis = 0) - - # Find correct average by dividing by sum of trained tokens - mean_embedding = (sum_embedding / n_trained) - mean_lm_head = (sum_lm_head / n_trained) - - return mean_embedding, mean_lm_head -pass - - -@torch.inference_mode -def add_new_tokens( - model, - tokenizer, - new_tokens = [], - method = "mean", - interpolation = 0.5, -): - """ - Smartly resizes the tokenizer and adds new tokens to the model. - We also disregard untrained tokens by removing them from the mean calculation. - """ - assert(isinstance(new_tokens, (list, tuple))) - assert(len(new_tokens) > 0) - assert(method == "mean" or method == "interpolation") - assert(interpolation >= 0 and interpolation <= 1) - - # Check if tokens already exist - overlapping_tokens = set(new_tokens) & set(tokenizer.vocab.keys()) - if len(overlapping_tokens) != 0: - print( - f"Unsloth: You're adding new_tokens = {new_tokens}\n"\ - f"There are tokens which are overlapping = {list(overlapping_tokens)}\n"\ - f"We shall safely ignore these overlapping tokens." - ) - new_tokens = [x for x in new_tokens if x not in overlapping_tokens] - pass - - # Get mean of trained tokens - # mean_embedding, mean_lm_head = fix_untrained_tokens(model) - - # Weirdly be careful reserved tokens can pop out - mean_embedding, mean_lm_head = mean_of_trained_tokens(model) - mean_embedding = mean_embedding.to(torch.float32) - mean_lm_head = mean_lm_head .to(torch.float32) - - # Add tokens! - old_length = len(tokenizer) - tokenizer.add_tokens(new_tokens) - model.resize_token_embeddings(len(tokenizer)) - - # If we use interpolation, we interpolate between the mean embeddings and - # the Word2Vec sum of the other vectors - embedding_matrix = model.get_input_embeddings ().weight - lm_head_matrix = model.get_output_embeddings().weight - - if method == "interpolation": - print( - "Unsloth: You are using interpolation to add new tokens.\n"\ - f"We shall set new tokens = mean(embeddings)*{1-interpolation} + mean(new_tokens)*{interpolation}" - ) - for j, token in enumerate(new_tokens): - input_ids = tokenizer(token, add_special_tokens = False).input_ids - mean_embedding_token = embedding_matrix[input_ids].mean(axis = 0, dtype = torch.float32) - mean_lm_head_token = lm_head_matrix [input_ids].mean(axis = 0, dtype = torch.float32) - - # Interpolate - mean_embedding_token = mean_embedding*(1-interpolation) + mean_embedding_token*interpolation - mean_lm_head_token = mean_lm_head *(1-interpolation) + mean_lm_head_token *interpolation - - # Set the new vector - embedding_matrix[old_length+j] = mean_embedding_token - lm_head_matrix [old_length+j] = mean_lm_head_token - pass - else: - # Now set the new tokens to the mean! - embedding_matrix[old_length:] = mean_embedding - lm_head_matrix [old_length:] = mean_lm_head - pass - - # We set a flag to say we need to train embeddings - internal_model = model - while hasattr(internal_model, "model"): - internal_model._need_to_train_embeddings = True - internal_model = internal_model.model - pass - internal_model._need_to_train_embeddings = True - - return -pass - - -@torch.inference_mode -def fix_zero_training_loss(model, tokenizer, train_dataset): - """ - Sometimes the labels get masked by all -100s, causing the loss - to be 0. We check for this! - """ - if len(train_dataset) == 0: return - - row = train_dataset[0] - if type(row) is dict and "labels" in row: - - # Check the first 100 rows - seen_bad = 0 - seen_good = 0 - for i, row in enumerate(train_dataset): - try: check_tokens = list(set(row["labels"])) - except: continue - if len(check_tokens) == 1 and check_tokens[0] == -100: seen_bad += 1 - else: seen_good += 1 - if i >= 100: break - pass - - # Check ratio - if seen_bad / (seen_bad + seen_good) >= 0.9: - logger.warning( - "Unsloth: Most labels in your dataset are -100. Training losses will be 0.\n"\ - "For example, are you sure you used `train_on_responses_only` correctly?\n"\ - "Or did you mask our tokens incorrectly? Maybe this is intended?" - ) - pass - pass -pass - - def check_nvidia(): # Unsloth doesn't work yet on AMD devices - we're working on it! output = np.array([0,]) @@ -1260,7 +928,7 @@ def patch_sft_trainer_tokenizer(): " torch.cuda.empty_cache()\n"\ "pass\n"\ "\n"\ - "fix_untrained_tokens(self.model, self.tokenizer, self.train_dataset, eps = 1e-16)\n\n"\ + "fix_untrained_tokens(self.model, self.tokenizer, self.train_dataset, IGNORED_TOKENIZER_NAMES, eps = 1e-16)\n\n"\ "fix_zero_training_loss(self.model, self.tokenizer, self.train_dataset)\n\n" # Add NEFTune since it doesn't seem to work?? We need to manually inject it diff --git a/unsloth/trainer.py b/unsloth/trainer.py index 45616ca6..c9c0ca2d 100644 --- a/unsloth/trainer.py +++ b/unsloth/trainer.py @@ -22,10 +22,12 @@ from transformers import TrainingArguments pass from . import is_bfloat16_supported +from unsloth_zoo.training_utils import unsloth_train __all__ = [ "UnslothTrainingArguments", "UnslothTrainer", + "unsloth_train", ]