Skip to content

Commit

Permalink
extend tests for vlm
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Nov 5, 2024
1 parent c96bb24 commit 42005a6
Show file tree
Hide file tree
Showing 7 changed files with 127 additions and 26 deletions.
2 changes: 1 addition & 1 deletion optimum/intel/openvino/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,7 +778,7 @@ def __init__(
for inputs in self.model.inputs
}
self.ov_config = ov_config or {**self.parent_model.ov_config}
self.request = None
self.request = None if not self.parent_model._compile_only else self.model
self._model_name = model_name
self.config = self.parent_model.config
self._model_dir = Path(model_dir or parent_model._model_save_dir)
Expand Down
54 changes: 35 additions & 19 deletions optimum/intel/openvino/modeling_visual_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,31 @@
from huggingface_hub import hf_hub_download
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from openvino._offline_transformations import apply_moc_transformations, compress_model_transformation
from transformers import AutoConfig, GenerationConfig, GenerationMixin, PretrainedConfig
from transformers import AutoConfig, AutoModelForCausalLM, GenerationConfig, GenerationMixin, PretrainedConfig
from transformers.modeling_outputs import BaseModelOutputWithPooling

from ...exporters.openvino import main_export
from ...exporters.openvino.stateful import ensure_stateful_is_available, model_has_input_output_name
from .configuration import OVConfig, OVWeightQuantizationConfig
from .modeling_base import OVBaseModel, OVModelPart
from .modeling_decoder import CausalLMOutputWithPast, OVModelForCausalLM
from .utils import TemporaryDirectory
from .utils import (
OV_LANGUAGE_MODEL_NAME,
OV_TEXT_EMBEDDINGS_MODEL_NAME,
OV_VISION_EMBEDDINGS_MODEL_NAME,
TemporaryDirectory,
)


try:
from transformers import LlavaForConditionalGeneration
except ImportError:
LlavaForConditionalGeneration = None

