diff --git a/py/core/__init__.py b/py/core/__init__.py index cb8c270bc..98d6ebdc8 100644 --- a/py/core/__init__.py +++ b/py/core/__init__.py @@ -147,6 +147,9 @@ # LLM provider "CompletionConfig", "CompletionProvider", + # User management provider + "UserManagementConfig", + "UserManagementProvider", ## UTILS "RecursiveCharacterTextSplitter", "TextSplitter", diff --git a/py/core/base/__init__.py b/py/core/base/__init__.py index 6e14b3edc..d9c3e22c4 100644 --- a/py/core/base/__init__.py +++ b/py/core/base/__init__.py @@ -124,6 +124,9 @@ # LLM provider "CompletionConfig", "CompletionProvider", + # User management provider + "UserManagementConfig", + "UserManagementProvider", ## UTILS "RecursiveCharacterTextSplitter", "TextSplitter", diff --git a/py/core/base/api/models/__init__.py b/py/core/base/api/models/__init__.py index 46d3007db..157534c8d 100644 --- a/py/core/base/api/models/__init__.py +++ b/py/core/base/api/models/__init__.py @@ -1,10 +1,8 @@ from shared.api.models.auth.responses import ( GenericMessageResponse, TokenResponse, - UserResponse, WrappedGenericMessageResponse, WrappedTokenResponse, - WrappedUserResponse, ) from shared.api.models.ingestion.responses import ( CreateVectorIndexResponse, @@ -46,6 +44,7 @@ ScoreCompletionResponse, ServerStats, UserOverviewResponse, + UserResponse, WrappedAddUserResponse, WrappedAnalyticsResponse, WrappedAppSettingsResponse, @@ -63,6 +62,7 @@ WrappedServerStatsResponse, WrappedUserCollectionResponse, WrappedUserOverviewResponse, + WrappedUserResponse, WrappedUsersInCollectionResponse, WrappedVerificationResult, ) diff --git a/py/core/base/providers/__init__.py b/py/core/base/providers/__init__.py index 37af2b8f3..f46ee0748 100644 --- a/py/core/base/providers/__init__.py +++ b/py/core/base/providers/__init__.py @@ -21,6 +21,7 @@ from .ingestion import ChunkingStrategy, IngestionConfig, IngestionProvider from .llm import CompletionConfig, CompletionProvider from .orchestration import OrchestrationConfig, OrchestrationProvider, Workflow +from .user_management import UserManagementConfig, UserManagementProvider __all__ = [ # Auth provider @@ -64,4 +65,7 @@ "OrchestrationConfig", "OrchestrationProvider", "Workflow", + # User management provider + "UserManagementConfig", + "UserManagementProvider", ] diff --git a/py/core/base/providers/database.py b/py/core/base/providers/database.py index ddea49556..1452f8069 100644 --- a/py/core/base/providers/database.py +++ b/py/core/base/providers/database.py @@ -6,8 +6,6 @@ Any, AsyncGenerator, BinaryIO, - Dict, - List, Optional, Sequence, Tuple, @@ -50,7 +48,6 @@ KGEnrichmentEstimationResponse, UserResponse, ) -from core.base.utils import _decorate_vector_type from ..logger import RunInfoLog from ..logger.base import RunType @@ -79,15 +76,6 @@ logger = logging.getLogger() -def escape_braces(s: str) -> str: - """ - Escape braces in a string. - This is a placeholder function - implement the actual logic as needed. - """ - # Implement your escape_braces logic here - return s.replace("{", "{{").replace("}", "}}") - - logger = logging.getLogger() @@ -658,14 +646,14 @@ async def add_triples( @abstractmethod async def get_entity_map( self, offset: int, limit: int, document_id: UUID - ) -> Dict[str, Dict[str, List[Dict[str, Any]]]]: + ) -> dict[str, dict[str, list[dict[str, Any]]]]: """Get entity map for a document.""" pass @abstractmethod async def upsert_embeddings( self, - data: List[Tuple[Any]], + data: list[Tuple[Any]], table_name: str, ) -> None: """Upsert embeddings into storage.""" @@ -680,7 +668,7 @@ async def vector_query( # Community management @abstractmethod - async def add_community_info(self, communities: List[Any]) -> None: + async def add_community_info(self, communities: list[Any]) -> None: """Add communities to storage.""" pass @@ -713,14 +701,14 @@ async def get_community_details( @abstractmethod async def get_community_reports( self, collection_id: UUID - ) -> List[CommunityReport]: + ) -> list[CommunityReport]: """Get community reports for a collection.""" pass @abstractmethod async def check_community_reports_exist( self, collection_id: UUID, offset: int, limit: int - ) -> List[int]: + ) -> list[int]: """Check which community reports exist.""" pass @@ -728,7 +716,7 @@ async def check_community_reports_exist( async def perform_graph_clustering( self, collection_id: UUID, - leiden_params: Dict[str, Any], + leiden_params: dict[str, Any], ) -> int: """Perform graph clustering.""" pass @@ -753,10 +741,10 @@ async def delete_node_via_document_id( async def get_entities( self, collection_id: Optional[UUID] = None, - entity_ids: Optional[List[str]] = None, - entity_names: Optional[List[str]] = None, + entity_ids: Optional[list[str]] = None, + entity_names: Optional[list[str]] = None, entity_table_name: str = "document_entity", - extra_columns: Optional[List[str]] = None, + extra_columns: Optional[list[str]] = None, offset: int = 0, limit: int = -1, ) -> dict: @@ -767,8 +755,8 @@ async def get_entities( async def get_triples( self, collection_id: Optional[UUID] = None, - entity_names: Optional[List[str]] = None, - triple_ids: Optional[List[str]] = None, + entity_names: Optional[list[str]] = None, + triple_ids: Optional[list[str]] = None, offset: int = 0, limit: int = -1, ) -> dict: @@ -865,7 +853,7 @@ async def get_existing_entity_extraction_ids( @abstractmethod async def get_all_triples( self, collection_id: UUID, document_ids: Optional[list[UUID]] = None - ) -> List[Triple]: + ) -> list[Triple]: raise NotImplementedError @abstractmethod @@ -1002,8 +990,8 @@ async def info_log( @abstractmethod async def get_logs( - self, run_ids: List[UUID], limit_per_run: int = 10 - ) -> List[Dict]: + self, run_ids: list[UUID], limit_per_run: int = 10 + ) -> list[dict]: """Retrieve logs for specified run IDs.""" pass @@ -1013,8 +1001,8 @@ async def get_info_logs( offset: int = 0, limit: int = 100, run_type_filter: Optional[RunType] = None, - user_ids: Optional[List[UUID]] = None, - ) -> List[RunInfoLog]: + user_ids: Optional[list[UUID]] = None, + ) -> list[RunInfoLog]: """Retrieve run information logs with filtering options.""" pass @@ -1032,10 +1020,10 @@ async def delete_conversation(self, conversation_id: str) -> None: @abstractmethod async def get_conversations_overview( self, - conversation_ids: Optional[List[UUID]] = None, + conversation_ids: Optional[list[UUID]] = None, offset: int = 0, limit: int = -1, - ) -> Dict[str, Union[List[Dict], int]]: + ) -> dict[str, Union[list[dict], int]]: """Get an overview of conversations with pagination.""" pass @@ -1046,7 +1034,7 @@ async def add_message( conversation_id: str, content: Message, parent_id: Optional[str] = None, - metadata: Optional[Dict] = None, + metadata: Optional[dict] = None, ) -> str: """Add a message to a conversation.""" pass @@ -1061,13 +1049,13 @@ async def edit_message( @abstractmethod async def get_conversation( self, conversation_id: str, branch_id: Optional[str] = None - ) -> List[Tuple[str, Message]]: + ) -> list[Tuple[str, Message]]: """Retrieve all messages in a conversation branch.""" pass # Branch management methods @abstractmethod - async def get_branches_overview(self, conversation_id: str) -> List[Dict]: + async def get_branches_overview(self, conversation_id: str) -> list[dict]: """Get an overview of all branches in a conversation.""" pass @@ -1556,20 +1544,20 @@ async def add_triples( async def get_entity_map( self, offset: int, limit: int, document_id: UUID - ) -> Dict[str, Dict[str, List[Dict[str, Any]]]]: + ) -> dict[str, dict[str, list[dict[str, Any]]]]: """Forward to KG handler get_entity_map method.""" return await self.kg_handler.get_entity_map(offset, limit, document_id) async def upsert_embeddings( self, - data: List[Tuple[Any]], + data: list[Tuple[Any]], table_name: str, ) -> None: """Forward to KG handler upsert_embeddings method.""" return await self.kg_handler.upsert_embeddings(data, table_name) # Community methods - async def add_community_info(self, communities: List[Any]) -> None: + async def add_community_info(self, communities: list[Any]) -> None: """Forward to KG handler add_communities method.""" return await self.kg_handler.add_community_info(communities) @@ -1606,13 +1594,13 @@ async def get_community_details( async def get_community_reports( self, collection_id: UUID - ) -> List[CommunityReport]: + ) -> list[CommunityReport]: """Forward to KG handler get_community_reports method.""" return await self.kg_handler.get_community_reports(collection_id) async def check_community_reports_exist( self, collection_id: UUID, offset: int, limit: int - ) -> List[int]: + ) -> list[int]: """Forward to KG handler check_community_reports_exist method.""" return await self.kg_handler.check_community_reports_exist( collection_id, offset, limit @@ -1621,7 +1609,7 @@ async def check_community_reports_exist( async def perform_graph_clustering( self, collection_id: UUID, - leiden_params: Dict[str, Any], + leiden_params: dict[str, Any], ) -> int: """Forward to KG handler perform_graph_clustering method.""" return await self.kg_handler.perform_graph_clustering( @@ -1649,10 +1637,10 @@ async def delete_node_via_document_id( async def get_entities( self, collection_id: Optional[UUID], - entity_ids: Optional[List[str]] = None, - entity_names: Optional[List[str]] = None, + entity_ids: Optional[list[str]] = None, + entity_names: Optional[list[str]] = None, entity_table_name: str = "document_entity", - extra_columns: Optional[List[str]] = None, + extra_columns: Optional[list[str]] = None, offset: int = 0, limit: int = -1, ) -> dict: @@ -1670,8 +1658,8 @@ async def get_entities( async def get_triples( self, collection_id: Optional[UUID] = None, - entity_names: Optional[List[str]] = None, - triple_ids: Optional[List[str]] = None, + entity_names: Optional[list[str]] = None, + triple_ids: Optional[list[str]] = None, offset: int = 0, limit: int = -1, ) -> dict: @@ -1735,7 +1723,7 @@ async def get_deduplication_estimate( async def get_all_triples( self, collection_id: UUID, document_ids: Optional[list[UUID]] = None - ) -> List[Triple]: + ) -> list[Triple]: return await self.kg_handler.get_all_triples( collection_id, document_ids ) @@ -1874,8 +1862,8 @@ async def get_info_logs( offset: int = 0, limit: int = 100, run_type_filter: Optional[RunType] = None, - user_ids: Optional[List[UUID]] = None, - ) -> List[RunInfoLog]: + user_ids: Optional[list[UUID]] = None, + ) -> list[RunInfoLog]: """Retrieve log info entries with filtering and pagination.""" return await self.logging_handler.get_info_logs( offset, limit, run_type_filter, user_ids @@ -1883,9 +1871,9 @@ async def get_info_logs( async def get_logs( self, - run_ids: List[UUID], + run_ids: list[UUID], limit_per_run: int = 10, - ) -> List[Dict[str, Any]]: + ) -> list[dict[str, Any]]: """Retrieve logs for specified run IDs with a per-run limit.""" return await self.logging_handler.get_logs(run_ids, limit_per_run) @@ -1899,10 +1887,10 @@ async def delete_conversation(self, conversation_id: str) -> None: async def get_conversations_overview( self, - conversation_ids: Optional[List[UUID]] = None, + conversation_ids: Optional[list[UUID]] = None, offset: int = 0, limit: int = -1, - ) -> Dict[str, Union[List[Dict], int]]: + ) -> dict[str, Union[list[dict], int]]: """Get an overview of conversations with pagination.""" return await self.logging_handler.get_conversations_overview( conversation_ids, offset, limit @@ -1913,7 +1901,7 @@ async def add_message( conversation_id: str, content: Message, parent_id: Optional[str] = None, - metadata: Optional[Dict] = None, + metadata: Optional[dict] = None, ) -> str: """Add a message to a conversation.""" return await self.logging_handler.add_message( @@ -1928,13 +1916,13 @@ async def edit_message( async def get_conversation( self, conversation_id: str, branch_id: Optional[str] = None - ) -> List[Tuple[str, Message]]: + ) -> list[Tuple[str, Message]]: """Retrieve all messages in a conversation branch.""" return await self.logging_handler.get_conversation( conversation_id, branch_id ) - async def get_branches_overview(self, conversation_id: str) -> List[Dict]: + async def get_branches_overview(self, conversation_id: str) -> list[dict]: """Get an overview of all branches in a conversation.""" return await self.logging_handler.get_branches_overview( conversation_id diff --git a/py/core/base/providers/user_management.py b/py/core/base/providers/user_management.py new file mode 100644 index 000000000..f6ff6a670 --- /dev/null +++ b/py/core/base/providers/user_management.py @@ -0,0 +1,110 @@ +from typing import Optional +from pydantic import Field, field_validator, BaseModel +from .base import Provider, ProviderConfig +from abc import ABC + + +Limit = Optional[int] + + +class RoleLimits(BaseModel): + max_files: Limit = Field( + default=None, + description="Number of files allowed (None for no limit)", + ) + max_chunks: Limit = Field( + default=None, + description="Number of chunks allowed (None for no limit)", + ) + max_queries: Limit = Field( + default=None, + description="Number of queries allowed (None for no limit)", + ) + max_queries_window: Limit = Field( + default=None, description="Query window size (None for no limit)" + ) + + @field_validator( + "max_files", + "max_chunks", + "max_queries", + "max_queries_window", + mode="before", + ) + def parse_limit(cls, v): + return None if v is None or v == "inf" or v == float("inf") else v + + def has_limit(self, field: str) -> bool: + """Check if a particular field has a numerical limit.""" + return getattr(self, field) is not None + + def get_limit(self, field: str) -> Optional[int]: + """Get the numerical limit for a field, or None if no limit.""" + return getattr(self, field) + + +class UserManagementConfig(ProviderConfig): + default_role: str = "default" + roles: dict[str, RoleLimits] = { + "default": RoleLimits(), + } + + @property + def supported_providers(self) -> list[str]: + return ["r2r"] + + def validate_config(self) -> None: + if self.default_role not in self.roles: + raise ValueError( + f"Default role '{self.default_role}' not found in roles configuration" + ) + + def get_role_limits(self, role: str) -> RoleLimits: + default_limits = self.roles[self.default_role] + + if role == self.default_role: + return default_limits + + custom_limits = self.roles.get(role, RoleLimits()) + + return RoleLimits( + max_files=( + custom_limits.max_files + if custom_limits.has_limit("max_files") + else default_limits.max_files + ), + max_chunks=( + custom_limits.max_chunks + if custom_limits.has_limit("max_chunks") + else default_limits.max_chunks + ), + max_queries=( + custom_limits.max_queries + if custom_limits.has_limit("max_queries") + else default_limits.max_queries + ), + max_queries_window=( + custom_limits.max_queries_window + if custom_limits.has_limit("max_queries_window") + else default_limits.max_queries_window + ), + ) + + +class UserManagementProvider(Provider, ABC): + def __init__(self, config: UserManagementConfig): + if not isinstance(config, UserManagementConfig): + raise ValueError( + "UserManagementProvider must be initialized with a UserManagementConfig" + ) + print(f"UserManagementProvider config: {config}") + super().__init__(config) + self.config: UserManagementConfig = config + + def check_limit( + self, role: str, limit_type: str, current_value: int + ) -> bool: + """Check if a particular action would exceed the role's limits.""" + limits = self.config.get_role_limits(role) + limit = limits.get_limit(limit_type) + return limit is None or current_value < limit diff --git a/py/core/main/assembly/builder.py b/py/core/main/assembly/builder.py index 9f15578df..111d46ae5 100644 --- a/py/core/main/assembly/builder.py +++ b/py/core/main/assembly/builder.py @@ -11,6 +11,7 @@ DatabaseProvider, EmbeddingProvider, OrchestrationProvider, + UserManagementProvider, RunManager, ) from core.pipelines import KGEnrichmentPipeline, RAGPipeline, SearchPipeline @@ -47,6 +48,7 @@ class ProviderOverrides: llm: Optional[CompletionProvider] = None crypto: Optional[CryptoProvider] = None orchestration: Optional[OrchestrationProvider] = None + user_management: Optional[UserManagementProvider] = None @dataclass diff --git a/py/core/main/assembly/factory.py b/py/core/main/assembly/factory.py index 8ba3efc76..d061206bf 100644 --- a/py/core/main/assembly/factory.py +++ b/py/core/main/assembly/factory.py @@ -15,6 +15,7 @@ EmbeddingProvider, IngestionConfig, OrchestrationConfig, + UserManagementConfig, ) from core.pipelines import RAGPipeline, SearchPipeline from core.pipes import GeneratorPipe, MultiSearchPipe, SearchPipe @@ -43,6 +44,7 @@ SupabaseAuthProvider, UnstructuredIngestionConfig, UnstructuredIngestionProvider, + R2RUserManagementProvider, ) @@ -147,10 +149,22 @@ def create_orchestration_provider( f"Orchestration provider {config.provider} not supported" ) + @staticmethod + def create_user_management_provider( + user_management_config: UserManagementConfig, *args, **kwargs + ) -> R2RUserManagementProvider: + if user_management_config.provider == "r2r": + return R2RUserManagementProvider(user_management_config) + else: + raise ValueError( + f"User management provider {user_management_config.provider} not supported" + ) + async def create_database_provider( self, db_config: DatabaseConfig, crypto_provider: BCryptProvider, + user_management_provider: R2RUserManagementProvider, *args, **kwargs, ) -> PostgresDBProvider: @@ -171,6 +185,7 @@ async def create_database_provider( dimension, crypto_provider=crypto_provider, quantization_type=quantization_type, + user_management_provider=user_management_provider, ) await database_provider.initialize() return database_provider @@ -239,7 +254,7 @@ async def create_email_provider( """Creates an email provider based on configuration.""" if not email_config: raise ValueError( - f"No email configuration provided for email provider, please add `[email]` to your `r2r.toml`." + "No email configuration provided for email provider, please add `[email]` to your `r2r.toml`." ) if email_config.provider == "smtp": @@ -295,10 +310,19 @@ async def create_providers( crypto_provider_override or self.create_crypto_provider(self.config.crypto, *args, **kwargs) ) + + user_management_provider = self.create_user_management_provider( + self.config.user_management + ) + database_provider = ( database_provider_override or await self.create_database_provider( - self.config.database, crypto_provider, *args, **kwargs + self.config.database, + crypto_provider, + user_management_provider, + *args, + **kwargs, ) ) diff --git a/py/core/main/config.py b/py/core/main/config.py index 4b914d6da..ad19c5dd9 100644 --- a/py/core/main/config.py +++ b/py/core/main/config.py @@ -18,6 +18,7 @@ from ..base.providers.ingestion import IngestionConfig from ..base.providers.llm import CompletionConfig from ..base.providers.orchestration import OrchestrationConfig +from ..base.providers.user_management import UserManagementConfig logger = logging.getLogger() @@ -56,6 +57,7 @@ class R2RConfig: "database": ["provider"], "agent": ["generation_config"], "orchestration": ["provider"], + "user_management": ["default_role"], } app: AppConfig @@ -69,6 +71,7 @@ class R2RConfig: logging: PersistentLoggingConfig agent: AgentConfig orchestration: OrchestrationConfig + user_management: UserManagementConfig def __init__(self, config_data: dict[str, Any]): """ @@ -122,6 +125,7 @@ def __init__(self, config_data: dict[str, Any]): self.logging = PersistentLoggingConfig.create(**self.logging, app=self.app) # type: ignore self.agent = AgentConfig.create(**self.agent, app=self.app) # type: ignore self.orchestration = OrchestrationConfig.create(**self.orchestration, app=self.app) # type: ignore + self.user_management = UserManagementConfig.create(**self.user_management, app=self.app) # type: ignore # override GenerationConfig defaults GenerationConfig.set_default( diff --git a/py/core/providers/__init__.py b/py/core/providers/__init__.py index a970f83eb..1320afe16 100644 --- a/py/core/providers/__init__.py +++ b/py/core/providers/__init__.py @@ -19,6 +19,7 @@ HatchetOrchestrationProvider, SimpleOrchestrationProvider, ) +from .user_management import R2RUserManagementProvider __all__ = [ # Auth @@ -49,4 +50,6 @@ "LiteLLMCompletionProvider", # Logging "SqlitePersistentLoggingProvider", + # User Management + "R2RUserManagementProvider", ] diff --git a/py/core/providers/database/postgres.py b/py/core/providers/database/postgres.py index 655473192..8c1bd1f86 100644 --- a/py/core/providers/database/postgres.py +++ b/py/core/providers/database/postgres.py @@ -10,6 +10,7 @@ DatabaseProvider, PostgresConfigurationSettings, VectorQuantizationType, + UserManagementConfig, ) from core.providers import BCryptProvider from core.providers.database.base import PostgresConnectionManager @@ -22,6 +23,7 @@ from core.providers.database.tokens import PostgresTokenHandler from core.providers.database.user import PostgresUserHandler from core.providers.database.vector import PostgresVectorHandler +from core.providers.user_management import R2RUserManagementProvider from .base import SemaphoreConnectionPool @@ -73,6 +75,7 @@ def __init__( config: DatabaseConfig, dimension: int, crypto_provider: BCryptProvider, + user_management_provider: R2RUserManagementProvider, quantization_type: VectorQuantizationType = VectorQuantizationType.FP32, *args, **kwargs, @@ -124,6 +127,7 @@ def __init__( self.conn = None self.config: DatabaseConfig = config self.crypto_provider = crypto_provider + self.user_management_provider = user_management_provider self.postgres_configuration_settings: PostgresConfigurationSettings = ( self._get_postgres_configuration_settings(config) ) @@ -146,7 +150,10 @@ def __init__( self.project_name, self.connection_manager, self.config ) self.user_handler = PostgresUserHandler( - self.project_name, self.connection_manager, self.crypto_provider + self.project_name, + self.connection_manager, + self.crypto_provider, + self.user_management_provider, ) self.vector_handler = PostgresVectorHandler( self.project_name, diff --git a/py/core/providers/database/user.py b/py/core/providers/database/user.py index 41654ff04..35936e524 100644 --- a/py/core/providers/database/user.py +++ b/py/core/providers/database/user.py @@ -3,7 +3,7 @@ from uuid import UUID from fastapi import HTTPException -from core.base import CryptoProvider, UserHandler +from core.base import CryptoProvider, UserHandler, UserManagementProvider from core.base.abstractions import R2RException, UserStats from core.base.api.models import UserResponse from core.utils import generate_user_id @@ -20,9 +20,11 @@ def __init__( project_name: str, connection_manager: PostgresConnectionManager, crypto_provider: CryptoProvider, + user_config: UserManagementProvider, ): super().__init__(project_name, connection_manager) self.crypto_provider = crypto_provider + self.user_config = user_config async def create_tables(self): query = f""" @@ -42,7 +44,8 @@ async def create_tables(self): reset_token_expiry TIMESTAMPTZ, collection_ids UUID[] NULL, created_at TIMESTAMPTZ DEFAULT NOW(), - updated_at TIMESTAMPTZ DEFAULT NOW() + updated_at TIMESTAMPTZ DEFAULT NOW(), + role TEXT ); """ await self.connection_manager.execute_query(query) @@ -130,7 +133,15 @@ async def get_user_by_email(self, email: str) -> UserResponse: collection_ids=result["collection_ids"], ) - async def create_user(self, email: str, password: str) -> UserResponse: + async def create_user( + self, email: str, password: str, role: str = "default" + ) -> UserResponse: + """Modified create_user to include role""" + # if role not in self.user_config.roles: + # raise R2RException( + # status_code=400, message=f"Invalid role: {role}" + # ) + try: if await self.get_user_by_email(email): raise R2RException( @@ -144,12 +155,12 @@ async def create_user(self, email: str, password: str) -> UserResponse: hashed_password = self.crypto_provider.get_password_hash(password) # type: ignore query = f""" INSERT INTO {self._get_table_name(PostgresUserHandler.TABLE_NAME)} - (email, user_id, hashed_password, collection_ids) - VALUES ($1, $2, $3, $4) - RETURNING user_id, email, is_superuser, is_active, is_verified, created_at, updated_at, collection_ids + (email, user_id, hashed_password, collection_ids, role) + VALUES ($1, $2, $3, $4, $5) + RETURNING user_id, email, is_superuser, is_active, is_verified, created_at, updated_at, collection_ids, role """ result = await self.connection_manager.fetchrow_query( - query, [email, generate_user_id(email), hashed_password, []] + query, [email, generate_user_id(email), hashed_password, [], role] ) if not result: @@ -162,6 +173,7 @@ async def create_user(self, email: str, password: str) -> UserResponse: id=result["user_id"], email=result["email"], is_superuser=result["is_superuser"], + role=result["role"], is_active=result["is_active"], is_verified=result["is_verified"], created_at=result["created_at"], @@ -616,3 +628,67 @@ async def get_user_verification_data( ), } } + + async def get_user_role(self, user_id: UUID) -> str: + """ + Check the assigned role of a user + """ + query = f""" + SELECT role FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)} + WHERE user_id = $1 + """ + result = await self.connection_manager.fetchrow_query(query, [user_id]) + if not result: + raise R2RException(status_code=404, message="User not found") + return result["role"] + + async def check_file_limit(self, user_id: UUID) -> bool: + """ + Check if a user has reached their role defined file limit + """ + user = await self.get_user_by_id(user_id) + + query = f""" + SELECT COUNT(*) FROM {self._get_table_name('document_info')} + WHERE user_id = $1 + """ + result = await self.connection_manager.fetchrow_query(query, [user_id]) + + return self.user_config.check_limit(user.role, "max_files", result[0]) + + async def check_query_limit(self, user_id: UUID) -> bool: + """ + Check if a user has reached their role defined query limit + """ + user = await self.get_user_by_id(user_id) + + query = f""" + SELECT COUNT(*) FROM {self._get_table_name('logs')} + WHERE user_id = $1 AND run_type = 'query' + AND created_at > NOW() - INTERVAL '1 day' + """ + result = await self.connection_manager.fetchrow_query(query, [user_id]) + + return self.user_config.check_limit( + user.role, "max_queries", result[0] + ) + + async def increment_query_count(self, user_id: UUID) -> None: + """Increment user's query count""" + query = f""" + INSERT INTO {self._get_table_name('logs')} + (user_id, run_type, created_at) + VALUES ($1, 'query', NOW()) + """ + await self.connection_manager.execute_query(query, [user_id]) + + async def update_user_role(self, user_id: UUID, role: str) -> None: + """ + Update the role of a user + """ + query = f""" + UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)} + SET role = $1 + WHERE user_id = $2 + """ + await self.connection_manager.execute_query(query, [role, user_id]) diff --git a/py/core/providers/database/vector.py b/py/core/providers/database/vector.py index 600d5c5e2..054ff3d6f 100644 --- a/py/core/providers/database/vector.py +++ b/py/core/providers/database/vector.py @@ -3,7 +3,7 @@ import logging import time import uuid -from typing import Any, Optional, Tuple, TypedDict, Union +from typing import Any, Optional, TypedDict, Union from uuid import UUID import numpy as np diff --git a/py/core/providers/user_management/__init__.py b/py/core/providers/user_management/__init__.py new file mode 100644 index 000000000..9439e1b32 --- /dev/null +++ b/py/core/providers/user_management/__init__.py @@ -0,0 +1 @@ +from .r2r_user_management import R2RUserManagementProvider diff --git a/py/core/providers/user_management/r2r_user_management.py b/py/core/providers/user_management/r2r_user_management.py new file mode 100644 index 000000000..a5b3c5286 --- /dev/null +++ b/py/core/providers/user_management/r2r_user_management.py @@ -0,0 +1,18 @@ +from core.base.providers.user_management import ( + UserManagementProvider, + UserManagementConfig, + RoleLimits, +) + + +class R2RUserManagementProvider(UserManagementProvider): + def __init__(self, config: UserManagementConfig): + super().__init__(config) + self.roles = config.roles + self.default_role = config.default_role + print( + f"Initialized R2RUserManagementProvider with roles: {self.roles}" + ) + + def get_role_limits(self, role: str) -> RoleLimits: + return self.config.get_role_limits(role) diff --git a/py/r2r.toml b/py/r2r.toml index e9f0023fa..ed82a00a8 100644 --- a/py/r2r.toml +++ b/py/r2r.toml @@ -119,9 +119,18 @@ log_info_table = "log_info" [orchestration] provider = "simple" - [prompt] provider = "r2r" [email] provider = "console_mock" + +[user_management] +provider = "r2r" +default_role = "default" + + [user_management.roles.default] + max_files = inf + max_chunks = inf + max_queries = inf + max_queries_window = inf diff --git a/py/shared/api/models/__init__.py b/py/shared/api/models/__init__.py index cc982cec8..92a8ddbad 100644 --- a/py/shared/api/models/__init__.py +++ b/py/shared/api/models/__init__.py @@ -1,10 +1,8 @@ from shared.api.models.auth.responses import ( GenericMessageResponse, TokenResponse, - UserResponse, WrappedGenericMessageResponse, WrappedTokenResponse, - WrappedUserResponse, ) from shared.api.models.ingestion.responses import ( IngestionResponse, @@ -33,6 +31,7 @@ ScoreCompletionResponse, ServerStats, UserOverviewResponse, + UserResponse, WrappedAddUserResponse, WrappedAnalyticsResponse, WrappedAppSettingsResponse, @@ -50,6 +49,7 @@ WrappedUserCollectionResponse, WrappedUserOverviewResponse, WrappedUsersInCollectionResponse, + WrappedUserResponse, ) from shared.api.models.retrieval.responses import ( RAGAgentResponse, @@ -65,9 +65,7 @@ # Auth Responses "GenericMessageResponse", "TokenResponse", - "UserResponse", "WrappedTokenResponse", - "WrappedUserResponse", "WrappedGenericMessageResponse", # Ingestion Responses "IngestionResponse", @@ -92,6 +90,7 @@ "CollectionResponse", "CollectionOverviewResponse", "ConversationOverviewResponse", + "UserResponse", "WrappedPromptMessageResponse", "WrappedServerStatsResponse", "WrappedLogResponse", @@ -112,6 +111,7 @@ "WrappedDocumentChunkResponse", "WrappedCollectionOverviewResponse", "WrappedConversationsOverviewResponse", + "WrappedUserResponse", # Retrieval Responses "SearchResponse", "RAGResponse", diff --git a/py/shared/api/models/auth/responses.py b/py/shared/api/models/auth/responses.py index 9e868272d..e3c56c055 100644 --- a/py/shared/api/models/auth/responses.py +++ b/py/shared/api/models/auth/responses.py @@ -1,10 +1,6 @@ -from datetime import datetime -from typing import Optional -from uuid import UUID - from pydantic import BaseModel -from shared.abstractions import R2RSerializable, Token +from shared.abstractions import Token from shared.api.models.base import ResultsWrapper @@ -13,29 +9,10 @@ class TokenResponse(BaseModel): refresh_token: Token -class UserResponse(R2RSerializable): - id: UUID - email: str - is_active: bool = True - is_superuser: bool = False - created_at: datetime = datetime.now() - updated_at: datetime = datetime.now() - is_verified: bool = False - collection_ids: list[UUID] = [] - - # Optional fields (to update or set at creation) - hashed_password: Optional[str] = None - verification_code_expiry: Optional[datetime] = None - name: Optional[str] = None - bio: Optional[str] = None - profile_picture: Optional[str] = None - - class GenericMessageResponse(BaseModel): message: str # Create wrapped versions of each response WrappedTokenResponse = ResultsWrapper[TokenResponse] -WrappedUserResponse = ResultsWrapper[UserResponse] WrappedGenericMessageResponse = ResultsWrapper[GenericMessageResponse] diff --git a/py/shared/api/models/management/responses.py b/py/shared/api/models/management/responses.py index cec9efd73..c72e4c40b 100644 --- a/py/shared/api/models/management/responses.py +++ b/py/shared/api/models/management/responses.py @@ -4,6 +4,7 @@ from pydantic import BaseModel +from shared.abstractions.base import R2RSerializable from shared.api.models.base import PaginatedResultsWrapper, ResultsWrapper from ....abstractions.llm import Message @@ -70,11 +71,12 @@ class UserOverviewResponse(BaseModel): document_ids: list[UUID] -class UserResponse(BaseModel): +class UserResponse(R2RSerializable): id: UUID email: str is_active: bool = True is_superuser: bool = False + role: str = "default" created_at: datetime = datetime.now() updated_at: datetime = datetime.now() is_verified: bool = False @@ -176,6 +178,7 @@ class AddUserResponse(BaseModel): list[DocumentChunkResponse] ] WrappedDeleteResponse = ResultsWrapper[None] +WrappedUserResponse = ResultsWrapper[UserResponse] WrappedVerificationResult = ResultsWrapper[VerificationResult] WrappedConversationsOverviewResponse = PaginatedResultsWrapper[ list[ConversationOverviewResponse]