diff --git a/py/core/base/api/models/__init__.py b/py/core/base/api/models/__init__.py index b6eb410a7..e6312513a 100644 --- a/py/core/base/api/models/__init__.py +++ b/py/core/base/api/models/__init__.py @@ -63,6 +63,7 @@ WrappedUserCollectionResponse, WrappedUserOverviewResponse, WrappedUsersInCollectionResponse, + WrappedVerificationResult, ) from shared.api.models.retrieval.responses import ( RAGAgentResponse, @@ -81,6 +82,7 @@ "UserResponse", "WrappedTokenResponse", "WrappedUserResponse", + "WrappedVerificationResult", "WrappedGenericMessageResponse", # Ingestion Responses "IngestionResponse", diff --git a/py/core/base/providers/database.py b/py/core/base/providers/database.py index a035ea639..e95c59f87 100644 --- a/py/core/base/providers/database.py +++ b/py/core/base/providers/database.py @@ -485,6 +485,16 @@ async def get_users_overview( ) -> dict[str, Union[list[UserStats], int]]: pass + @abstractmethod + async def get_user_verification_data( + self, user_id: UUID, *args, **kwargs + ) -> dict: + """ + Get verification data for a specific user. + This method should be called after superuser authorization has been verified. + """ + pass + class VectorHandler(Handler): def __init__(self, *args, **kwargs): @@ -1368,6 +1378,11 @@ async def get_users_overview( user_ids, offset, limit ) + async def get_user_verification_data( + self, user_id: UUID, *args, **kwargs + ) -> dict: + return await self.user_handler.get_user_verification_data(user_id) + # Vector handler methods async def upsert(self, entry: VectorEntry) -> None: return await self.vector_handler.upsert(entry) diff --git a/py/core/main/api/management_router.py b/py/core/main/api/management_router.py index 606c943e9..628cf3457 100644 --- a/py/core/main/api/management_router.py +++ b/py/core/main/api/management_router.py @@ -30,6 +30,7 @@ WrappedUserCollectionResponse, WrappedUserOverviewResponse, WrappedUsersInCollectionResponse, + WrappedVerificationResult, ) from core.base.logger import AnalysisTypes, LogFilterCriteria from core.providers import ( @@ -887,3 +888,28 @@ async def delete_conversation( ) -> WrappedDeleteResponse: await self.service.delete_conversation(conversation_id) return None # type: ignore + + @self.router.get("/user/{user_id}/verification_data") + @self.base_endpoint + async def get_user_verification_code( + user_id: str = Path(..., description="User ID"), + auth_user=Depends(self.service.providers.auth.auth_wrapper), + ) -> WrappedVerificationResult: + """ + Get only the verification code for a specific user. + Only accessible by superusers. + """ + if not auth_user.is_superuser: + raise R2RException( + status_code=403, + message="Only superusers can access verification codes", + ) + + try: + user_uuid = UUID(user_id) + except ValueError: + raise R2RException( + status_code=400, message="Invalid user ID format" + ) + + return await self.service.get_user_verification_data(user_uuid) diff --git a/py/core/main/services/management_service.py b/py/core/main/services/management_service.py index 3490fbec0..0a01250aa 100644 --- a/py/core/main/services/management_service.py +++ b/py/core/main/services/management_service.py @@ -740,3 +740,23 @@ async def branch_at_message(self, message_id: str, auth_user=None) -> str: @telemetry_event("DeleteConversation") async def delete_conversation(self, conversation_id: str, auth_user=None): await self.logging_connection.delete_conversation(conversation_id) + + @telemetry_event("GetUserVerificationCode") + async def get_user_verification_data( + self, user_id: UUID, *args, **kwargs + ) -> dict: + """ + Get only the verification code data for a specific user. + This method should be called after superuser authorization has been verified. + """ + verification_data = ( + await self.providers.database.get_user_verification_data(user_id) + ) + return { + "verification_code": verification_data["verification_data"][ + "verification_code" + ], + "expiry": verification_data["verification_data"][ + "verification_code_expiry" + ], + } diff --git a/py/core/providers/database/user.py b/py/core/providers/database/user.py index 68113ff88..61e465388 100644 --- a/py/core/providers/database/user.py +++ b/py/core/providers/database/user.py @@ -575,3 +575,41 @@ async def _collection_exists(self, collection_id: UUID) -> bool: query, [collection_id] ) return result is not None + + async def get_user_verification_data( + self, user_id: UUID, *args, **kwargs + ) -> dict: + """ + Get verification data for a specific user. + This method should be called after superuser authorization has been verified. + """ + query = f""" + SELECT + verification_code, + verification_code_expiry, + reset_token, + reset_token_expiry + FROM {self._get_table_name("users")} + WHERE user_id = $1 + """ + result = await self.connection_manager.fetchrow_query(query, [user_id]) + + if not result: + raise R2RException(status_code=404, message="User not found") + + return { + "verification_data": { + "verification_code": result["verification_code"], + "verification_code_expiry": ( + result["verification_code_expiry"].isoformat() + if result["verification_code_expiry"] + else None + ), + "reset_token": result["reset_token"], + "reset_token_expiry": ( + result["reset_token_expiry"].isoformat() + if result["reset_token_expiry"] + else None + ), + } + } diff --git a/py/sdk/mixins/auth.py b/py/sdk/mixins/auth.py index 03dc5db5c..8cf4bd05e 100644 --- a/py/sdk/mixins/auth.py +++ b/py/sdk/mixins/auth.py @@ -196,3 +196,19 @@ async def login_with_token( self.access_token = None self._refresh_token = None raise ValueError("Invalid tokens provided") + + async def get_user_verification_code( + self, user_id: Union[str, UUID] + ) -> dict: + """ + Retrieves only the verification code for a specific user. Requires superuser access. + + Args: + user_id (Union[str, UUID]): The ID of the user to get verification code for. + + Returns: + dict: Contains verification code and its expiry date + """ + return await self._make_request( # type: ignore + "GET", f"user/{user_id}/verification_data" + ) diff --git a/py/shared/api/models/management/responses.py b/py/shared/api/models/management/responses.py index a285af620..e8c0bd8e0 100644 --- a/py/shared/api/models/management/responses.py +++ b/py/shared/api/models/management/responses.py @@ -133,6 +133,11 @@ class ConversationOverviewResponse(BaseModel): created_at: datetime +class VerificationResult(BaseModel): + verification_code: str + expiry: datetime + + class AddUserResponse(BaseModel): result: bool @@ -165,6 +170,7 @@ class AddUserResponse(BaseModel): list[DocumentChunkResponse] ] WrappedDeleteResponse = ResultsWrapper[None] +WrappedVerificationResult = ResultsWrapper[VerificationResult] WrappedConversationsOverviewResponse = PaginatedResultsWrapper[ list[ConversationOverviewResponse] ]