From 4cc9e41f53bdcb2cfcfb132019aa34f4da7c65d0 Mon Sep 17 00:00:00 2001 From: Joan Martinez Date: Thu, 7 Mar 2024 17:54:28 +0100 Subject: [PATCH] fix: hack for LegacyDocument --- jina/_docarray.py | 2 + jina/_docarray_legacy.py | 49 +++++++++++++++++++ jina/serve/executors/__init__.py | 4 +- .../runtimes/gateway/graph/topology_graph.py | 10 ++-- jina/serve/runtimes/head/request_handling.py | 11 +++-- .../serve/runtimes/worker/request_handling.py | 7 +-- 6 files changed, 67 insertions(+), 16 deletions(-) create mode 100644 jina/_docarray_legacy.py diff --git a/jina/_docarray.py b/jina/_docarray.py index 96ae3dd1d352d..499fdb67e6f40 100644 --- a/jina/_docarray.py +++ b/jina/_docarray.py @@ -4,6 +4,8 @@ docarray_v2 = True + from jina._docarray_legacy import LegacyDocumentJina + except ImportError: from docarray import Document, DocumentArray diff --git a/jina/_docarray_legacy.py b/jina/_docarray_legacy.py new file mode 100644 index 0000000000000..61a2347f852b7 --- /dev/null +++ b/jina/_docarray_legacy.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +from docarray import BaseDoc +from docarray import DocList + +docarray_v2 = True + +from typing import Any, Dict, Optional, List, Union + +from docarray.typing import AnyEmbedding, AnyTensor + + +class LegacyDocumentJina(BaseDoc): + """ + This Document is the LegacyDocumentJina. It follows the same schema as in DocArray <=0.21. + It can be useful to start migrating a codebase from v1 to v2. + + Nevertheless, the API is not totally compatible with DocArray <=0.21 `Document`. + Indeed, none of the method associated with `Document` are present. Only the schema + of the data is similar. + + ```python + from docarray import DocList + from docarray.documents.legacy import LegacyDocument + import numpy as np + + doc = LegacyDocument(text='hello') + doc.url = 'http://myimg.png' + doc.tensor = np.zeros((3, 224, 224)) + doc.embedding = np.zeros((100, 1)) + + doc.tags['price'] = 10 + + doc.chunks = DocList[Document]([Document() for _ in range(10)]) + + doc.chunks = DocList[Document]([Document() for _ in range(10)]) + ``` + + """ + + tensor: Optional[AnyTensor] = None + chunks: Optional[Union[DocList[LegacyDocumentJina], List[LegacyDocumentJina]]] = None + matches: Optional[Union[DocList[LegacyDocumentJina], List[LegacyDocumentJina]]] = None + blob: Optional[bytes] = None + text: Optional[str] = None + url: Optional[str] = None + embedding: Optional[AnyEmbedding] = None + tags: Dict[str, Any] = dict() + scores: Optional[Dict[str, Any]] = None diff --git a/jina/serve/executors/__init__.py b/jina/serve/executors/__init__.py index c5c8f72a8e6c1..3f3c83e5a2083 100644 --- a/jina/serve/executors/__init__.py +++ b/jina/serve/executors/__init__.py @@ -52,7 +52,7 @@ from jina.serve.instrumentation import MetricsTimer if docarray_v2: - from docarray.documents.legacy import LegacyDocument + from jina._docarray import LegacyDocumentJina if TYPE_CHECKING: # pragma: no cover from opentelemetry.context.context import Context @@ -257,7 +257,7 @@ def get_function_with_schema(fn: Callable) -> T: from docarray import BaseDoc, DocList default_annotations = ( - DocList[LegacyDocument] if is_batch_docs else LegacyDocument + DocList[LegacyDocumentJina] if is_batch_docs else LegacyDocumentJina ) else: from jina import Document, DocumentArray diff --git a/jina/serve/runtimes/gateway/graph/topology_graph.py b/jina/serve/runtimes/gateway/graph/topology_graph.py index 2a6c01ea9b633..2c20e803bb73b 100644 --- a/jina/serve/runtimes/gateway/graph/topology_graph.py +++ b/jina/serve/runtimes/gateway/graph/topology_graph.py @@ -18,7 +18,7 @@ if docarray_v2: from docarray import DocList - from docarray.documents.legacy import LegacyDocument + from jina._docarray import LegacyDocumentJina if not is_pydantic_v2: from jina.serve.runtimes.helper import _create_pydantic_model_from_schema as create_base_doc_from_schema @@ -26,7 +26,7 @@ from docarray.utils.create_dynamic_doc_class import create_base_doc_from_schema - legacy_doc_schema = LegacyDocument.schema() + legacy_doc_schema = LegacyDocumentJina.schema() class TopologyGraph: @@ -222,8 +222,6 @@ async def task(): endp, _ = endpoints_proto self.endpoints = endp.endpoints if docarray_v2: - from docarray.documents.legacy import LegacyDocument - schemas = json_format.MessageToDict(endp.schemas) self._pydantic_models_by_endpoint = {} models_created_by_name = {} @@ -240,7 +238,7 @@ async def task(): else: if input_model_name not in models_created_by_name: if input_model_schema == legacy_doc_schema: - input_model = LegacyDocument + input_model = LegacyDocumentJina else: input_model = ( create_base_doc_from_schema( @@ -270,7 +268,7 @@ async def task(): else: if output_model_name not in models_created_by_name: if output_model_name == legacy_doc_schema: - output_model = LegacyDocument + output_model = LegacyDocumentJina else: output_model = ( create_base_doc_from_schema( diff --git a/jina/serve/runtimes/head/request_handling.py b/jina/serve/runtimes/head/request_handling.py index 417c7a865ac6d..e883b901a55ae 100644 --- a/jina/serve/runtimes/head/request_handling.py +++ b/jina/serve/runtimes/head/request_handling.py @@ -26,6 +26,10 @@ from docarray import DocList from docarray.base_doc.any_doc import AnyDoc + from jina._docarray import LegacyDocumentJina + + legacy_doc_schema = LegacyDocumentJina.schema() + if TYPE_CHECKING: # pragma: no cover from prometheus_client import CollectorRegistry @@ -333,9 +337,6 @@ def _get_endpoints_from_workers( self, connection_pool: GrpcConnectionPool, name: str, retries: int, stop_event ): from google.protobuf import json_format - from docarray.documents.legacy import LegacyDocument - - legacy_doc_schema = LegacyDocument.schema() async def task(): self.logger.debug( @@ -359,7 +360,7 @@ async def task(): if input_model_schema == legacy_doc_schema: models_created_by_name[input_model_name] = ( - LegacyDocument + LegacyDocumentJina ) elif input_model_name not in models_created_by_name: input_model = create_base_doc_from_schema( @@ -369,7 +370,7 @@ async def task(): if output_model_name == legacy_doc_schema: models_created_by_name[output_model_name] = ( - LegacyDocument + LegacyDocumentJina ) elif output_model_name not in models_created_by_name: output_model = create_base_doc_from_schema( diff --git a/jina/serve/runtimes/worker/request_handling.py b/jina/serve/runtimes/worker/request_handling.py index 65472cd6d406f..16b79d1e0e047 100644 --- a/jina/serve/runtimes/worker/request_handling.py +++ b/jina/serve/runtimes/worker/request_handling.py @@ -33,6 +33,9 @@ if docarray_v2: from docarray import DocList + from jina._docarray import LegacyDocumentJina + legacy_doc_schema = LegacyDocumentJina.schema() + if TYPE_CHECKING: # pragma: no cover import grpc @@ -1011,14 +1014,12 @@ async def endpoint_discovery(self, empty, context) -> jina_pb2.EndpointsProto: endpoints_proto.write_endpoints.extend(list(self._executor.write_endpoints)) schemas = self._executor._get_endpoint_models_dict() if docarray_v2: - from docarray.documents.legacy import LegacyDocument - if not is_pydantic_v2: from jina.serve.runtimes.helper import _create_aux_model_doc_list_to_list as create_pure_python_type_model else: from docarray.utils.create_dynamic_doc_class import create_pure_python_type_model - legacy_doc_schema = LegacyDocument.schema() + for endpoint_name, inner_dict in schemas.items(): if inner_dict['input']['model'].schema() == legacy_doc_schema: inner_dict['input']['model'] = legacy_doc_schema