Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

User Management #1562

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions py/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@
# LLM provider
"CompletionConfig",
"CompletionProvider",
# User management provider
"UserManagementConfig",
"UserManagementProvider",
## UTILS
"RecursiveCharacterTextSplitter",
"TextSplitter",
Expand Down
3 changes: 3 additions & 0 deletions py/core/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@
# LLM provider
"CompletionConfig",
"CompletionProvider",
# User management provider
"UserManagementConfig",
"UserManagementProvider",
## UTILS
"RecursiveCharacterTextSplitter",
"TextSplitter",
Expand Down
4 changes: 2 additions & 2 deletions py/core/base/api/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from shared.api.models.auth.responses import (
GenericMessageResponse,
TokenResponse,
UserResponse,
WrappedGenericMessageResponse,
WrappedTokenResponse,
WrappedUserResponse,
)
from shared.api.models.ingestion.responses import (
CreateVectorIndexResponse,
Expand Down Expand Up @@ -46,6 +44,7 @@
ScoreCompletionResponse,
ServerStats,
UserOverviewResponse,
UserResponse,
WrappedAddUserResponse,
WrappedAnalyticsResponse,
WrappedAppSettingsResponse,
Expand All @@ -63,6 +62,7 @@
WrappedServerStatsResponse,
WrappedUserCollectionResponse,
WrappedUserOverviewResponse,
WrappedUserResponse,
WrappedUsersInCollectionResponse,
WrappedVerificationResult,
)
Expand Down
4 changes: 4 additions & 0 deletions py/core/base/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .ingestion import ChunkingStrategy, IngestionConfig, IngestionProvider
from .llm import CompletionConfig, CompletionProvider
from .orchestration import OrchestrationConfig, OrchestrationProvider, Workflow
from .user_management import UserManagementConfig, UserManagementProvider

__all__ = [
# Auth provider
Expand Down Expand Up @@ -64,4 +65,7 @@
"OrchestrationConfig",
"OrchestrationProvider",
"Workflow",
# User management provider
"UserManagementConfig",
"UserManagementProvider",
]
96 changes: 42 additions & 54 deletions py/core/base/providers/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
Any,
AsyncGenerator,
BinaryIO,
Dict,
List,
Optional,
Sequence,
Tuple,
Expand Down Expand Up @@ -50,7 +48,6 @@
KGEnrichmentEstimationResponse,
UserResponse,
)
from core.base.utils import _decorate_vector_type

from ..logger import RunInfoLog
from ..logger.base import RunType
Expand Down Expand Up @@ -79,15 +76,6 @@
logger = logging.getLogger()


def escape_braces(s: str) -> str:
"""
Escape braces in a string.
This is a placeholder function - implement the actual logic as needed.
"""
# Implement your escape_braces logic here
return s.replace("{", "{{").replace("}", "}}")


logger = logging.getLogger()


Expand Down Expand Up @@ -658,14 +646,14 @@ async def add_triples(
@abstractmethod
async def get_entity_map(
self, offset: int, limit: int, document_id: UUID
) -> Dict[str, Dict[str, List[Dict[str, Any]]]]:
) -> dict[str, dict[str, list[dict[str, Any]]]]:
"""Get entity map for a document."""
pass

@abstractmethod
async def upsert_embeddings(
self,
data: List[Tuple[Any]],
data: list[Tuple[Any]],
table_name: str,
) -> None:
"""Upsert embeddings into storage."""
Expand All @@ -680,7 +668,7 @@ async def vector_query(

# Community management
@abstractmethod
async def add_community_info(self, communities: List[Any]) -> None:
async def add_community_info(self, communities: list[Any]) -> None:
"""Add communities to storage."""
pass

Expand Down Expand Up @@ -713,22 +701,22 @@ async def get_community_details(
@abstractmethod
async def get_community_reports(
self, collection_id: UUID
) -> List[CommunityReport]:
) -> list[CommunityReport]:
"""Get community reports for a collection."""
pass

@abstractmethod
async def check_community_reports_exist(
self, collection_id: UUID, offset: int, limit: int
) -> List[int]:
) -> list[int]:
"""Check which community reports exist."""
pass