try:
from transformers import LlavaNextForConditionalGeneration
except ImportError:
LlavaNextForConditionalGeneration = None


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -54,13 +70,8 @@ def __init__(

def compile(self):
if self.request is None:
if self._compile_only:
self.request = self.model.create_infer_request()
else:
logger.info(f"Compiling the Language model to {self._device} ...")
self.request = self._compile_model(
self.model, self._device, self.ov_config, self.model_save_dir
).create_infer_request()
logger.info(f"Compiling the Language model to {self._device} ...")
super().compile()
self._compile_text_emb()

def _compile_text_emb(self):
Expand Down Expand Up @@ -233,6 +244,7 @@ def forward(self, image_feature, pos_embed, key_padding_mask):
class OVModelForVisualCausalLM(OVBaseModel, GenerationMixin):
export_feature = "image-text-to-text"
additional_parts = []
auto_model_class = AutoModelForCausalLM

def __init__(
self,
Expand Down Expand Up @@ -316,11 +328,7 @@ def _save_pretrained(self, save_directory: Union[str, Path]):
The directory where to save the model files.
"""
src_files = [self.lm_model, self.text_embdings_model, self.vision_embeddings_model]
dst_file_names = [
"openvino_language_model.xml",
"openvino_text_embeddings_model.xml",
"openvino_vision_embeddings_model.xml",
]
dst_file_names = [OV_LANGUAGE_MODEL_NAME, OV_TEXT_EMBEDDINGS_MODEL_NAME, OV_VISION_EMBEDDINGS_MODEL_NAME]
for part in self.additional_parts:
model = getattr(self, f"{part}_model", None)
if model is not None:
Expand Down Expand Up @@ -407,13 +415,17 @@ def _from_pretrained(
if os.path.isdir(model_id):
model_save_dir = Path(model_id)
model_file_names = {
"language_model": "openvino_language_model.xml",
"text_embeddings": "openvino_text_embeddings_model.xml",
"vision_embeddings": "openvino_vision_embeddings_model.xml",
"language_model": OV_LANGUAGE_MODEL_NAME,
"language_model_bin": OV_LANGUAGE_MODEL_NAME.replace(".xml", ".bin"),
"text_embeddings": OV_TEXT_EMBEDDINGS_MODEL_NAME,
"text_embeddings_bin": OV_TEXT_EMBEDDINGS_MODEL_NAME.replace(".xml", ".bin"),
"vision_embeddings": OV_VISION_EMBEDDINGS_MODEL_NAME,
"vision_embeddings_bin": OV_VISION_EMBEDDINGS_MODEL_NAME.replace(".xml", ".bin"),
}

for part in model_cls.additional_parts:
model_file_names[part] = f"openvino_{part}_model.xml"
model_file_names[part + "_bin"] = f"openvino_{part}_model.bin"
model_cls = MODEL_TYPE_TO_CLS_MAPPING[config.model_type]
quantization_config = model_cls._prepare_weight_quantization_config(quantization_config, load_in_8bit)
compile_only = kwargs.get("compile_only", False)
Expand Down Expand Up @@ -713,6 +725,8 @@ def can_generate(self):


class _OVLlavaForCausalLM(OVModelForVisualCausalLM):
auto_model_class = LlavaForConditionalGeneration

def get_vision_embeddings(self, pixel_values, input_ids=None, **kwargs):
if input_ids is not None and input_ids.shape[1] == 1:
return None
Expand Down Expand Up @@ -882,6 +896,8 @@ def _filter_unattended_tokens(self, input_ids, attention_mask, past_key_values):


class _OVLlavaNextForCausalLM(_OVLlavaForCausalLM):
auto_model_class = LlavaNextForConditionalGeneration

# Adopted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_next/modeling_llava_next.py#L655
def pack_image_features(self, image_features, image_sizes, image_newline=None):
from transformers.models.llava_next.modeling_llava_next import get_anyres_image_grid_shape, unpad_image
Expand Down Expand Up @@ -1150,7 +1166,7 @@ def get_text_embeddings(self, input_ids, **kwargs):
return super().get_text_embeddings(for_inputs_embeds_ids, **kwargs)


class _OvInternVLForCausalLM(OVModelForVisualCausalLM):
class _OVInternVLForCausalLM(OVModelForVisualCausalLM):
def get_vision_embeddings(self, pixel_values, input_ids=None, **kwargs):
if input_ids is not None and input_ids.shape[1] == 1:
return None
Expand Down Expand Up @@ -1568,7 +1584,7 @@ def get_multimodal_embeddings(
MODEL_TYPE_TO_CLS_MAPPING = {
"llava": _OVLlavaForCausalLM,
"llava_next": _OVLlavaNextForCausalLM,
"internvl_chat": _OvInternVLForCausalLM,
"minicpmv": _OVMiniCPMVForCausalLM,
"llava-qwen2": _OVNanoLlavaForCausalLM,
"internvl_chat": _OVInternVLForCausalLM,
}
4 changes: 4 additions & 0 deletions optimum/intel/openvino/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@
OV_ENCODER_NAME = "openvino_encoder_model.xml"
OV_DECODER_NAME = "openvino_decoder_model.xml"
OV_DECODER_WITH_PAST_NAME = "openvino_decoder_with_past_model.xml"
OV_TEXT_EMBEDDINGS_MODEL_NAME = "openvino_text_embeddings_model.xml"
OV_LANGUAGE_MODEL_NAME = "openvino_language_model.xml"
OV_VISION_EMBEDDINGS_MODEL_NAME = "openvino_vision_embeddings_model.xml"

OV_TOKENIZER_NAME = "openvino_tokenizer{}.xml"
OV_DETOKENIZER_NAME = "openvino_detokenizer{}.xml"
Expand Down Expand Up @@ -116,6 +119,7 @@
"token-classification": "OVModelForTokenClassification",
"question-answering": "OVModelForQuestionAnswering",
"image-classification": "OVModelForImageClassification",
"image-text-to-text": "OVModelForVisualCausalLM",
"audio-classification": "OVModelForAudioClassification",
"stable-diffusion": "OVStableDiffusionPipeline",
"stable-diffusion-xl": "OVStableDiffusionXLPipeline",
Expand Down
17 changes: 14 additions & 3 deletions tests/openvino/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,14 @@
OVModelForSequenceClassification,
OVModelForSpeechSeq2Seq,
OVModelForTokenClassification,
OVModelForVisualCausalLM,
OVStableDiffusion3Pipeline,
OVStableDiffusionPipeline,
OVStableDiffusionXLImg2ImgPipeline,
OVStableDiffusionXLPipeline,
)
from optimum.intel.openvino.modeling_base import OVBaseModel
from optimum.intel.openvino.modeling_visual_language import MODEL_TYPE_TO_CLS_MAPPING
from optimum.intel.openvino.utils import TemporaryDirectory
from optimum.intel.utils.import_utils import _transformers_version, is_transformers_version
from optimum.utils.save_utils import maybe_load_preprocessors
Expand All @@ -70,12 +72,13 @@ class ExportModelTest(unittest.TestCase):
"stable-diffusion-xl": OVStableDiffusionXLPipeline,
"stable-diffusion-xl-refiner": OVStableDiffusionXLImg2ImgPipeline,
"latent-consistency": OVLatentConsistencyModelPipeline,
"llava": OVModelForVisualCausalLM,
}

if is_transformers_version(">=", "4.45"):
SUPPORTED_ARCHITECTURES.update({"stable-diffusion-3": OVStableDiffusion3Pipeline, "flux": OVFluxPipeline})

GENERATIVE_MODELS = ("pix2struct", "t5", "bart", "gpt2", "whisper")
GENERATIVE_MODELS = ("pix2struct", "t5", "bart", "gpt2", "whisper", "llava")

def _openvino_export(
self,
Expand All @@ -93,6 +96,10 @@ def _openvino_export(
model_class = TasksManager.get_model_class_for_task(task, library=library_name)
model = model_class(f"hf_hub:{model_name}", pretrained=True, exportable=True)
TasksManager.standardize_model_attributes(model_name, model, library_name=library_name)
elif model_type == "llava":
model = MODEL_TYPE_TO_CLS_MAPPING[model_type].auto_model_class.from_pretrained(
model_name, **loading_kwargs
)
else:
model = auto_model.auto_model_class.from_pretrained(model_name, **loading_kwargs)

Expand Down Expand Up @@ -135,8 +142,12 @@ def test_export_with_custom_gen_config(self, model_type):
task = auto_model.export_feature
model_name = MODEL_NAMES[model_type]
loading_kwargs = {"attn_implementation": "eager"} if model_type in SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED else {}

model = auto_model.auto_model_class.from_pretrained(model_name, **loading_kwargs)
if model_type == "llava":
model = MODEL_TYPE_TO_CLS_MAPPING[model_type].auto_model_class.from_pretrained(
model_name, **loading_kwargs
)
else:
model = auto_model.auto_model_class.from_pretrained(model_name, **loading_kwargs)

model.generation_config.top_k = 42
model.generation_config.do_sample = True
Expand Down
4 changes: 4 additions & 0 deletions tests/openvino/test_exporters_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
OVModelForSeq2SeqLM,
OVModelForSequenceClassification,
OVModelForTokenClassification,
OVModelForVisualCausalLM,
OVModelOpenCLIPForZeroShotImageClassification,
OVModelOpenCLIPText,
OVModelOpenCLIPVisual,
Expand Down Expand Up @@ -92,6 +93,7 @@ class OVCLIExportTestCase(unittest.TestCase):
"stable-diffusion-xl": 4 if is_tokenizers_version("<", "0.20") else 0,
"stable-diffusion-3": 6 if is_tokenizers_version("<", "0.20") else 2,
"flux": 4 if is_tokenizers_version("<", "0.20") else 0,
"llava": 2 if is_tokenizers_version("<", "0.20") else 0,
}

SUPPORTED_SD_HYBRID_ARCHITECTURES = [
Expand Down Expand Up @@ -222,6 +224,8 @@ def test_exporters_cli_int8(self, task: str, model_type: str):
elif model_type.startswith("stable-diffusion") or model_type.startswith("flux"):
models = [model.unet or model.transformer, model.vae_encoder, model.vae_decoder]
models.append(model.text_encoder if model_type == "stable-diffusion" else model.text_encoder_2)
elif task.startswith("image-text-to-text"):
models = [model.language_model, model.vision_embeddings]
else:
models = [model]

Expand Down
71 changes: 68 additions & 3 deletions tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,13 @@
OVModelWithEmbedForCausalLM,
OVVisionEmbedding,
)
from optimum.intel.openvino.utils import TemporaryDirectory, _print_compiled_model_properties
from optimum.intel.openvino.utils import (
OV_LANGUAGE_MODEL_NAME,
OV_TEXT_EMBEDDINGS_MODEL_NAME,
OV_VISION_EMBEDDINGS_MODEL_NAME,
TemporaryDirectory,
_print_compiled_model_properties,
)
from optimum.intel.pipelines import pipeline as optimum_pipeline
from optimum.intel.utils.import_utils import is_openvino_version, is_transformers_version
from optimum.intel.utils.modeling_utils import _find_files_matching_pattern
Expand Down Expand Up @@ -134,6 +140,7 @@ def __init__(self, *args, **kwargs):
self.OV_DECODER_MODEL_ID = "helenai/gpt2-ov"
self.OV_SEQ2SEQ_MODEL_ID = "echarlaix/t5-small-openvino"
self.OV_DIFFUSION_MODEL_ID = "hf-internal-testing/tiny-stable-diffusion-openvino"
self.OV_VLM_MODEL_ID = "katuni4ka/tiny-random-llava-ov"

def test_load_from_hub_and_save_model(self):
tokenizer = AutoTokenizer.from_pretrained(self.OV_MODEL_ID)
Expand Down Expand Up @@ -222,6 +229,64 @@ def test_load_from_hub_and_save_decoder_model(self, use_cache):
del model
gc.collect()

def test_load_from_hub_and_save_visual_language_model(self):
model_id = self.OV_VLM_MODEL_ID
processor = get_preprocessor(model_id)
prompt = "<image>\n What is shown in this image?"
image = Image.open(
requests.get(
"http://images.cocodataset.org/val2017/000000039769.jpg",
stream=True,
).raw
)
loaded_model = OVModelForVisualCausalLM.from_pretrained(model_id)
self.assertIsInstance(loaded_model.config, PretrainedConfig)
self.assertIsInstance(loaded_model, MODEL_TYPE_TO_CLS_MAPPING[loaded_model.config.model_type])
self.assertIsInstance(loaded_model.vision_embeddings, OVVisionEmbedding)
self.assertIsInstance(loaded_model.language_model, OVModelWithEmbedForCausalLM)
for additional_part in loaded_model.additional_parts:
self.assertTrue(hasattr(loaded_model, additional_part))
self.assertIsInstance(getattr(loaded_model, additional_part), MODEL_PARTS_CLS_MAPPING[additional_part])
self.assertIsInstance(loaded_model.config, PretrainedConfig)
# Test that PERFORMANCE_HINT is set to LATENCY by default
self.assertEqual(loaded_model.ov_config.get("PERFORMANCE_HINT"), "LATENCY")
self.assertEqual(
loaded_model.language_model.request.get_compiled_model().get_property("PERFORMANCE_HINT"), "LATENCY"
)
self.assertEqual(loaded_model.language_model.text_emb_request.get_property("PERFORMANCE_HINT"), "LATENCY")
self.assertEqual(loaded_model.vision_embeddings.request.get_property("PERFORMANCE_HINT"), "LATENCY")
inputs = processor(images=image, text=prompt, return_tensors="pt")
set_seed(SEED)
loaded_model_outputs = loaded_model(**inputs)

with TemporaryDirectory() as tmpdirname:
loaded_model.save_pretrained(tmpdirname)
folder_contents = os.listdir(tmpdirname)
for xml_file_name in [
OV_LANGUAGE_MODEL_NAME,
OV_TEXT_EMBEDDINGS_MODEL_NAME,
OV_VISION_EMBEDDINGS_MODEL_NAME,
]:
self.assertTrue(xml_file_name in folder_contents)
self.assertTrue(xml_file_name.replace(".xml", ".bin") in folder_contents)
model = OVModelForVisualCausalLM.from_pretrained(tmpdirname)
compile_only_model = OVModelForVisualCausalLM.from_pretrained(tmpdirname, compile_only=True)
self.assertIsInstance(compile_only_model.language_model.model, ov.runtime.CompiledModel)
self.assertIsInstance(compile_only_model.language_model.request, ov.runtime.InferRequest)
self.assertIsInstance(compile_only_model.language_model.text_emb_model, ov.runtime.CompiledModel)
self.assertIsInstance(compile_only_model.language_model.text_emb_request, ov.runtime.CompiledModel)
self.assertIsInstance(compile_only_model.vision_embeddings.model, ov.runtime.CompiledModel)
self.assertIsInstance(compile_only_model.vision_embeddings.request, ov.runtime.CompiledModel)
outputs = compile_only_model(**inputs)
self.assertTrue(torch.equal(loaded_model_outputs.logits, outputs.logits))
del compile_only_model

outputs = model(**inputs)
self.assertTrue(torch.equal(loaded_model_outputs.logits, outputs.logits))
del loaded_model
del model
gc.collect()

def test_load_from_hub_and_save_seq2seq_model(self):
tokenizer = AutoTokenizer.from_pretrained(self.OV_SEQ2SEQ_MODEL_ID)
tokens = tokenizer("This is a sample input", return_tensors="pt")
Expand Down Expand Up @@ -2040,9 +2105,9 @@ def test_generate_utils(self, model_arch):
def test_model_can_be_loaded_after_saving(self, model_arch):
model_id = MODEL_NAMES[model_arch]
with TemporaryDirectory() as save_dir:
ov_model = OVModelForVisualCausalLM.from_pretrained(model_id, compile=False)
ov_model = OVModelForVisualCausalLM.from_pretrained(model_id, compile=False, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS)
ov_model.save_pretrained(save_dir)
ov_restored_model = OVModelForVisualCausalLM.from_pretrained(save_dir, compile=False)
ov_restored_model = OVModelForVisualCausalLM.from_pretrained(save_dir, compile=False, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS)
self.assertIsInstance(ov_restored_model, type(ov_model))


Expand Down
1 change: 1 addition & 0 deletions tests/openvino/utils_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@
"open-clip": (20, 28),
"stable-diffusion-3": (66, 42, 58, 30),
"flux": (56, 24, 28, 64),
"llava": (30, 18),
}


Expand Down

0 comments on commit 42005a6

Please sign in to comment.