Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trainer: add predict with generate #32346

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/transformers/models/idefics/configuration_idefics.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ class IdeficsConfig(PretrainedConfig):

model_type = "idefics"
is_composition = False
keys_to_ignore_at_inference = ["past_key_values"]

def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/idefics2/configuration_idefics2.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ class Idefics2Config(PretrainedConfig):

model_type = "idefics2"
is_composition = True
keys_to_ignore_at_inference = ["past_key_values"]

def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/llava/configuration_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class LlavaConfig(PretrainedConfig):

model_type = "llava"
is_composition = False
keys_to_ignore_at_inference = ["past_key_values"]

def __init__(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class LlavaNextConfig(PretrainedConfig):

model_type = "llava_next"
is_composition = False
keys_to_ignore_at_inference = ["past_key_values"]

def __init__(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ class LlavaNextVideoConfig(PretrainedConfig):

model_type = "llava_next_video"
is_composition = True
keys_to_ignore_at_inference = ["past_key_values"]

def __init__(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ class LlavaNextVideoConfig(PretrainedConfig):

model_type = "llava_next_video"
is_composition = True
keys_to_ignore_at_inference = ["past_key_values"]

def __init__(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class PaliGemmaConfig(PretrainedConfig):

model_type = "paligemma"
is_composition = False
keys_to_ignore_at_inference = ["past_key_values"]

def __init__(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class VideoLlavaConfig(PretrainedConfig):

model_type = "video_llava"
is_composition = False
keys_to_ignore_at_inference = ["past_key_values"]

def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/vipllava/configuration_vipllava.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class VipLlavaConfig(PretrainedConfig):

model_type = "vipllava"
is_composition = False
keys_to_ignore_at_inference = ["past_key_values"]

def __init__(
self,
Expand Down
116 changes: 111 additions & 5 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,12 @@
from .debug_utils import DebugOption, DebugUnderflowOverflow
from .feature_extraction_sequence_utils import SequenceFeatureExtractor
from .hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend
from .integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available
from .integrations.deepspeed import (
deepspeed_init,
deepspeed_load_checkpoint,
is_deepspeed_available,
is_deepspeed_zero3_enabled,
)
from .integrations.tpu import tpu_spmd_dataloader
from .modelcard import TrainingSummary
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint
Expand Down Expand Up @@ -305,9 +310,12 @@ class Trainer:
The arguments to tweak for training. Will default to a basic instance of [`TrainingArguments`] with the
`output_dir` set to a directory named *tmp_trainer* in the current directory if not provided.
data_collator (`DataCollator`, *optional*):
The function to use to form a batch from a list of elements of `train_dataset` or `eval_dataset`. Will
The function to use to form a batch from a list of elements of `train_dataset`. Will
default to [`default_data_collator`] if no `tokenizer` is provided, an instance of
[`DataCollatorWithPadding`] otherwise.
eval_data_collator (`typing.Union[DataCollator, NoneType]`, *optional*):
The function to use to form a batch from a list of elements of `eval_dataset` and `train_dataset`. Will
zucchini-nlp marked this conversation as resolved.
Show resolved Hide resolved
default to `data_collator` if no `eval_data_collator` is provided.
train_dataset (Union[`torch.utils.data.Dataset`, `torch.utils.data.IterableDataset`, `datasets.Dataset`], *optional*):
The dataset to use for training. If it is a [`~datasets.Dataset`], columns not accepted by the
`model.forward()` method are automatically removed.
Expand Down Expand Up @@ -379,6 +387,7 @@ def __init__(
model: Union[PreTrainedModel, nn.Module] = None,
args: TrainingArguments = None,
data_collator: Optional[DataCollator] = None,
eval_data_collator: Optional[DataCollator] = None,
train_dataset: Optional[Union[Dataset, IterableDataset, "datasets.Dataset"]] = None,
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset], "datasets.Dataset"]] = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
Expand Down Expand Up @@ -523,6 +532,7 @@ def __init__(
else default_data_collator
)
self.data_collator = data_collator if data_collator is not None else default_collator
self.eval_data_collator = eval_data_collator if eval_data_collator is not None else data_collator
zucchini-nlp marked this conversation as resolved.
Show resolved Hide resolved
self.train_dataset = train_dataset
self.eval_dataset = eval_dataset
self.tokenizer = tokenizer
Expand Down Expand Up @@ -961,7 +971,7 @@ def get_eval_dataloader(self, eval_dataset: Optional[Union[str, Dataset]] = None
if eval_dataset is not None
else self.eval_dataset
)
data_collator = self.data_collator
data_collator = self.eval_data_collator if self.eval_data_collator is not None else self.data_collator
zucchini-nlp marked this conversation as resolved.
Show resolved Hide resolved

if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation")
Expand Down Expand Up @@ -1003,7 +1013,7 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
The test dataset to use. If it is a [`~datasets.Dataset`], columns not accepted by the
`model.forward()` method are automatically removed. It must implement `__len__`.
"""
data_collator = self.data_collator
data_collator = self.eval_data_collator if self.eval_data_collator is not None else self.data_collator
zucchini-nlp marked this conversation as resolved.
Show resolved Hide resolved

if is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
test_dataset = self._remove_unused_columns(test_dataset, description="test")
Expand Down Expand Up @@ -3600,6 +3610,7 @@ def evaluate(
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "eval",
**gen_kwargs,
) -> Dict[str, float]:
"""
Run evaluation and returns metrics.
Expand Down Expand Up @@ -3634,6 +3645,8 @@ def evaluate(
metric_key_prefix (`str`, *optional*, defaults to `"eval"`):
An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
"eval_bleu" if the prefix is "eval" (default)
gen_kwargs:
Additional `generate` specific kwargs.

Returns:
A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
Expand All @@ -3649,10 +3662,28 @@ def evaluate(
eval_dataset=_eval_dataset if override else eval_dataset_name,
ignore_keys=ignore_keys,
metric_key_prefix=f"{metric_key_prefix}_{eval_dataset_name}",
**gen_kwargs,
)
metrics.update(dataset_metrics)
return metrics

# Set generation-related kwargs
if self.args.predict_with_generate:
if self.args.generation_config is not None:
gen_config = self.args.generation_config
self.gen_config = copy.deepcopy(gen_config) # copy so we don't modify args.gen_config in-place
unused_kwargs = self.gen_config.update(**gen_kwargs)
if unused_kwargs:
logger.warning_once(
f"Following generation related kwargs were passed to `evaluate` but not used by `generate()`: "
f"{' '.join(unused_kwargs.keys())} .",
"Make sure there are no typos in the passed kwargs or do not pass unused kwargs.",
)
else:
# We assume the model can generate if predict-with-generate is True
# Therefore, generation_config should be available
self.gen_config = self.model.generation_config

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment here

# memory metrics - must set up as early as possible
self._memory_tracker.start()

Expand Down Expand Up @@ -3700,7 +3731,11 @@ def evaluate(
return output.metrics

def predict(
self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "test"
self,
test_dataset: Dataset,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "test",
**gen_kwargs,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you also need to add it for prediction_step

) -> PredictionOutput:
"""
Run prediction and returns predictions and potential metrics.
Expand All @@ -3718,6 +3753,8 @@ def predict(
metric_key_prefix (`str`, *optional*, defaults to `"test"`):
An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
"test_bleu" if the prefix is "test" (default)
gen_kwargs:
Additional `generate` specific kwargs.

<Tip>

Expand All @@ -3734,6 +3771,23 @@ def predict(
- metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained
labels).
"""
# Set generation-related kwargs
if self.args.predict_with_generate:
if self.args.generation_config is not None:
gen_config = self.args.generation_config
self.gen_config = copy.deepcopy(gen_config) # copy so we don't modify args.gen_config in-place
unused_kwargs = self.gen_config.update(**gen_kwargs)
if unused_kwargs:
logger.warning_once(
f"Following generation related kwargs were passed to `evaluate` but not used by `generate()`: "
f"{' '.join(unused_kwargs.keys())} .",
"Make sure there are no typos in the passed kwargs or do not pass unused kwargs.",
)
else:
# We assume the model can generate if predict-with-generate is True
# Therefore, generation_config should be available
self.gen_config = self.model.generation_config
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't we also update the config with gen_kwargs ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you mean add kwargs from model.config to generation config? It shouldn't be necessary because the base model.generation_config should contain all generation related kwargs after the model is loaded. So we just need to make sure user-passed kwargs have higher priority than trainer.generation_config

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm taking about the gen_kwargs that you are passing in predict. I would expect that self.gen_config is updated when the user pass gen_kwargs in the predict function in all cases (important in the case we pass a generate kwargs such as synced_gpus ). By default, it is equal to self.model.generation_config but if the user passes it in TrainingArguments, it will be equal to self.args.generation_config.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see now, right, we should be updating i any case


# memory metrics - must set up as early as possible
self._memory_tracker.start()

Expand Down Expand Up @@ -4001,6 +4055,7 @@ def prediction_step(
inputs: Dict[str, Union[torch.Tensor, Any]],
prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None,
**gen_kwargs,
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Perform an evaluation step on `model` using `inputs`.
Expand All @@ -4020,12 +4075,29 @@ def prediction_step(
ignore_keys (`List[str]`, *optional*):
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
gathering predictions.
gen_kwargs:
Additional `generate` specific kwargs.

Return:
Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss,
logits and labels (each being optional).
"""
has_labels = False if len(self.label_names) == 0 else all(inputs.get(k) is not None for k in self.label_names)

# Prioroty: gen_kwargs > args.gen_config > model.generation_config > default GenerationConfig()
if self.args.predict_with_generate:
gen_config = self.gen_config
default_synced_gpus = True if is_deepspeed_zero3_enabled() else False
synced_gpus = gen_kwargs.get("synced_gpus", default_synced_gpus)
if len(gen_kwargs) > 0:
unused_kwargs = gen_config.update(**gen_kwargs)
if unused_kwargs:
logger.warning_once(
"Following generation related kwargs were passed to `prediction_step` but not "
Comment on lines +4236 to +4241
Copy link
Member

@SunMarc SunMarc Aug 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that if you pass synced_gpus in gen_kwargs, the warning will appear since it will be in unused_kwargs. Maybe do pop instead. Also this will trigger the warning in other places also.

f"used by `generate()`: {' '.join(unused_kwargs.keys())} .",
"Make sure there are no typos in the passed kwargs or do not pass unused kwargs.",
)

# For CLIP-like models capable of returning loss values.
# If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss`
# is `True` in `model.forward`.
Expand All @@ -4049,6 +4121,37 @@ def prediction_step(
else:
labels = None

# If the `generation_input_ids` was passed in inputs, the model can generate and we need to modify
# input keys. Otherwise, we don't know the `prompt` to generate from
if self.args.predict_with_generate and not prediction_loss_only:
generation_inputs = inputs.copy()
if "generation_input_ids" in generation_inputs:
# get inputs that are related to text and contain only generation prompt
generation_only_inputs = {
k.replace("generation_", ""): v for k, v in generation_inputs.items() if "generation_" in k
}

# get common inputs that are not related to text, e.g. pixel-values
gen_keys = generation_only_inputs.keys()
generation_inputs_common = {
k: v
for k, v in generation_inputs.items()
if k.replace("generation_", "") not in gen_keys and "generation" not in k
}
generated_tokens = self.model.generate(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we use model instead of self.model here? In the evaluation_loop(), the self.model is wrapped and the wrapped model may not always be the same as self.model. I think this is for the case when deepspeed zero3 is enabled and evalute_on_start is set to true.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For inference we don't wrap for distributed mode, but I changed for model because there are some other steps run before returning the model. The original code was adapted from seq2seq trainer, so I modified it there too

# Note: in torch.distributed mode, there's no point in wrapping the model
# inside a DistributedDataParallel as we'll be under `no_grad` anyways.
if not training:
return model

**generation_inputs_common,
**generation_only_inputs,
generation_config=gen_config,
synced_gpus=synced_gpus,
)
else:
raise ValueError(
"`predict_with_generate` is set to `True` but no inputs are passed for generation. ",
"Make sure you have `generation_input_ids` and `generation_attention_mask`.",
)

# clean up inputs for loss from generation related input tensors if there are any before doing `forward`
inputs = {k: v for k, v in inputs.items() if "generation_" not in k}
with torch.no_grad():
if is_sagemaker_mp_enabled():
raw_outputs = smp_forward_only(model, inputs)
Expand Down Expand Up @@ -4094,6 +4197,9 @@ def prediction_step(
if prediction_loss_only:
return (loss, None, None)

if self.args.predict_with_generate and not prediction_loss_only:
return (loss, generated_tokens, labels)

logits = nested_detach(logits)
if len(logits) == 1:
logits = logits[0]
Expand Down
21 changes: 21 additions & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from packaging import version

from .debug_utils import DebugOption
from .generation import GenerationConfig
from .trainer_utils import (
EvaluationStrategy,
FSDPOption,
Expand Down Expand Up @@ -789,6 +790,12 @@ class TrainingArguments:

eval_use_gather_object (`bool`, *optional*, defaults to `False`):
Whether to run recursively gather object in a nested list/tuple/dictionary of objects from all devices. This should only be enabled if users are not just returning tensors, and this is actively discouraged by PyTorch.
predict_with_generate (`bool`, *optional*, defaults to `False`):
Whether to use generate to calculate generative metrics (ROUGE, BLEU).
generation_config ([`~generation.GenerationConfig`], *optional*):
The [`~generation.GenerationConfig`] object that will be used during generation if `predict_with_generate` is set to `True`.
Arguments passed in GenerationConfig will have higher priority than model's generation config. Anything not set by this config
will fallback to `model.generation_config` by default.
"""

framework = "pt"
Expand Down Expand Up @@ -1496,6 +1503,20 @@ class TrainingArguments:
},
)

predict_with_generate: bool = field(
default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."}
)
generation_config: Optional[GenerationConfig] = field(
default=None,
metadata={
"help": (
"The GenerationConfig that will be used during prediction. Args from this config ",
"will have higher priority than model's generation config. Anything not set by this config ",
"will fallback to `model.generation_config`.",
)
},
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we simplify a bit things if we also add a generation_kwargs as this is incompatible with generation_config + I don't think we want to merge both arguments into one. WDYT @muellerzr ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, maybe we can then allow users to pass generation_config as a dict also, then we can make a Config object of it ourselves. I see that TrainerSeq2Seq args also uses a config arg, so I thought we could later merge seq2seq args with trainerArgs

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would be better I think ! This way, we won't need to have **gen_kwargs in evaluate and predict function. cc @muellerzr @gante

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oke, now we can accept a dict or a config object in training args


def __post_init__(self):
# Parse in args that could be `dict` sent in from the CLI as a string
for field in _VALID_DICT_FIELDS:
Expand Down
Loading