Skip to content

Commit

Permalink
fix: Remove global state (#1216)
Browse files Browse the repository at this point in the history
* Remove all global settings state

* chore: remove autogenerated class

* chore: cleanup

* chore: merge conflicts
  • Loading branch information
pabloogc authored Nov 12, 2023
1 parent f394ca6 commit 022bd71
Show file tree
Hide file tree
Showing 24 changed files with 286 additions and 190 deletions.
2 changes: 1 addition & 1 deletion private_gpt/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@
# Set log_config=None to do not use the uvicorn logging configuration, and
# use ours instead. For reference, see below:
# https://github.com/tiangolo/fastapi/discussions/7457#discussioncomment-5141108
uvicorn.run(app, host="0.0.0.0", port=settings.server.port, log_config=None)
uvicorn.run(app, host="0.0.0.0", port=settings().server.port, log_config=None)
4 changes: 2 additions & 2 deletions private_gpt/components/embedding/embedding_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
from llama_index.embeddings.base import BaseEmbedding

from private_gpt.paths import models_cache_path
from private_gpt.settings.settings import settings
from private_gpt.settings.settings import Settings


@singleton
class EmbeddingComponent:
embedding_model: BaseEmbedding

@inject
def __init__(self) -> None:
def __init__(self, settings: Settings) -> None:
match settings.llm.mode:
case "local":
from llama_index.embeddings import HuggingFaceEmbedding
Expand Down
4 changes: 2 additions & 2 deletions private_gpt/components/llm/llm_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
from llama_index.llms.llama_utils import completion_to_prompt, messages_to_prompt

from private_gpt.paths import models_path
from private_gpt.settings.settings import settings
from private_gpt.settings.settings import Settings


@singleton
class LLMComponent:
llm: LLM

@inject
def __init__(self) -> None:
def __init__(self, settings: Settings) -> None:
match settings.llm.mode:
case "local":
from llama_index.llms import LlamaCPP
Expand Down
16 changes: 13 additions & 3 deletions private_gpt/di.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
from injector import Injector

from private_gpt.settings.settings import Settings, unsafe_typed_settings


def create_application_injector() -> Injector:
injector = Injector(auto_bind=True)
return injector
_injector = Injector(auto_bind=True)
_injector.binder.bind(Settings, to=unsafe_typed_settings)
return _injector


"""
Global injector for the application.
Avoid using this reference, it will make your code harder to test.
root_injector: Injector = create_application_injector()
Instead, use the `request.state.injector` reference, which is bound to every request
"""
global_injector: Injector = create_application_injector()
128 changes: 128 additions & 0 deletions private_gpt/launcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
"""FastAPI app creation, logger configuration and main API routes."""
import logging
from typing import Any

from fastapi import Depends, FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.openapi.utils import get_openapi
from injector import Injector

from private_gpt.paths import docs_path
from private_gpt.server.chat.chat_router import chat_router
from private_gpt.server.chunks.chunks_router import chunks_router
from private_gpt.server.completions.completions_router import completions_router
from private_gpt.server.embeddings.embeddings_router import embeddings_router
from private_gpt.server.health.health_router import health_router
from private_gpt.server.ingest.ingest_router import ingest_router
from private_gpt.settings.settings import Settings

logger = logging.getLogger(__name__)


def create_app(root_injector: Injector) -> FastAPI:

# Start the API
with open(docs_path / "description.md") as description_file:
description = description_file.read()

tags_metadata = [
{
"name": "Ingestion",
"description": "High-level APIs covering document ingestion -internally "
"managing document parsing, splitting,"
"metadata extraction, embedding generation and storage- and ingested "
"documents CRUD."
"Each ingested document is identified by an ID that can be used to filter the "
"context"
"used in *Contextual Completions* and *Context Chunks* APIs.",
},
{
"name": "Contextual Completions",
"description": "High-level APIs covering contextual Chat and Completions. They "
"follow OpenAI's format, extending it to "
"allow using the context coming from ingested documents to create the "
"response. Internally"
"manage context retrieval, prompt engineering and the response generation.",
},
{
"name": "Context Chunks",
"description": "Low-level API that given a query return relevant chunks of "
"text coming from the ingested"
"documents.",
},
{
"name": "Embeddings",
"description": "Low-level API to obtain the vector representation of a given "
"text, using an Embeddings model."
"Follows OpenAI's embeddings API format.",
},
{
"name": "Health",
"description": "Simple health API to make sure the server is up and running.",
},
]

async def bind_injector_to_request(request: Request) -> None:
request.state.injector = root_injector

app = FastAPI(dependencies=[Depends(bind_injector_to_request)])

def custom_openapi() -> dict[str, Any]:
if app.openapi_schema:
return app.openapi_schema
openapi_schema = get_openapi(
title="PrivateGPT",
description=description,
version="0.1.0",
summary="PrivateGPT is a production-ready AI project that allows you to "
"ask questions to your documents using the power of Large Language "
"Models (LLMs), even in scenarios without Internet connection. "
"100% private, no data leaves your execution environment at any point.",
contact={
"url": "https://github.com/imartinez/privateGPT",
},
license_info={
"name": "Apache 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0.html",
},
routes=app.routes,
tags=tags_metadata,
)
openapi_schema["info"]["x-logo"] = {
"url": "https://lh3.googleusercontent.com/drive-viewer"
"/AK7aPaD_iNlMoTquOBsw4boh4tIYxyEuhz6EtEs8nzq3yNkNAK00xGj"
"E1KUCmPJSk3TYOjcs6tReG6w_cLu1S7L_gPgT9z52iw=s2560"
}

app.openapi_schema = openapi_schema
return app.openapi_schema

app.openapi = custom_openapi # type: ignore[method-assign]

app.include_router(completions_router)
app.include_router(chat_router)
app.include_router(chunks_router)
app.include_router(ingest_router)
app.include_router(embeddings_router)
app.include_router(health_router)

settings = root_injector.get(Settings)
if settings.server.cors.enabled:
logger.debug("Setting up CORS middleware")
app.add_middleware(
CORSMiddleware,
allow_credentials=settings.server.cors.allow_credentials,
allow_origins=settings.server.cors.allow_origins,
allow_origin_regex=settings.server.cors.allow_origin_regex,
allow_methods=settings.server.cors.allow_methods,
allow_headers=settings.server.cors.allow_headers,
)

if settings.ui.enabled:
logger.debug("Importing the UI module")
from private_gpt.ui.ui import PrivateGptUi

ui = root_injector.get(PrivateGptUi)
ui.mount_in_app(app, settings.ui.path)

return app
119 changes: 3 additions & 116 deletions private_gpt/main.py
Original file line number Diff line number Diff line change
@@ -1,124 +1,11 @@
"""FastAPI app creation, logger configuration and main API routes."""
import logging
from typing import Any

import llama_index
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.openapi.utils import get_openapi

from private_gpt.paths import docs_path
from private_gpt.server.chat.chat_router import chat_router
from private_gpt.server.chunks.chunks_router import chunks_router
from private_gpt.server.completions.completions_router import completions_router
from private_gpt.server.embeddings.embeddings_router import embeddings_router
from private_gpt.server.health.health_router import health_router
from private_gpt.server.ingest.ingest_router import ingest_router
from private_gpt.settings.settings import settings

logger = logging.getLogger(__name__)
from private_gpt.di import global_injector
from private_gpt.launcher import create_app

# Add LlamaIndex simple observability
llama_index.set_global_handler("simple")

# Start the API
with open(docs_path / "description.md") as description_file:
description = description_file.read()

tags_metadata = [
{
"name": "Ingestion",
"description": "High-level APIs covering document ingestion -internally "
"managing document parsing, splitting,"
"metadata extraction, embedding generation and storage- and ingested "
"documents CRUD."
"Each ingested document is identified by an ID that can be used to filter the "
"context"
"used in *Contextual Completions* and *Context Chunks* APIs.",
},
{
"name": "Contextual Completions",
"description": "High-level APIs covering contextual Chat and Completions. They "
"follow OpenAI's format, extending it to "
"allow using the context coming from ingested documents to create the "
"response. Internally"
"manage context retrieval, prompt engineering and the response generation.",
},
{
"name": "Context Chunks",
"description": "Low-level API that given a query return relevant chunks of "
"text coming from the ingested"
"documents.",
},
{
"name": "Embeddings",
"description": "Low-level API to obtain the vector representation of a given "
"text, using an Embeddings model."
"Follows OpenAI's embeddings API format.",
},
{
"name": "Health",
"description": "Simple health API to make sure the server is up and running.",
},
]

app = FastAPI()


def custom_openapi() -> dict[str, Any]:
if app.openapi_schema:
return app.openapi_schema
openapi_schema = get_openapi(
title="PrivateGPT",
description=description,
version="0.1.0",
summary="PrivateGPT is a production-ready AI project that allows you to "
"ask questions to your documents using the power of Large Language "
"Models (LLMs), even in scenarios without Internet connection. "
"100% private, no data leaves your execution environment at any point.",
contact={
"url": "https://github.com/imartinez/privateGPT",
},
license_info={
"name": "Apache 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0.html",
},
routes=app.routes,
tags=tags_metadata,
)
openapi_schema["info"]["x-logo"] = {
"url": "https://lh3.googleusercontent.com/drive-viewer"
"/AK7aPaD_iNlMoTquOBsw4boh4tIYxyEuhz6EtEs8nzq3yNkNAK00xGj"
"E1KUCmPJSk3TYOjcs6tReG6w_cLu1S7L_gPgT9z52iw=s2560"
}

app.openapi_schema = openapi_schema
return app.openapi_schema


app.openapi = custom_openapi # type: ignore[method-assign]

app.include_router(completions_router)
app.include_router(chat_router)
app.include_router(chunks_router)
app.include_router(ingest_router)
app.include_router(embeddings_router)
app.include_router(health_router)

if settings.server.cors.enabled:
logger.debug("Setting up CORS middleware")
app.add_middleware(
CORSMiddleware,
allow_credentials=settings.server.cors.allow_credentials,
allow_origins=settings.server.cors.allow_origins,
allow_origin_regex=settings.server.cors.allow_origin_regex,
allow_methods=settings.server.cors.allow_methods,
allow_headers=settings.server.cors.allow_headers,
)


if settings.ui.enabled:
logger.debug("Importing the UI module")
from private_gpt.ui.ui import PrivateGptUi

PrivateGptUi().mount_in_app(app)
app = create_app(global_injector)
4 changes: 3 additions & 1 deletion private_gpt/paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,6 @@ def _absolute_or_from_project_root(path: str) -> Path:
models_path: Path = PROJECT_ROOT_PATH / "models"
models_cache_path: Path = models_path / "cache"
docs_path: Path = PROJECT_ROOT_PATH / "docs"
local_data_path: Path = _absolute_or_from_project_root(settings.data.local_data_folder)
local_data_path: Path = _absolute_or_from_project_root(
settings().data.local_data_folder
)
9 changes: 5 additions & 4 deletions private_gpt/server/chat/chat_router.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, Request
from llama_index.llms import ChatMessage, MessageRole
from pydantic import BaseModel
from starlette.responses import StreamingResponse

from private_gpt.di import root_injector
from private_gpt.open_ai.extensions.context_filter import ContextFilter
from private_gpt.open_ai.openai_models import (
OpenAICompletion,
Expand Down Expand Up @@ -52,7 +51,9 @@ class ChatBody(BaseModel):
responses={200: {"model": OpenAICompletion}},
tags=["Contextual Completions"],
)
def chat_completion(body: ChatBody) -> OpenAICompletion | StreamingResponse:
def chat_completion(
request: Request, body: ChatBody
) -> OpenAICompletion | StreamingResponse:
"""Given a list of messages comprising a conversation, return a response.
If `use_context` is set to `true`, the model will use context coming
Expand All @@ -72,7 +73,7 @@ def chat_completion(body: ChatBody) -> OpenAICompletion | StreamingResponse:
"finish_reason":null}]}
```
"""
service = root_injector.get(ChatService)
service = request.state.injector.get(ChatService)
all_messages = [
ChatMessage(content=m.content, role=MessageRole(m.role)) for m in body.messages
]
Expand Down
7 changes: 3 additions & 4 deletions private_gpt/server/chunks/chunks_router.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from typing import Literal

from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, Request
from pydantic import BaseModel, Field

from private_gpt.di import root_injector
from private_gpt.open_ai.extensions.context_filter import ContextFilter
from private_gpt.server.chunks.chunks_service import Chunk, ChunksService
from private_gpt.server.utils.auth import authenticated
Expand All @@ -25,7 +24,7 @@ class ChunksResponse(BaseModel):


@chunks_router.post("/chunks", tags=["Context Chunks"])
def chunks_retrieval(body: ChunksBody) -> ChunksResponse:
def chunks_retrieval(request: Request, body: ChunksBody) -> ChunksResponse:
"""Given a `text`, returns the most relevant chunks from the ingested documents.
The returned information can be used to generate prompts that can be
Expand All @@ -45,7 +44,7 @@ def chunks_retrieval(body: ChunksBody) -> ChunksResponse:
`/ingest/list` endpoint. If you want all ingested documents to be used,
remove `context_filter` altogether.
"""
service = root_injector.get(ChunksService)
service = request.state.injector.get(ChunksService)
results = service.retrieve_relevant(
body.text, body.context_filter, body.limit, body.prev_next_chunks
)
Expand Down
Loading

0 comments on commit 022bd71

Please sign in to comment.