From 073511e66b69e2f40f3ab548432eb3c14a79634d Mon Sep 17 00:00:00 2001 From: Joan Fontanals Date: Thu, 15 Feb 2024 23:46:07 +0100 Subject: [PATCH] fix: fix subclass with DocList (#6138) Signed-off-by: Joan Martinez --- extra-requirements.txt | 2 +- jina/serve/runtimes/helper.py | 4 +- .../docarray_v2/issues/__init__.py | 0 .../issues/github_6137/__init__.py | 0 .../issues/github_6137/test_issue.py | 30 +++++++++++++ tests/unit/serve/runtimes/test_helper.py | 43 +++++++++++++++---- 6 files changed, 68 insertions(+), 11 deletions(-) create mode 100644 tests/integration/docarray_v2/issues/__init__.py create mode 100644 tests/integration/docarray_v2/issues/github_6137/__init__.py create mode 100644 tests/integration/docarray_v2/issues/github_6137/test_issue.py diff --git a/extra-requirements.txt b/extra-requirements.txt index 18dcfcdd2d4c9..089c12007756e 100644 --- a/extra-requirements.txt +++ b/extra-requirements.txt @@ -58,7 +58,7 @@ aiofiles: standard,devel aiohttp: standard,devel scipy>=1.6.1: test Pillow: test -pytest: test +pytest<8.0.0: test pytest-timeout: test pytest-mock: test pytest-cov==3.0.0: test diff --git a/jina/serve/runtimes/helper.py b/jina/serve/runtimes/helper.py index fd4633c6d9352..70bb75a485c1b 100644 --- a/jina/serve/runtimes/helper.py +++ b/jina/serve/runtimes/helper.py @@ -108,7 +108,8 @@ def _create_aux_model_doc_list_to_list(model): try: if issubclass(field, DocList): t: Any = field.doc_type - fields[field_name] = (List[t], field_info) + t_aux = _create_aux_model_doc_list_to_list(t) + fields[field_name] = (List[t_aux], field_info) else: fields[field_name] = (field, field_info) except TypeError: @@ -272,7 +273,6 @@ def _create_pydantic_model_from_schema( ) -> type: if not definitions: definitions = schema.get('definitions', {}) - cached_models = cached_models if cached_models is not None else {} fields: Dict[str, Any] = {} if model_name in cached_models: diff --git a/tests/integration/docarray_v2/issues/__init__.py b/tests/integration/docarray_v2/issues/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/integration/docarray_v2/issues/github_6137/__init__.py b/tests/integration/docarray_v2/issues/github_6137/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/integration/docarray_v2/issues/github_6137/test_issue.py b/tests/integration/docarray_v2/issues/github_6137/test_issue.py new file mode 100644 index 0000000000000..9dc91f9098963 --- /dev/null +++ b/tests/integration/docarray_v2/issues/github_6137/test_issue.py @@ -0,0 +1,30 @@ +from docarray import DocList, BaseDoc +from docarray.documents.text import TextDoc +from jina import Executor, requests, Flow + + +def test_issue(): + class QuoteFile(BaseDoc): + quote_file_id: int = None + texts: DocList[TextDoc] = None + + class SearchResult(BaseDoc): + results: DocList[QuoteFile] = None + + class InitialExecutor(Executor): + + @requests(on='/search') + async def search(self, docs: DocList[SearchResult], **kwargs) -> DocList[SearchResult]: + return docs + + f = ( + Flow(protocol='http') + .add(name='initial', uses=InitialExecutor) + ) + + with f: + resp = f.post(on='/search', inputs=DocList[SearchResult]([SearchResult(results=DocList[QuoteFile]( + [QuoteFile(quote_file_id=999, texts=DocList[TextDoc]([TextDoc(text='hey here')]))]))]), + return_type=DocList[SearchResult]) + assert resp[0].results[0].quote_file_id == 999 + assert resp[0].results[0].texts[0].text == 'hey here' diff --git a/tests/unit/serve/runtimes/test_helper.py b/tests/unit/serve/runtimes/test_helper.py index 5adcfa57b67c3..43abba91f2157 100644 --- a/tests/unit/serve/runtimes/test_helper.py +++ b/tests/unit/serve/runtimes/test_helper.py @@ -44,15 +44,15 @@ def test_split_key_executor_name(full_key, key, executor): 'param, parsed_param, executor_name', [ ( - {'key': 1, 'executor__key': 2, 'wrong_executor__key': 3}, - {'key': 2}, - 'executor', + {'key': 1, 'executor__key': 2, 'wrong_executor__key': 3}, + {'key': 2}, + 'executor', ), ({'executor__key': 2, 'wrong_executor__key': 3}, {'key': 2}, 'executor'), ( - {'a': 1, 'executor__key': 2, 'wrong_executor__key': 3}, - {'key': 2, 'a': 1}, - 'executor', + {'a': 1, 'executor__key': 2, 'wrong_executor__key': 3}, + {'key': 2, 'a': 1}, + 'executor', ), ({'key_1': 0, 'exec2__key_2': 1}, {'key_1': 0}, 'executor'), ], @@ -69,8 +69,8 @@ def test_get_name_from_replicas(name_w_replicas, name): def _custom_grpc_options( - call_recording_mock: Mock, - additional_options: Optional[Union[list, Dict[str, Any]]] = None, + call_recording_mock: Mock, + additional_options: Optional[Union[list, Dict[str, Any]]] = None, ) -> List[Tuple[str, Any]]: call_recording_mock() expected_grpc_option_keys = [ @@ -355,3 +355,30 @@ class ResultTestDoc(BaseDoc): assert len(original_back) == 0 assert len(custom_da) == 0 + + +@pytest.mark.skipif(not docarray_v2, reason='Test only working with docarray v2') +def test_dynamic_class_creation_multiple_doclist_nested(): + from docarray import BaseDoc, DocList + from jina.serve.runtimes.helper import _create_aux_model_doc_list_to_list + from jina.serve.runtimes.helper import _create_pydantic_model_from_schema + + class MyTextDoc(BaseDoc): + text: str + + class QuoteFile(BaseDoc): + texts: DocList[MyTextDoc] + + class SearchResult(BaseDoc): + results: DocList[QuoteFile] = None + + textlist = DocList[MyTextDoc]([MyTextDoc(text='hey')]) + models_created_by_name = {} + SearchResult_aux = _create_aux_model_doc_list_to_list(SearchResult) + _ = _create_pydantic_model_from_schema(SearchResult_aux.schema(), 'SearchResult', + models_created_by_name) + QuoteFile_reconstructed_in_gateway_from_Search_results = models_created_by_name['QuoteFile'] + + reconstructed_in_gateway_from_Search_results = QuoteFile_reconstructed_in_gateway_from_Search_results( + texts=textlist) + assert reconstructed_in_gateway_from_Search_results.texts[0].text == 'hey'