diff --git a/langserve/playground.py b/langserve/playground.py index c8cd282f..66ca639c 100644 --- a/langserve/playground.py +++ b/langserve/playground.py @@ -7,10 +7,7 @@ from fastapi.responses import Response from langchain.schema.runnable import Runnable -try: - from pydantic.v1 import BaseModel -except ImportError: - from pydantic import BaseModel +from langserve.pydantic_v1 import BaseModel class PlaygroundTemplate(Template): diff --git a/langserve/pydantic.py b/langserve/pydantic.py deleted file mode 100644 index f0458eb0..00000000 --- a/langserve/pydantic.py +++ /dev/null @@ -1,10 +0,0 @@ -import pydantic - - -def _get_pydantic_version() -> int: - """Get the pydantic major version.""" - return int(pydantic.__version__.split(".")[0]) - - -# Code is written to support both version 1 and 2 -PYDANTIC_MAJOR_VERSION = _get_pydantic_version() diff --git a/langserve/pydantic_v1.py b/langserve/pydantic_v1.py new file mode 100644 index 00000000..b7cfc828 --- /dev/null +++ b/langserve/pydantic_v1.py @@ -0,0 +1,25 @@ +from importlib import metadata + +## Create namespaces for pydantic v1 and v2. +# This code must stay at the top of the file before other modules may +# attempt to import pydantic since it adds pydantic_v1 and pydantic_v2 to sys.modules. +# +# This hack is done for the following reasons: +# * Langchain will attempt to remain compatible with both pydantic v1 and v2 since +# both dependencies and dependents may be stuck on either version of v1 or v2. +# * Creating namespaces for pydantic v1 and v2 should allow us to write code that +# unambiguously uses either v1 or v2 API. +# * This change is easier to roll out and roll back. + +try: + # F401: imported but unused + from pydantic.v1 import BaseModel, Field, ValidationError # noqa: F401 +except ImportError: + from pydantic import BaseModel, Field, ValidationError # noqa: F401 + + +# This is not a pydantic v1 thing, but it feels too small to create a new module for. +try: + _PYDANTIC_MAJOR_VERSION: int = int(metadata.version("pydantic").split(".")[0]) +except metadata.PackageNotFoundError: + _PYDANTIC_MAJOR_VERSION = -1 diff --git a/langserve/schema.py b/langserve/schema.py index ca3ec92c..e2f7608e 100644 --- a/langserve/schema.py +++ b/langserve/schema.py @@ -2,13 +2,9 @@ from typing import Dict, List, Optional, Union from uuid import UUID -from langserve.pydantic import PYDANTIC_MAJOR_VERSION +from pydantic import BaseModel # Floats between v1 and v2 -if PYDANTIC_MAJOR_VERSION == 2: - from pydantic.v1 import BaseModel as BaseModelV1 -else: - from pydantic import BaseModel as BaseModelV1 -from pydantic import BaseModel +from langserve.pydantic_v1 import BaseModel as BaseModelV1 class CustomUserType(BaseModelV1): diff --git a/langserve/serialization.py b/langserve/serialization.py index cec27646..c79da288 100644 --- a/langserve/serialization.py +++ b/langserve/serialization.py @@ -39,14 +39,9 @@ LLMResult, ) +from langserve.pydantic_v1 import BaseModel, ValidationError from langserve.validation import CallbackEvent -try: - from pydantic.v1 import BaseModel, ValidationError -except ImportError: - from pydantic import BaseModel, ValidationError - - logger = logging.getLogger(__name__) diff --git a/langserve/server.py b/langserve/server.py index 60895ed6..9d02cf64 100644 --- a/langserve/server.py +++ b/langserve/server.py @@ -45,7 +45,7 @@ from langserve.callbacks import AsyncEventAggregatorCallback, CallbackEventDict from langserve.lzstring import LZString from langserve.playground import serve_playground -from langserve.pydantic import PYDANTIC_MAJOR_VERSION +from langserve.pydantic_v1 import _PYDANTIC_MAJOR_VERSION from langserve.schema import ( BatchResponseMetadata, CustomUserType, @@ -1045,8 +1045,8 @@ async def feedback(feedback_create_req: FeedbackCreateRequest) -> Feedback: ####################################### # Documentation variants of end points. ####################################### - # At the moment, we only support pydantic 1.x - if PYDANTIC_MAJOR_VERSION == 1: + # At the moment, we only support pydantic 1.x for documentation + if _PYDANTIC_MAJOR_VERSION == 1: @app.post( namespace + "/c/{config_hash}/invoke",