Replies: 2 comments
-
The inputs should be exactly the same as before. The only thing that needs to change is the mask for the loss that ignores the prefix and only calculates the loss on the answer. By the way you can compute that mask without a loop by doing token_indices = mx.arange(mask_width)[None, :]
mask = mx.logical_and(token_indices >= input_lengths[:, None], token_indices < lengths[:, None]) |
Beta Was this translation helpful? Give feedback.
0 replies
-
Thank you |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I'm trying to implement the equivalent of HF's ability to train on completions only using MLX. Looking at the default implementation of iterate_batches and default loss in mlx_lm.tuner.trainer, it looks as if the tokens are being set to zero for the padding suffix used to ensure each sequence of tokens in the batch is of the same maximal length. Then, in default_loss, a boolean mask is used to avoid penalizing the model for not generating the padding.
In my attempt to generalize from this approach, I'm using the module below (which uses #391 to pass a custom loss and batching function) to test this on the SQL generation training dataset included with mlx-llm.
It assumes, with the following training text, as an example:
that the input is
and the output it
The custom iterate_batches function calculates the length of the tokenized 'input' for each record in the batch, fills in zeros for the tokens up to the length of the input as well as the padding suffix (leaving only the completion ids with non-zero tokens), and passes a list of the input lengths along with the batch and the full lengths to the custom loss function. The custom loss function then calculates a mask for ignoring the inputs and the padding suffix.
However, when I run this, I'm getting NaN error values:
But when I change the custom batching function to fill in the actual values of the tokenized input (rather than use zeros as is the case for the suffix), i.e., from
to
Then I get proper loss values:
Is there a reason why using zeros for the the front of the tokens that will be subject to the mask same as the suffix padding would cause NaN error values? Note, the custom batching function is not performing the token shift (described here and implemented in mlx_lm's default iterate_batches method) and I'm not sure if that is related to the cause of this issue, but changing the loss method to the following to facilitate the same shift did not address the issue:
Any insight to help my understanding of this would be greatly appreciated. Thank you for such a great software framework
Beta Was this translation helpful? Give feedback.
All reactions