Skip to content

Commit

Permalink
feat: encode text first when both text and uri are presented (#795)
Browse files Browse the repository at this point in the history
* fix: encode text first when both text and uri are presented

* fix: encode text first when both text and uri are presented

* test: add split da test

* fix: typo

* test: test split da
  • Loading branch information
ZiniuYu authored Aug 5, 2022
1 parent 7c6708f commit 65032f0
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 5 deletions.
8 changes: 3 additions & 5 deletions server/clip_server/executors/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,10 @@ def preproc_text(


def split_img_txt_da(doc: 'Document', img_da: 'DocumentArray', txt_da: 'DocumentArray'):
if doc.uri:
img_da.append(doc)
elif doc.blob or (doc.tensor is not None):
img_da.append(doc)
elif doc.text:
if doc.text:
txt_da.append(doc)
elif doc.blob or (doc.tensor is not None) or doc.uri:
img_da.append(doc)


def set_rank(docs, _logit_scale=np.exp(4.60517)):
Expand Down
60 changes: 60 additions & 0 deletions tests/test_helper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import pytest
import numpy as np
from clip_server.executors.helper import numpy_softmax
from clip_server.executors.helper import split_img_txt_da
from docarray import Document, DocumentArray


@pytest.mark.parametrize('shape', [(5, 10), (5, 10, 10)])
Expand All @@ -17,3 +19,61 @@ def test_numpy_softmax(shape, axis):
np_softmax = numpy_softmax(logits, axis=axis)
torch_softmax = torch.from_numpy(logits).softmax(dim=axis).numpy()
np.testing.assert_array_almost_equal(np_softmax, torch_softmax)


@pytest.mark.parametrize(
'inputs',
[
(
DocumentArray(
[
Document(text='hello, world'),
Document(text='goodbye, world'),
Document(
text='hello, world',
uri='https://docarray.jina.ai/_static/favicon.png',
),
Document(
uri='https://docarray.jina.ai/_static/favicon.png',
),
]
),
(3, 1),
),
(
DocumentArray(
[
Document(text='hello, world'),
Document(tensor=np.array([0, 1, 2])),
Document(
uri='https://docarray.jina.ai/_static/favicon.png'
).load_uri_to_blob(),
Document(
tensor=np.array([0, 1, 2]),
uri='https://docarray.jina.ai/_static/favicon.png',
),
Document(
uri='https://docarray.jina.ai/_static/favicon.png',
),
]
),
(1, 4),
),
(
DocumentArray(
[
Document(text='hello, world'),
Document(uri='https://docarray.jina.ai/_static/favicon.png'),
]
),
(1, 1),
),
],
)
def test_split_img_txt_da(inputs):
txt_da = DocumentArray()
img_da = DocumentArray()
for doc in inputs[0]:
split_img_txt_da(doc, img_da, txt_da)
assert len(txt_da) == inputs[1][0]
assert len(img_da) == inputs[1][1]

0 comments on commit 65032f0

Please sign in to comment.