Skip to content

Commit

Permalink
fix: Fix jina clip text v1
Browse files Browse the repository at this point in the history
  • Loading branch information
hh-space-invader committed Dec 10, 2024
1 parent 81739cd commit 74a381e
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 16 deletions.
19 changes: 18 additions & 1 deletion fastembed/text/onnx_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,17 @@
},
"model_file": "onnx/model.onnx",
},
{
"model": "jinaai/jina-clip-v1",
"dim": 768,
"description": "Text embeddings, Multimodal (text&image), English, Prefixes for queries/documents: not necessary, 2024 year",
"license": "apache-2.0",
"size_in_GB": 0.55,
"sources": {
"hf": "jinaai/jina-clip-v1",
},
"model_file": "onnx/text_model.onnx",
},
]


Expand Down Expand Up @@ -285,7 +296,13 @@ def _preprocess_onnx_input(

def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[np.ndarray]:
embeddings = output.model_output
return normalize(embeddings[:, 0]).astype(np.float32)
if embeddings.ndim == 3: # (batch_size, seq_len, embedding_dim)
processed_embeddings = embeddings[:, 0]
elif embeddings.ndim == 2: # (batch_size, embedding_dim)
processed_embeddings = embeddings
else:
raise ValueError(f"Unsupported embedding shape: {embeddings.shape}")
return normalize(processed_embeddings).astype(np.float32)

def load_onnx_model(self) -> None:
self._load_onnx_model(
Expand Down
4 changes: 0 additions & 4 deletions fastembed/text/pooled_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,6 @@ def _get_worker_class(cls) -> Type[TextEmbeddingWorker]:

@classmethod
def mean_pooling(cls, model_output: np.ndarray, attention_mask: np.ndarray) -> np.ndarray:
if model_output.ndim == 2: # (batch, embedding_dim)
seq_length = attention_mask.shape[1]
# (batch, seq_length, embedding_dim)
model_output = np.tile(np.expand_dims(model_output, axis=1), (1, seq_length, 1))
token_embeddings = model_output
input_mask_expanded = np.expand_dims(attention_mask, axis=-1)
input_mask_expanded = np.tile(input_mask_expanded, (1, 1, token_embeddings.shape[-1]))
Expand Down
11 changes: 0 additions & 11 deletions fastembed/text/pooled_normalized_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,17 +75,6 @@
"sources": {"hf": "jinaai/jina-embeddings-v2-base-es"},
"model_file": "onnx/model.onnx",
},
{
"model": "jinaai/jina-clip-v1",
"dim": 768,
"description": "Text embeddings, Multimodal (text&image), English, Prefixes for queries/documents: not necessary, 2024 year",
"license": "apache-2.0",
"size_in_GB": 0.55,
"sources": {
"hf": "jinaai/jina-clip-v1",
},
"model_file": "onnx/text_model.onnx",
},
]


Expand Down

0 comments on commit 74a381e

Please sign in to comment.