diff --git a/assets/view.jpg b/assets/view.jpg new file mode 100644 index 000000000..6f44ebaba Binary files /dev/null and b/assets/view.jpg differ diff --git a/torchchat/generate.py b/torchchat/generate.py index a9094aa40..397f9e801 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -24,15 +24,6 @@ from PIL import Image -# torchtune model definition dependencies -from torchtune.data import Message, padded_collate_tiled_images_and_mask - -from torchtune.generation import sample as tune_sample -from torchtune.models.llama3 import llama3_tokenizer - -from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform -from torchtune.training import set_default_dtype - from torchchat.cli.builder import ( _initialize_model, _initialize_tokenizer, @@ -43,6 +34,15 @@ from torchchat.utils.build_utils import device_sync, set_precision from torchchat.utils.device_info import get_device_info +# torchtune model definition dependencies +from torchtune.data import Message, padded_collate_tiled_images_and_mask + +from torchtune.generation import sample as tune_sample +from torchtune.models.llama3 import llama3_tokenizer + +from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform +from torchtune.training import set_default_dtype + class _ChatFormatter(ABC): def __init__(self, tokenizer): @@ -179,8 +179,15 @@ def from_args(cls, args): # Validate that all image prompts exist before expensive model load if image_prompts := getattr(args, "image_prompts", None): - if not all(os.path.exists(image_prompt) for image_prompt in image_prompts): - raise RuntimeError(f"Image prompt {image_prompt} does not exist") + non_existent_image_prompts = [ + image_prompt + for image_prompt in image_prompts + if (not os.path.exists(image_prompt)) + ] + if len(non_existent_image_prompts): + raise RuntimeError( + f"Image prompt {non_existent_image_prompts} does not exist" + ) return cls( prompt=getattr(args, "prompt", ""), @@ -938,7 +945,8 @@ def chat( TransformerCrossAttentionLayer, TransformerSelfAttentionLayer, ) - decoder = self.model.model.decoder + + decoder = self.model.model.decoder for m in reversed(list(decoder.modules())): if isinstance(m, TransformerSelfAttentionLayer) or isinstance( m, TransformerCrossAttentionLayer @@ -984,7 +992,10 @@ def chat( # `is_torchtune_model` is a misnomer since it doesn't capture all # torchtune models (i.e. Flamingo) # See Issue: https://github.com/pytorch/torchchat/issues/1273 - elif not generator_args.is_torchtune_model and self.model.config.model_type != ModelType.Flamingo: + elif ( + not generator_args.is_torchtune_model + and self.model.config.model_type != ModelType.Flamingo + ): max_seq_length = min( encoded.size(0) + generator_args.max_new_tokens, (