From 367178b904092707f562a823381a82f3c22a7ad5 Mon Sep 17 00:00:00 2001 From: Dmitrii Rudenko Date: Fri, 29 Nov 2024 11:44:34 +0100 Subject: [PATCH] Tests draft --- experiments/colpali_convert_lang_model.py | 31 ++ experiments/colpali_image_test.ipynb | 381 ++++++++++++++++++ experiments/colpali_text_test.ipynb | 444 +++++++++++++++++++++ experiments/late_interaction_colpali.ipynb | 202 ++++++++++ fastembed/common/model_management.py | 2 - fastembed/late_interaction/colpali.py | 6 +- 6 files changed, 1062 insertions(+), 4 deletions(-) create mode 100644 experiments/colpali_convert_lang_model.py create mode 100644 experiments/colpali_image_test.ipynb create mode 100644 experiments/colpali_text_test.ipynb create mode 100644 experiments/late_interaction_colpali.ipynb diff --git a/experiments/colpali_convert_lang_model.py b/experiments/colpali_convert_lang_model.py new file mode 100644 index 00000000..691f0cba --- /dev/null +++ b/experiments/colpali_convert_lang_model.py @@ -0,0 +1,31 @@ +import torch +from colpali_engine.models import ColPali, ColPaliProcessor +import onnxruntime as ort + +model_name = "vidore/colpali-v1.2" +original_model = ColPali.from_pretrained(model_name).eval() +processor = ColPaliProcessor.from_pretrained(model_name) + +dummy_query = ["Is attention really all you need?"] + +# Process the input query +processed_query = processor.process_queries(dummy_query).to(original_model.device) + +# Prepare input tensors +input_query_tensor = processed_query["input_ids"].type(torch.long) +attention_mask_tensor = processed_query["attention_mask"].type(torch.long) + +# Export the model to ONNX with the required inputs and dynamic shapes +torch.onnx.export( + original_model.model.language_model, + (input_query_tensor, attention_mask_tensor), + "experiments/colpali_text_encoder_dir/model.onnx", + input_names=["input_ids", "attention_mask"], + output_names=["logits"], + dynamo=True, + opset_version=14, +) + + +image_session = ort.InferenceSession("experiments/colpali_text_encoder_dir/model.onnx") +print("Session output", image_session((input_query_tensor, attention_mask_tensor))) diff --git a/experiments/colpali_image_test.ipynb b/experiments/colpali_image_test.ipynb new file mode 100644 index 00000000..354bc596 --- /dev/null +++ b/experiments/colpali_image_test.ipynb @@ -0,0 +1,381 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "id": "initial_id", + "metadata": { + "ExecuteTime": { + "end_time": "2024-11-28T10:02:39.315496Z", + "start_time": "2024-11-28T10:02:39.290846Z" + }, + "collapsed": true + }, + "outputs": [], + "source": [ + "from PIL import Image\n", + "\n", + "images = [\n", + " Image.open(\"/Users/d.rudenko/PycharmProjects/opensource/fastembed/tests/misc/image.jpeg\"),\n", + " Image.open(\n", + " \"/Users/d.rudenko/PycharmProjects/opensource/fastembed/tests/misc/small_image.jpeg\"\n", + " ),\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e46189ce4b8b0677", + "metadata": { + "ExecuteTime": { + "end_time": "2024-11-28T09:58:37.254586Z", + "start_time": "2024-11-28T09:58:22.754066Z" + } + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "ba9856c5109643049718592a236b2206", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Fetching 7 files: 0%| | 0/7 [00:00` tokens in the very beginning of your text and `` token after that. For this call, we will infer how many images each text has and add special tokens.\n" + ] + } + ], + "source": [ + "from colpali_engine.models import ColPaliProcessor\n", + "\n", + "model_name = \"vidore/colpali-v1.2-merged\"\n", + "\n", + "processor = ColPaliProcessor.from_pretrained(model_name)\n", + "# Process the inputs\n", + "batch_images_onnx = processor.process_images(images)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "89c2fbe3d64964fc", + "metadata": { + "ExecuteTime": { + "end_time": "2024-11-28T10:03:56.766986Z", + "start_time": "2024-11-28T10:02:43.893495Z" + } + }, + "outputs": [], + "source": [ + "import onnxruntime as ort\n", + "\n", + "sess = ort.InferenceSession(\"/Users/d.rudenko/dev/qdrant/colpali-v1.2-merged-onnx/model.onnx\")\n", + "image_embeddings_onnx = sess.run(\n", + " [sess.get_outputs()[0].name],\n", + " {\n", + " \"input_ids\": batch_images_onnx[\"input_ids\"].numpy(),\n", + " \"pixel_values\": batch_images_onnx[\"pixel_values\"].numpy(),\n", + " \"attention_mask\": batch_images_onnx[\"attention_mask\"].numpy(),\n", + " },\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "61b43dd6caaa0909", + "metadata": { + "ExecuteTime": { + "end_time": "2024-11-28T10:06:23.238770Z", + "start_time": "2024-11-28T10:06:23.235457Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(1, 2, 1030, 128)" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import numpy as np\n", + "\n", + "np.array(image_embeddings_onnx).shape" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "5be8ebb15c6dfaa6", + "metadata": { + "ExecuteTime": { + "end_time": "2024-11-28T10:59:48.765049Z", + "start_time": "2024-11-28T10:59:48.761122Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[[ 0.015 0.051 0.059 0.026 -0.061 -0.027 -0.014]\n", + " [-0.22 -0.111 0.046 0.081 -0.048 -0.052 -0.086]\n", + " [-0.184 -0.131 0.004 0.062 -0.038 -0.059 -0.127]\n", + " [-0.209 -0.113 0.015 0.059 -0.035 -0.035 -0.072]\n", + " [-0.031 -0.044 0.092 -0.005 0.006 -0.057 -0.061]\n", + " [-0.18 -0.039 0.031 0.003 0.083 -0.041 0.088]\n", + " [-0.091 0.023 0.116 -0.02 0.039 -0.064 -0.026]]\n", + "\n", + " [[-0.25 -0.112 -0.065 -0.014 0.005 -0.092 0.024]\n", + " [-0.22 -0.096 -0.014 0.039 -0.02 -0.12 -0.004]\n", + " [-0.228 -0.114 0.031 0.019 0.034 -0.052 -0.031]\n", + " [-0.274 -0.186 0.095 -0.019 0.017 0.021 -0.016]\n", + " [-0.186 -0.061 -0.01 0.065 -0.058 -0.05 0.019]\n", + " [-0.183 -0.11 -0.034 -0.042 0.026 -0.071 0.02 ]\n", + " [-0.153 -0.072 -0.015 0.088 -0.081 -0.043 0.04 ]]]\n" + ] + } + ], + "source": [ + "print(np.round(image_embeddings_onnx[0][:, :7, :7], decimals=3))" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "bc9f7ffda971d3ba", + "metadata": { + "ExecuteTime": { + "end_time": "2024-11-28T10:59:02.286294Z", + "start_time": "2024-11-28T10:59:02.264997Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "array([ 0.01533 , 0.05118 , 0.05948 , 0.02583 , -0.06128 , -0.02682 ,\n", + " -0.013565, 0.10254 , -0.0983 , 0.1109 , -0.00342 , -0.0344 ,\n", + " -0.00887 , -0.1616 , 0.09814 , 0.2257 , 0.03976 , 0.03687 ,\n", + " 0.1648 , 0.06866 , 0.0396 , 0.1672 , 0.1455 , -0.1387 ,\n", + " 0.1203 , 0.04907 , -0.07965 , -0.0885 , 0.01982 , 0.0404 ,\n", + " -0.07513 , -0.02844 , 0.04337 , 0.03857 , -0.1065 , 0.0288 ,\n", + " -0.1279 , -0.1126 , 0.03363 , -0.0507 , 0.11584 , 0.0483 ,\n", + " 0.035 , -0.08417 , -0.0907 , 0.0279 , 0.1394 , -0.10364 ,\n", + " -0.1471 , -0.07135 , -0.136 , 0.1289 , 0.082 , 0.02232 ,\n", + " -0.00571 , -0.02547 , 0.1053 , 0.0377 , 0.0148 , 0.02795 ,\n", + " -0.01859 , -0.11066 , -0.12195 , 0.0583 , 0.0995 , 0.01086 ,\n", + " 0.0859 , 0.1302 , -0.10126 , 0.005417, 0.05423 , -0.1808 ,\n", + " 0.1444 , 0.1885 , 0.09247 , -0.04718 , 0.1018 , -0.02997 ,\n", + " -0.0598 , -0.011284, 0.1203 , -0.1313 , -0.04584 , -0.02725 ,\n", + " -0.1277 , -0.04236 , -0.08466 , -0.0861 , 0.1131 , 0.02806 ,\n", + " -0.0947 , 0.04388 , 0.04263 , 0.03598 , -0.06866 , -0.06018 ,\n", + " -0.02763 , -0.0972 , 0.11505 , -0.1097 , -0.04166 , 0.0742 ,\n", + " -0.06683 , -0.02188 , -0.1663 , -0.0902 , 0.02594 , -0.03802 ,\n", + " -0.034 , -0.04828 , -0.05765 , 0.0633 , -0.02515 , -0.08826 ,\n", + " -0.09753 , -0.10974 , -0.074 , -0.02083 , -0.1301 , 0.1383 ,\n", + " 0.1428 , 0.0935 , 0.0949 , 0.03876 , 0.08514 , -0.12256 ,\n", + " -0.0451 , -0.002306], dtype=float16)" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "np.array(image_embeddings_onnx)[0][0][0]" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "34a238b20e5fcab2", + "metadata": { + "ExecuteTime": { + "end_time": "2024-11-20T15:39:41.314176Z", + "start_time": "2024-11-20T15:39:41.308579Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import numpy as np\n", + "\n", + "np.allclose(image_embeddings_onnx[0][0], fastembed_i_embeddings[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "5ca3b11eb3813a87", + "metadata": { + "ExecuteTime": { + "end_time": "2024-11-20T15:39:42.081408Z", + "start_time": "2024-11-20T15:39:42.078582Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "array([ 0.01533 , 0.05118 , 0.05948 , 0.02583 , -0.06128 , -0.02682 ,\n", + " -0.013565, 0.10254 , -0.0983 , 0.1109 , -0.00342 , -0.0344 ,\n", + " -0.00887 , -0.1616 , 0.09814 , 0.2257 , 0.03976 , 0.03687 ,\n", + " 0.1648 , 0.06866 , 0.0396 , 0.1672 , 0.1455 , -0.1387 ,\n", + " 0.1203 , 0.04907 , -0.07965 , -0.0885 , 0.01982 , 0.0404 ,\n", + " -0.07513 , -0.02844 , 0.04337 , 0.03857 , -0.1065 , 0.0288 ,\n", + " -0.1279 , -0.1126 , 0.03363 , -0.0507 , 0.11584 , 0.0483 ,\n", + " 0.035 , -0.08417 , -0.0907 , 0.0279 , 0.1394 , -0.10364 ,\n", + " -0.1471 , -0.07135 , -0.136 , 0.1289 , 0.082 , 0.02232 ,\n", + " -0.00571 , -0.02547 , 0.1053 , 0.0377 , 0.0148 , 0.02795 ,\n", + " -0.01859 , -0.11066 , -0.12195 , 0.0583 , 0.0995 , 0.01086 ,\n", + " 0.0859 , 0.1302 , -0.10126 , 0.005417, 0.05423 , -0.1808 ,\n", + " 0.1444 , 0.1885 , 0.09247 , -0.04718 , 0.1018 , -0.02997 ,\n", + " -0.0598 , -0.011284, 0.1203 , -0.1313 , -0.04584 , -0.02725 ,\n", + " -0.1277 , -0.04236 , -0.08466 , -0.0861 , 0.1131 , 0.02806 ,\n", + " -0.0947 , 0.04388 , 0.04263 , 0.03598 , -0.06866 , -0.06018 ,\n", + " -0.02763 , -0.0972 , 0.11505 , -0.1097 , -0.04166 , 0.0742 ,\n", + " -0.06683 , -0.02188 , -0.1663 , -0.0902 , 0.02594 , -0.03802 ,\n", + " -0.034 , -0.04828 , -0.05765 , 0.0633 , -0.02515 , -0.08826 ,\n", + " -0.09753 , -0.10974 , -0.074 , -0.02083 , -0.1301 , 0.1383 ,\n", + " 0.1428 , 0.0935 , 0.0949 , 0.03876 , 0.08514 , -0.12256 ,\n", + " -0.0451 , -0.002306], dtype=float16)" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "image_embeddings_onnx[0][0][0]" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "2c52a4d7d83aeda7", + "metadata": { + "ExecuteTime": { + "end_time": "2024-11-20T16:05:32.115768Z", + "start_time": "2024-11-20T16:05:32.090218Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "[array([[ 0.01532745, 0.05117798, 0.05947876, ..., -0.12255859,\n", + " -0.04510498, -0.00230598],\n", + " [-0.22009277, -0.11071777, 0.04562378, ..., 0.00257111,\n", + " -0.06988525, 0.12384033],\n", + " [-0.18371582, -0.13085938, 0.00393677, ..., -0.02949524,\n", + " -0.05444336, 0.1295166 ],\n", + " ...,\n", + " [-0.1418457 , 0.01023102, 0.1239624 , ..., -0.00460434,\n", + " 0.17321777, 0.09454346],\n", + " [-0.24572754, -0.06878662, 0.11834717, ..., -0.02763367,\n", + " -0.03022766, 0.08917236],\n", + " [-0.2211914 , -0.04171753, 0.19519043, ..., -0.01535797,\n", + " -0.02432251, -0.03561401]], dtype=float32),\n", + " array([[-2.49877930e-01, -1.11511230e-01, -6.51855469e-02, ...,\n", + " 3.19519043e-02, 3.44543457e-02, -1.33666992e-02],\n", + " [-2.20336914e-01, -9.56420898e-02, -1.39694214e-02, ...,\n", + " -8.88705254e-05, -1.57318115e-02, -1.00555420e-02],\n", + " [-2.28271484e-01, -1.14501953e-01, 3.10058594e-02, ...,\n", + " 7.59277344e-02, -4.28466797e-02, 1.19262695e-01],\n", + " ...,\n", + " [-2.04589844e-01, -4.86755371e-02, 8.46557617e-02, ...,\n", + " -3.98254395e-02, 1.66625977e-01, 9.71679688e-02],\n", + " [-2.88085938e-01, -4.50439453e-02, 7.69653320e-02, ...,\n", + " -4.36096191e-02, -1.28784180e-02, 6.26220703e-02],\n", + " [-2.67578125e-01, -3.25317383e-02, 1.66625977e-01, ...,\n", + " -2.90679932e-03, -1.52282715e-02, -3.62243652e-02]], dtype=float32)]" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "fastembed_i_embeddings" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "786bfac25eb7704a", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/experiments/colpali_text_test.ipynb b/experiments/colpali_text_test.ipynb new file mode 100644 index 00000000..b3392089 --- /dev/null +++ b/experiments/colpali_text_test.ipynb @@ -0,0 +1,444 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 6, + "id": "54b3bfd4ad5b9ee6", + "metadata": { + "ExecuteTime": { + "end_time": "2024-11-28T10:44:32.841758Z", + "start_time": "2024-11-28T10:44:32.830025Z" + } + }, + "outputs": [], + "source": [ + "# Your inputs\n", + "queries = [\n", + " # \"Is attention really all you need?\",\n", + " # \"Are Benjamin, Antoine, Merve, and Jo best friends?\",\n", + " # \"Long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long long\"\n", + " \"hello world\",\n", + " \"flag embedding\",\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "86ee1b68fb88b11d", + "metadata": { + "ExecuteTime": { + "end_time": "2024-11-27T22:54:23.016952Z", + "start_time": "2024-11-27T22:33:14.872976Z" + } + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "f1b463d5ae47404f951fecc6629e8008", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Fetching 7 files: 0%| | 0/7 [00:00 None: """ @@ -223,6 +224,7 @@ def load_onnx_model(self) -> None: cuda=self.cuda, device_id=self.device_id, ) + self.tokenizer.enable_truncation(max_length=maxsize) class ColPaliEmbeddingWorker(TextEmbeddingWorker):