@abstractmethod
async def perform_graph_clustering(
self,
collection_id: UUID,
leiden_params: Dict[str, Any],
leiden_params: dict[str, Any],
) -> int:
"""Perform graph clustering."""
pass
Expand All @@ -753,10 +741,10 @@ async def delete_node_via_document_id(
async def get_entities(
self,
collection_id: Optional[UUID] = None,
entity_ids: Optional[List[str]] = None,
entity_names: Optional[List[str]] = None,
entity_ids: Optional[list[str]] = None,
entity_names: Optional[list[str]] = None,
entity_table_name: str = "document_entity",
extra_columns: Optional[List[str]] = None,
extra_columns: Optional[list[str]] = None,
offset: int = 0,
limit: int = -1,
) -> dict:
Expand All @@ -767,8 +755,8 @@ async def get_entities(
async def get_triples(
self,
collection_id: Optional[UUID] = None,
entity_names: Optional[List[str]] = None,
triple_ids: Optional[List[str]] = None,
entity_names: Optional[list[str]] = None,
triple_ids: Optional[list[str]] = None,
offset: int = 0,
limit: int = -1,
) -> dict:
Expand Down Expand Up @@ -865,7 +853,7 @@ async def get_existing_entity_extraction_ids(
@abstractmethod
async def get_all_triples(
self, collection_id: UUID, document_ids: Optional[list[UUID]] = None
) -> List[Triple]:
) -> list[Triple]:
raise NotImplementedError

@abstractmethod
Expand Down Expand Up @@ -1002,8 +990,8 @@ async def info_log(

@abstractmethod
async def get_logs(
self, run_ids: List[UUID], limit_per_run: int = 10
) -> List[Dict]:
self, run_ids: list[UUID], limit_per_run: int = 10
) -> list[dict]:
"""Retrieve logs for specified run IDs."""
pass

Expand All @@ -1013,8 +1001,8 @@ async def get_info_logs(
offset: int = 0,
limit: int = 100,
run_type_filter: Optional[RunType] = None,
user_ids: Optional[List[UUID]] = None,
) -> List[RunInfoLog]:
user_ids: Optional[list[UUID]] = None,
) -> list[RunInfoLog]:
"""Retrieve run information logs with filtering options."""
pass

Expand All @@ -1032,10 +1020,10 @@ async def delete_conversation(self, conversation_id: str) -> None:
@abstractmethod
async def get_conversations_overview(
self,
conversation_ids: Optional[List[UUID]] = None,
conversation_ids: Optional[list[UUID]] = None,
offset: int = 0,
limit: int = -1,
) -> Dict[str, Union[List[Dict], int]]:
) -> dict[str, Union[list[dict], int]]:
"""Get an overview of conversations with pagination."""
pass

Expand All @@ -1046,7 +1034,7 @@ async def add_message(
conversation_id: str,
content: Message,
parent_id: Optional[str] = None,
metadata: Optional[Dict] = None,
metadata: Optional[dict] = None,
) -> str:
"""Add a message to a conversation."""
pass
Expand All @@ -1061,13 +1049,13 @@ async def edit_message(
@abstractmethod
async def get_conversation(
self, conversation_id: str, branch_id: Optional[str] = None
) -> List[Tuple[str, Message]]:
) -> list[Tuple[str, Message]]:
"""Retrieve all messages in a conversation branch."""
pass

# Branch management methods
@abstractmethod
async def get_branches_overview(self, conversation_id: str) -> List[Dict]:
async def get_branches_overview(self, conversation_id: str) -> list[dict]:
"""Get an overview of all branches in a conversation."""
pass

Expand Down Expand Up @@ -1556,20 +1544,20 @@ async def add_triples(

async def get_entity_map(
self, offset: int, limit: int, document_id: UUID
) -> Dict[str, Dict[str, List[Dict[str, Any]]]]:
) -> dict[str, dict[str, list[dict[str, Any]]]]:
"""Forward to KG handler get_entity_map method."""
return await self.kg_handler.get_entity_map(offset, limit, document_id)

async def upsert_embeddings(
self,
data: List[Tuple[Any]],
data: list[Tuple[Any]],
table_name: str,
) -> None:
"""Forward to KG handler upsert_embeddings method."""
return await self.kg_handler.upsert_embeddings(data, table_name)

# Community methods
async def add_community_info(self, communities: List[Any]) -> None:
async def add_community_info(self, communities: list[Any]) -> None:
"""Forward to KG handler add_communities method."""
return await self.kg_handler.add_community_info(communities)

