Skip to content

Commit

Permalink
expose user verification code
Browse files Browse the repository at this point in the history
  • Loading branch information
emrgnt-cmplxty committed Oct 31, 2024
1 parent 5913612 commit a0ad656
Show file tree
Hide file tree
Showing 7 changed files with 123 additions and 0 deletions.
2 changes: 2 additions & 0 deletions py/core/base/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
WrappedUserCollectionResponse,
WrappedUserOverviewResponse,
WrappedUsersInCollectionResponse,
WrappedVerificationResult,
)
from shared.api.models.retrieval.responses import (
RAGAgentResponse,
Expand All @@ -81,6 +82,7 @@
"UserResponse",
"WrappedTokenResponse",
"WrappedUserResponse",
"WrappedVerificationResult",
"WrappedGenericMessageResponse",
# Ingestion Responses
"IngestionResponse",
Expand Down
15 changes: 15 additions & 0 deletions py/core/base/providers/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,16 @@ async def get_users_overview(
) -> dict[str, Union[list[UserStats], int]]:
pass

@abstractmethod
async def get_user_verification_data(
self, user_id: UUID, *args, **kwargs
) -> dict:
"""
Get verification data for a specific user.
This method should be called after superuser authorization has been verified.
"""
pass


class VectorHandler(Handler):
def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -1368,6 +1378,11 @@ async def get_users_overview(
user_ids, offset, limit
)

async def get_user_verification_data(
self, user_id: UUID, *args, **kwargs
) -> dict:
return await self.user_handler.get_user_verification_data(user_id)

# Vector handler methods
async def upsert(self, entry: VectorEntry) -> None:
return await self.vector_handler.upsert(entry)
Expand Down
26 changes: 26 additions & 0 deletions py/core/main/api/management_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
WrappedUserCollectionResponse,
WrappedUserOverviewResponse,
WrappedUsersInCollectionResponse,
WrappedVerificationResult,
)
from core.base.logger import AnalysisTypes, LogFilterCriteria
from core.providers import (
Expand Down Expand Up @@ -887,3 +888,28 @@ async def delete_conversation(
) -> WrappedDeleteResponse:
await self.service.delete_conversation(conversation_id)
return None # type: ignore

@self.router.get("/user/{user_id}/verification_data")
@self.base_endpoint
async def get_user_verification_code(
user_id: str = Path(..., description="User ID"),
auth_user=Depends(self.service.providers.auth.auth_wrapper),
) -> WrappedVerificationResult:
"""
Get only the verification code for a specific user.
Only accessible by superusers.
"""
if not auth_user.is_superuser:
raise R2RException(
status_code=403,
message="Only superusers can access verification codes",
)

try:
user_uuid = UUID(user_id)
except ValueError:
raise R2RException(
status_code=400, message="Invalid user ID format"
)

return await self.service.get_user_verification_data(user_uuid)
20 changes: 20 additions & 0 deletions py/core/main/services/management_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,3 +740,23 @@ async def branch_at_message(self, message_id: str, auth_user=None) -> str:
@telemetry_event("DeleteConversation")
async def delete_conversation(self, conversation_id: str, auth_user=None):
await self.logging_connection.delete_conversation(conversation_id)

@telemetry_event("GetUserVerificationCode")
async def get_user_verification_data(
self, user_id: UUID, *args, **kwargs
) -> dict:
"""
Get only the verification code data for a specific user.
This method should be called after superuser authorization has been verified.
"""
verification_data = (
await self.providers.database.get_user_verification_data(user_id)
)
return {
"verification_code": verification_data["verification_data"][
"verification_code"
],
"expiry": verification_data["verification_data"][
"verification_code_expiry"
],
}
38 changes: 38 additions & 0 deletions py/core/providers/database/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,3 +575,41 @@ async def _collection_exists(self, collection_id: UUID) -> bool:
query, [collection_id]
)
return result is not None

async def get_user_verification_data(
self, user_id: UUID, *args, **kwargs
) -> dict:
"""
Get verification data for a specific user.
This method should be called after superuser authorization has been verified.
"""
query = f"""
SELECT
verification_code,
verification_code_expiry,
reset_token,
reset_token_expiry
FROM {self._get_table_name("users")}
WHERE user_id = $1
"""
result = await self.connection_manager.fetchrow_query(query, [user_id])

if not result:
raise R2RException(status_code=404, message="User not found")

return {
"verification_data": {
"verification_code": result["verification_code"],
"verification_code_expiry": (
result["verification_code_expiry"].isoformat()
if result["verification_code_expiry"]
else None
),
"reset_token": result["reset_token"],
"reset_token_expiry": (
result["reset_token_expiry"].isoformat()
if result["reset_token_expiry"]
else None
),
}
}
16 changes: 16 additions & 0 deletions py/sdk/mixins/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,3 +196,19 @@ async def login_with_token(
self.access_token = None
self._refresh_token = None
raise ValueError("Invalid tokens provided")

async def get_user_verification_code(
self, user_id: Union[str, UUID]
) -> dict:
"""
Retrieves only the verification code for a specific user. Requires superuser access.
Args:
user_id (Union[str, UUID]): The ID of the user to get verification code for.
Returns:
dict: Contains verification code and its expiry date
"""
return await self._make_request( # type: ignore
"GET", f"user/{user_id}/verification_data"
)
6 changes: 6 additions & 0 deletions py/shared/api/models/management/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,11 @@ class ConversationOverviewResponse(BaseModel):
created_at: datetime


class VerificationResult(BaseModel):
verification_code: str
expiry: datetime


class AddUserResponse(BaseModel):
result: bool

Expand Down Expand Up @@ -165,6 +170,7 @@ class AddUserResponse(BaseModel):
list[DocumentChunkResponse]
]
WrappedDeleteResponse = ResultsWrapper[None]
WrappedVerificationResult = ResultsWrapper[VerificationResult]
WrappedConversationsOverviewResponse = PaginatedResultsWrapper[
list[ConversationOverviewResponse]
]

0 comments on commit a0ad656

Please sign in to comment.