Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Trainer: add predict with generate #32346
base: main
Are you sure you want to change the base?
Trainer: add predict with generate #32346
Changes from 4 commits
2265ea8
c45fab2
3c11a1c
6832a15
0959b5a
0fab8d7
19638d9
5e6f0e9
d66818a
2aedb2d
d510046
7dc7781
e21ba68
3e5411c
e342f8b
529e888
bddd18b
ae6000c
File filter
Filter by extension
Conversations
Jump to
There are no files selected for viewing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you also need to add it for
prediction_step
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't we also update the config with gen_kwargs ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you mean add kwargs from
model.config
to generation config? It shouldn't be necessary because the basemodel.generation_config
should contain all generation related kwargs after the model is loaded. So we just need to make sure user-passed kwargs have higher priority thantrainer.generation_config
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm taking about the
gen_kwargs
that you are passing inpredict
. I would expect thatself.gen_config
is updated when the user passgen_kwargs
in thepredict
function in all cases (important in the case we pass a generate kwargs such assynced_gpus
). By default, it is equal toself.model.generation_config
but if the user passes it in TrainingArguments, it will be equal toself.args.generation_config
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah I see now, right, we should be updating i any case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that if you pass
synced_gpus
ingen_kwargs
, the warning will appear since it will be inunused_kwargs
. Maybe dopop
instead. Also this will trigger the warning in other places also.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shall we use
model
instead ofself.model
here? In the evaluation_loop(), theself.model
is wrapped and the wrappedmodel
may not always be the same asself.model
. I think this is for the case when deepspeed zero3 is enabled andevalute_on_start
is set to true.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For inference we don't wrap for distributed mode, but I changed for
model
because there are some other steps run before returning the model. The original code was adapted from seq2seq trainer, so I modified it there tootransformers/src/transformers/trainer.py
Lines 1762 to 1765 in c409cd8
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we simplify a bit things if we also add a
generation_kwargs
as this is incompatible withgeneration_config
+ I don't think we want to merge both arguments into one. WDYT @muellerzr ?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, maybe we can then allow users to pass
generation_config
as a dict also, then we can make a Config object of it ourselves. I see thatTrainerSeq2Seq
args also uses aconfig
arg, so I thought we could later mergeseq2seq
args withtrainerArgs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would be better I think ! This way, we won't need to have **gen_kwargs in evaluate and predict function. cc @muellerzr @gante
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oke, now we can accept a dict or a config object in training args