Expand Down Expand Up @@ -1606,13 +1594,13 @@ async def get_community_details(

async def get_community_reports(
self, collection_id: UUID
) -> List[CommunityReport]:
) -> list[CommunityReport]:
"""Forward to KG handler get_community_reports method."""
return await self.kg_handler.get_community_reports(collection_id)

async def check_community_reports_exist(
self, collection_id: UUID, offset: int, limit: int
) -> List[int]:
) -> list[int]:
"""Forward to KG handler check_community_reports_exist method."""
return await self.kg_handler.check_community_reports_exist(
collection_id, offset, limit
Expand All @@ -1621,7 +1609,7 @@ async def check_community_reports_exist(
async def perform_graph_clustering(
self,
collection_id: UUID,
leiden_params: Dict[str, Any],
leiden_params: dict[str, Any],
) -> int:
"""Forward to KG handler perform_graph_clustering method."""
return await self.kg_handler.perform_graph_clustering(
Expand Down Expand Up @@ -1649,10 +1637,10 @@ async def delete_node_via_document_id(
async def get_entities(
self,
collection_id: Optional[UUID],
entity_ids: Optional[List[str]] = None,
entity_names: Optional[List[str]] = None,
entity_ids: Optional[list[str]] = None,
entity_names: Optional[list[str]] = None,
entity_table_name: str = "document_entity",
extra_columns: Optional[List[str]] = None,
extra_columns: Optional[list[str]] = None,
offset: int = 0,
limit: int = -1,
) -> dict:
Expand All @@ -1670,8 +1658,8 @@ async def get_entities(
async def get_triples(
self,
collection_id: Optional[UUID] = None,
entity_names: Optional[List[str]] = None,
triple_ids: Optional[List[str]] = None,
entity_names: Optional[list[str]] = None,
triple_ids: Optional[list[str]] = None,
offset: int = 0,
limit: int = -1,
) -> dict:
Expand Down Expand Up @@ -1735,7 +1723,7 @@ async def get_deduplication_estimate(

async def get_all_triples(
self, collection_id: UUID, document_ids: Optional[list[UUID]] = None
) -> List[Triple]:
) -> list[Triple]:
return await self.kg_handler.get_all_triples(
collection_id, document_ids
)
Expand Down Expand Up @@ -1874,18 +1862,18 @@ async def get_info_logs(
offset: int = 0,
limit: int = 100,
run_type_filter: Optional[RunType] = None,
user_ids: Optional[List[UUID]] = None,
) -> List[RunInfoLog]:
user_ids: Optional[list[UUID]] = None,
) -> list[RunInfoLog]:
"""Retrieve log info entries with filtering and pagination."""
return await self.logging_handler.get_info_logs(
offset, limit, run_type_filter, user_ids
)

async def get_logs(
self,
run_ids: List[UUID],
run_ids: list[UUID],
limit_per_run: int = 10,
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
"""Retrieve logs for specified run IDs with a per-run limit."""
return await self.logging_handler.get_logs(run_ids, limit_per_run)

Expand All @@ -1899,10 +1887,10 @@ async def delete_conversation(self, conversation_id: str) -> None:

async def get_conversations_overview(
self,
conversation_ids: Optional[List[UUID]] = None,
conversation_ids: Optional[list[UUID]] = None,
offset: int = 0,
limit: int = -1,
) -> Dict[str, Union[List[Dict], int]]:
) -> dict[str, Union[list[dict], int]]:
"""Get an overview of conversations with pagination."""
return await self.logging_handler.get_conversations_overview(
conversation_ids, offset, limit
Expand All @@ -1913,7 +1901,7 @@ async def add_message(
conversation_id: str,
content: Message,
parent_id: Optional[str] = None,
metadata: Optional[Dict] = None,
metadata: Optional[dict] = None,
) -> str:
"""Add a message to a conversation."""
return await self.logging_handler.add_message(
Expand All @@ -1928,13 +1916,13 @@ async def edit_message(

async def get_conversation(
self, conversation_id: str, branch_id: Optional[str] = None
) -> List[Tuple[str, Message]]:
) -> list[Tuple[str, Message]]:
"""Retrieve all messages in a conversation branch."""
return await self.logging_handler.get_conversation(
conversation_id, branch_id
)

async def get_branches_overview(self, conversation_id: str) -> List[Dict]:
async def get_branches_overview(self, conversation_id: str) -> list[dict]:
"""Get an overview of all branches in a conversation."""
return await self.logging_handler.get_branches_overview(
conversation_id
Expand Down
Loading
Loading