diff --git a/nemo/collections/llm/api.py b/nemo/collections/llm/api.py index 71e006472db9..a9b3d4361f5b 100644 --- a/nemo/collections/llm/api.py +++ b/nemo/collections/llm/api.py @@ -436,7 +436,7 @@ def export_ckpt( def generate( path: Union[Path, str], prompts: list[str], - trainer: Optional[nl.Trainer] = None, + trainer: nl.Trainer, params_dtype: torch.dtype = torch.bfloat16, max_batch_size: int = 4, random_seed: Optional[int] = None,