Skip to content

Commit

Permalink
Fix bug in deletion, better validation error handling (#1374)
Browse files Browse the repository at this point in the history
* Update graphrag.mdx

* Fix bug in deletion, better validation error handling

---------

Co-authored-by: Shreyas Pimpalgaonkar <shreyas.gp.7@gmail.com>
  • Loading branch information
NolanTrem and shreyaspimpalgaonkar authored Oct 10, 2024
1 parent 87c1e9e commit fde366d
Show file tree
Hide file tree
Showing 8 changed files with 79 additions and 12 deletions.
1 change: 1 addition & 0 deletions py/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@
"RelationshipType",
"format_entity_types",
"format_relations",
"validate_uuid",
## MAIN
## R2R ABSTRACTIONS
"R2RProviders",
Expand Down
1 change: 1 addition & 0 deletions py/core/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@
"to_async_generator",
"format_search_results_for_llm",
"format_search_results_for_stream",
"validate_uuid",
# ID generation
"generate_run_id",
"generate_document_id",
Expand Down
2 changes: 2 additions & 0 deletions py/core/base/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
increment_version,
run_pipeline,
to_async_generator,
validate_uuid,
)

__all__ = [
Expand All @@ -38,4 +39,5 @@
"generate_default_prompt_id",
"RecursiveCharacterTextSplitter",
"TextSplitter",
"validate_uuid",
]
26 changes: 22 additions & 4 deletions py/core/main/api/management_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import json
import mimetypes
from datetime import datetime, timezone
from typing import Optional, Set
from typing import Optional, Set, Any
from uuid import UUID

import psutil
Expand Down Expand Up @@ -268,8 +268,26 @@ async def delete_app(
filters: str = Query(..., description="JSON-encoded filters"),
auth_user=Depends(self.service.providers.auth.auth_wrapper),
):
filters_dict = json.loads(filters) if filters else None
return await self.service.delete(filters=filters_dict) # type: ignore
try:
filters_dict = json.loads(filters)
except json.JSONDecodeError:
raise R2RException(
status_code=422, message="Invalid JSON in filters"
)

if not isinstance(filters_dict, dict):
raise R2RException(
status_code=422, message="Filters must be a JSON object"
)

for key, value in filters_dict.items():
if not isinstance(value, dict):
raise R2RException(
status_code=422,
message=f"Invalid filter format for key: {key}",
)

return await self.service.delete(filters=filters_dict)

@self.router.get(
"/download_file/{document_id}", response_class=StreamingResponse
Expand All @@ -288,7 +306,7 @@ async def download_file_app(
document_uuid = UUID(document_id)
except ValueError:
raise R2RException(
status_code=400, message="Invalid document ID format."
status_code=422, message="Invalid document ID format."
)

file_tuple = await self.service.download_file(document_uuid)
Expand Down
53 changes: 45 additions & 8 deletions py/core/main/services/management_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections import defaultdict
from typing import Any, BinaryIO, Dict, Optional, Tuple
from uuid import UUID
from core.base.utils import validate_uuid

import toml

Expand Down Expand Up @@ -273,6 +274,45 @@ async def delete(
NOTE: This method is not atomic and may result in orphaned entries in the documents overview table.
NOTE: This method assumes that filters delete entire contents of any touched documents.
"""

def validate_filters(filters: dict[str, Any]) -> None:
ALLOWED_FILTERS = {"document_id", "user_id", "collection_ids"}

if not filters:
raise R2RException(
status_code=422, message="No filters provided"
)

for field in filters:
if field not in ALLOWED_FILTERS:
raise R2RException(
status_code=422,
message=f"Invalid filter field: {field}",
)

for field in ["document_id", "user_id"]:
if field in filters:
op = next(iter(filters[field].keys()))
try:
validate_uuid(filters[field][op])
except ValueError:
raise R2RException(
status_code=422,
message=f"Invalid UUID: {filters[field][op]}",
)

if "collection_ids" in filters:
op = next(iter(filters["collection_ids"].keys()))
for id_str in filters["collection_ids"][op]:
try:
validate_uuid(id_str)
except ValueError:
raise R2RException(
status_code=422, message=f"Invalid UUID: {id_str}"
)

validate_filters(filters)

logger.info(f"Deleting entries with filters: {filters}")

try:
Expand All @@ -297,17 +337,14 @@ async def delete(
relational_filters = {}
if "document_id" in filters:
relational_filters["filter_document_ids"] = [
UUID(filters["document_id"]["$eq"])
filters["document_id"]["$eq"]
]
if "user_id" in filters:
relational_filters["filter_user_ids"] = [
UUID(filters["user_id"]["$eq"])
]
relational_filters["filter_user_ids"] = [filters["user_id"]["$eq"]]
if "collection_ids" in filters:
relational_filters["filter_collection_ids"] = [
UUID(collection_id)
for collection_id in filters["collection_ids"]["$in"]
]
relational_filters["filter_collection_ids"] = list(
filters["collection_ids"]["$in"]
)

try:
documents_overview = (
Expand Down
2 changes: 2 additions & 0 deletions py/core/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
increment_version,
run_pipeline,
to_async_generator,
validate_uuid,
)
from shared.utils.splitter.text import (
RecursiveCharacterTextSplitter,
Expand All @@ -36,6 +37,7 @@
"run_pipeline",
"to_async_generator",
"generate_default_user_collection_id",
"validate_uuid",
# Text splitter
"RecursiveCharacterTextSplitter",
"TextSplitter",
Expand Down
2 changes: 2 additions & 0 deletions py/shared/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
llm_cost_per_million_tokens,
run_pipeline,
to_async_generator,
validate_uuid,
)
from .splitter.text import RecursiveCharacterTextSplitter, TextSplitter

Expand All @@ -39,6 +40,7 @@
"run_pipeline",
"to_async_generator",
"llm_cost_per_million_tokens",
"validate_uuid",
# Text splitter
"RecursiveCharacterTextSplitter",
"TextSplitter",
Expand Down
4 changes: 4 additions & 0 deletions py/shared/utils/base_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,3 +239,7 @@ def llm_cost_per_million_tokens(
* input_output_ratio
* cost_dict["gpt-4o"][1]
) / (1 + input_output_ratio)


def validate_uuid(uuid_str: str) -> UUID:
return UUID(uuid_str)

0 comments on commit fde366d

Please sign in to comment.