diff --git a/py/core/main/api/v3/base_router.py b/py/core/main/api/v3/base_router.py new file mode 100644 index 000000000..b44fad92e --- /dev/null +++ b/py/core/main/api/v3/base_router.py @@ -0,0 +1,103 @@ +import functools +import logging +from abc import abstractmethod +from typing import Callable, Union + +from fastapi import APIRouter, HTTPException +from fastapi.responses import StreamingResponse + +from core.base import R2RException, manage_run +from core.base.logger.base import RunType +from core.providers import ( + HatchetOrchestrationProvider, + SimpleOrchestrationProvider, +) + +from ...services.base import Service + +logger = logging.getLogger() + + +class BaseRouterV3: + def __init__(self, providers, services, orchestration_provider, run_type): + self.providers = providers + self.services = services + self.run_type = run_type + self.orchestration_provider = orchestration_provider + self.router = APIRouter() + self.openapi_extras = self._load_openapi_extras() + self._setup_routes() + self._register_workflows() + + def get_router(self): + return self.router + + def base_endpoint(self, func: Callable): + @functools.wraps(func) + async def wrapper(*args, **kwargs): + async with manage_run( + self.services["ingestion"].run_manager, func.__name__ + ) as run_id: + auth_user = kwargs.get("auth_user") + if auth_user: + await self.services[ + "ingestion" + ].run_manager.log_run_info( # TODO - this is a bit of a hack + run_type=self.run_type, + user=auth_user, + ) + + try: + func_result = await func(*args, **kwargs) + if ( + isinstance(func_result, tuple) + and len(func_result) == 2 + ): + results, outer_kwargs = func_result + else: + results, outer_kwargs = func_result, {} + + if isinstance(results, StreamingResponse): + return results + return {"results": results, **outer_kwargs} + + except R2RException: + raise + + except Exception as e: + + await self.services["ingestion"].logging_connection.log( + run_id=run_id, + key="error", + value=str(e), + ) + + logger.error( + f"Error in base endpoint {func.__name__}() - \n\n{str(e)}", + exc_info=True, + ) + + raise HTTPException( + status_code=500, + detail={ + "message": f"An error '{e}' occurred during {func.__name__}", + "error": str(e), + "error_type": type(e).__name__, + }, + ) from e + + return wrapper + + @classmethod + def build_router(cls, engine): + return cls(engine).router + + @abstractmethod + def _setup_routes(self): + pass + + def _register_workflows(self): + pass + + def _load_openapi_extras(self): + return {} diff --git a/py/core/main/api/v3/chunk_responses.py b/py/core/main/api/v3/chunk_responses.py new file mode 100644 index 000000000..eb5e7c994 --- /dev/null +++ b/py/core/main/api/v3/chunk_responses.py @@ -0,0 +1,14 @@ +from typing import Any, Optional +from uuid import UUID + +from pydantic import BaseModel + + +class ChunkResponse(BaseModel): + document_id: UUID + extraction_id: UUID + user_id: UUID + collection_ids: list[UUID] + text: str + metadata: dict[str, Any] + vector: Optional[list[float]] = None diff --git a/py/core/main/api/v3/document_responses.py b/py/core/main/api/v3/document_responses.py index d97f2e995..883fa47af 100644 --- a/py/core/main/api/v3/document_responses.py +++ b/py/core/main/api/v3/document_responses.py @@ -45,3 +45,11 @@ class DocumentResponse(BaseModel): version: str collection_ids: list[UUID] metadata: dict[str, Any] + + +class CollectionResponse(BaseModel): + collection_id: UUID + name: str + description: Optional[str] + created_at: datetime + updated_at: datetime diff --git a/py/core/main/api/v3/document_router.py b/py/core/main/api/v3/document_router.py index 83f99e21b..1ff844d3d 100644 --- a/py/core/main/api/v3/document_router.py +++ b/py/core/main/api/v3/document_router.py @@ -17,13 +17,18 @@ ) from shared.api.models.base import PaginatedResultsWrapper, ResultsWrapper -from ..v2.base_router import BaseRouter -from .document_responses import DocumentIngestionResponse, DocumentResponse +from .base_router import BaseRouterV3 +from .chunk_responses import ChunkResponse +from .document_responses import ( + CollectionResponse, + DocumentIngestionResponse, + DocumentResponse, +) logger = logging.getLogger() -class DocumentRouter(BaseRouter): +class DocumentRouter(BaseRouterV3): def __init__( self, providers, @@ -33,40 +38,20 @@ def __init__( ], run_type: RunType = RunType.INGESTION, ): - super().__init__(services, orchestration_provider, run_type) - self.providers = providers - self.services = services - - def _register_workflows(self): - self.orchestration_provider.register_workflows( - Workflow.INGESTION, - self.services.ingestion, - { - "ingest-document": ( - "Ingest document task queued successfully." - if self.orchestration_provider.config.provider != "simple" - else "Ingestion task completed successfully." - ), - "update-document": ( - "Update file task queued successfully." - if self.orchestration_provider.config.provider != "simple" - else "Update task queued successfully." - ), - }, - ) + super().__init__(providers, services, orchestration_provider, run_type) def _setup_routes(self): @self.router.post("/documents") @self.base_endpoint - async def ingest_documents( - file: Optional[UploadFile] = File(...), + async def ingest_document( + file: Optional[UploadFile] = File(None), content: Optional[str] = Form(None), document_id: Optional[Json[UUID]] = Form(None), metadata: Optional[Json[dict]] = Form(None), ingestion_config: Optional[Json[dict]] = Form(None), run_with_orchestration: Optional[bool] = Form(True), auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> ResultsWrapper[list[DocumentIngestionResponse]]: + ) -> ResultsWrapper[DocumentIngestionResponse]: """ Creates a new `Document` object from an input file or text content. Each document has corresponding `Chunk` objects which are used in vector indexing and search. @@ -85,6 +70,7 @@ async def ingest_documents( message="Both a file and content cannot be provided.", ) # Check if the user is a superuser + metadata = metadata or {} if not auth_user.is_superuser: if "user_id" in metadata and ( not auth_user.is_superuser @@ -99,7 +85,6 @@ async def ingest_documents( if file: file_data = await self._process_file(file) - content_length = len(file_data["content"]) file_content = BytesIO(base64.b64decode(file_data["content"])) @@ -107,24 +92,29 @@ async def ingest_documents( document_id = document_id or generate_document_id( file_data["filename"], auth_user.id ) - else: + elif content: content_length = len(content) file_content = BytesIO(content.encode("utf-8")) document_id = document_id or generate_document_id( content, auth_user.id ) file_data = { - "filename": f"{document_id}.txt", + "filename": "N/A", "content_type": "text/plain", } + else: + raise R2RException( + status_code=422, + message="Either a file or content must be provided.", + ) workflow_input = { - "file_datas": [file_data], - "document_ids": [str(document_id)], - "metadatas": [metadata], + "file_data": file_data, + "document_id": str(document_id), + "metadata": metadata, "ingestion_config": ingestion_config, "user": auth_user.model_dump_json(), - "file_sizes_in_bytes": [content_length], + "size_in_bytes": content_length, "is_update": False, } @@ -137,7 +127,7 @@ async def ingest_documents( ) if run_with_orchestration: raw_message: dict[str, Union[str, None]] = await self.orchestration_provider.run_workflow( # type: ignore - "ingest-documents", + "ingest-files", {"request": workflow_input}, options={ "additional_metadata": { @@ -146,7 +136,7 @@ async def ingest_documents( }, ) raw_message["document_id"] = str(document_id) - return raw_message + return raw_message # type: ignore else: logger.info( f"Running ingestion without orchestration for file {file_name} and document_id {document_id}." @@ -154,9 +144,11 @@ async def ingest_documents( # TODO - Clean up implementation logic here to be more explicitly `synchronous` from core.main.orchestration import simple_ingestion_factory - simple_ingestor = simple_ingestion_factory(self.service) + simple_ingestor = simple_ingestion_factory( + self.services["ingestion"] + ) await simple_ingestor["ingest-files"](workflow_input) - return { + return { # type: ignore "message": "Ingestion task completed successfully.", "document_id": str(document_id), "task_id": None, @@ -167,14 +159,14 @@ async def ingest_documents( ) @self.base_endpoint async def update_document( - file: Optional[UploadFile] = File(...), + file: Optional[UploadFile] = File(None), content: Optional[str] = Form(None), document_id: UUID = Path(...), - metadata: Optional[Json[list[dict]]] = Form(None), + metadata: Optional[Json[dict]] = Form(None), ingestion_config: Optional[Json[dict]] = Form(None), run_with_orchestration: Optional[bool] = Form(True), auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> ResultsWrapper[list[DocumentIngestionResponse]]: + ) -> ResultsWrapper[DocumentIngestionResponse]: """ Ingests updated files into R2R, updating the corresponding `Document` and `Chunk` objects from previous ingestion. @@ -192,6 +184,8 @@ async def update_document( status_code=422, message="Both a file and content cannot be provided.", ) + metadata = metadata or {} # type: ignore + # Check if the user is a superuser if not auth_user.is_superuser: if "user_id" in metadata and metadata["user_id"] != str( @@ -208,11 +202,11 @@ async def update_document( content_length = len(file_data["content"]) file_content = BytesIO(base64.b64decode(file_data["content"])) file_data.pop("content", None) - else: + elif content: content_length = len(content) file_content = BytesIO(content.encode("utf-8")) file_data = { - "filename": f"{document_id}.txt", + "filename": f"N/A", "content_type": "text/plain", } @@ -222,6 +216,11 @@ async def update_document( file_content, file_data["content_type"], ) + else: + raise R2RException( + status_code=422, + message="Either a file or content must be provided.", + ) workflow_input = { "file_datas": [file_data], @@ -237,10 +236,10 @@ async def update_document( if run_with_orchestration: raw_message: dict[str, Union[str, None]] = await self.orchestration_provider.run_workflow( # type: ignore - "update-documents", {"request": workflow_input}, {} + "update-files", {"request": workflow_input}, {} ) raw_message["message"] = "Update task queued successfully." - raw_message["document_ids"] = workflow_input["document_ids"] + raw_message["document_id"] = workflow_input["document_ids"][0] return raw_message # type: ignore else: @@ -248,11 +247,13 @@ async def update_document( # TODO - Clean up implementation logic here to be more explicitly `synchronous` from core.main.orchestration import simple_ingestion_factory - simple_ingestor = simple_ingestion_factory(self.service) + simple_ingestor = simple_ingestion_factory( + self.services["ingestion"] + ) await simple_ingestor["update-files"](workflow_input) return { # type: ignore "message": "Update task completed successfully.", - "document_ids": workflow_input["document_ids"], + "document_id": workflow_input["document_ids"], "task_id": None, } @@ -277,17 +278,17 @@ async def get_documents( document_uuids = [ UUID(document_id) for document_id in document_ids ] - documents_overview_response = ( - await self.service.management.documents_overview( - user_ids=request_user_ids, - collection_ids=filter_collection_ids, - document_ids=document_uuids, - offset=offset, - limit=limit, - ) + documents_overview_response = await self.services[ + "management" + ].documents_overview( + user_ids=request_user_ids, + collection_ids=filter_collection_ids, + document_ids=document_uuids, + offset=offset, + limit=limit, ) - return ( + return ( # type: ignore documents_overview_response["results"], { "total_entries": documents_overview_response[ @@ -312,12 +313,12 @@ async def get_document( None if auth_user.is_superuser else auth_user.collection_ids ) - documents_overview_response = ( - await self.service.management.documents_overview( - user_ids=request_user_ids, - collection_ids=filter_collection_ids, - document_ids=[document_id], - ) + documents_overview_response = await self.services[ + "management" + ].documents_overview( + user_ids=request_user_ids, + collection_ids=filter_collection_ids, + document_ids=[document_id], ) results = documents_overview_response["results"] if len(results) == 0: @@ -325,46 +326,54 @@ async def get_document( return results[0] - # Put this onto the chunk page - - # @self.router.get("/documents/{document_id}/chunks") - # @self.base_endpoint - # async def get_document_chunks( - # document_id: UUID = Path(...), - # offset: Optional[int] = Query(0, ge=0), - # limit: Optional[int] = Query(100, ge=0), - # include_vectors: Optional[bool] = Query(False), - # auth_user=Depends(self.providers.auth.auth_wrapper), - # ) -> PaginatedResultsWrapper[list[DocumentResponse]]: - # """ - # Get chunks for a specific document. - # """ - # document_chunks = await self.service.document_chunks( - # document_id, offset, limit, include_vectors - # ) - - # if not document_chunks["results"]: - # raise R2RException("No chunks found for the given document ID.", 404) - - # is_owner = str(document_chunks["results"][0].get("user_id")) == str(auth_user.id) - # document_collections = await self.service.document_collections(document_id, 0, -1) - - # user_has_access = ( - # is_owner or - # set(auth_user.collection_ids).intersection( - # {ele.collection_id for ele in document_collections["results"]} - # ) != set() - # ) - - # if not user_has_access and not auth_user.is_superuser: - # raise R2RException( - # "Not authorized to access this document's chunks.", 403 - # ) - - # return ( - # document_chunks["results"], - # {"total_entries": document_chunks["total_entries"]} - # ) + @self.router.get("/documents/{document_id}/chunks") + @self.base_endpoint + async def get_document_chunks( + document_id: UUID = Path(...), + offset: Optional[int] = Query(0, ge=0), + limit: Optional[int] = Query(100, ge=0), + include_vectors: Optional[bool] = Query(False), + auth_user=Depends(self.providers.auth.auth_wrapper), + ) -> PaginatedResultsWrapper[list[ChunkResponse]]: + """ + Get chunks for a specific document. + """ + document_chunks = await self.services[ + "management" + ].document_chunks(document_id, offset, limit, include_vectors) + + if not document_chunks["results"]: + raise R2RException( + "No chunks found for the given document ID.", 404 + ) + + is_owner = str( + document_chunks["results"][0].get("user_id") + ) == str(auth_user.id) + document_collections = await self.services[ + "management" + ].document_collections(document_id, 0, -1) + + user_has_access = ( + is_owner + or set(auth_user.collection_ids).intersection( + { + ele.collection_id + for ele in document_collections["results"] + } + ) + != set() + ) + + if not user_has_access and not auth_user.is_superuser: + raise R2RException( + "Not authorized to access this document's chunks.", 403 + ) + + return ( # type: ignore + document_chunks["results"], + {"total_entries": document_chunks["total_entries"]}, + ) @self.router.get( "/documents/{document_id}/download", @@ -373,7 +382,7 @@ async def get_document( @self.base_endpoint async def get_document_file( document_id: str = Path(..., description="Document ID"), - auth_user=Depends(self.service.providers.auth.auth_wrapper), + auth_user=Depends(self.providers.auth.auth_wrapper), ): """ Download a file by its corresponding document ID. @@ -387,7 +396,7 @@ async def get_document_file( status_code=422, message="Invalid document ID format." ) - file_tuple = await self.service.management.download_file( + file_tuple = await self.services["management"].download_file( document_uuid ) if not file_tuple: @@ -431,14 +440,14 @@ async def delete_document_by_id( {"document_id": {"$eq": document_id}}, ] } - await self.services.management.delete(filters=filters) + await self.services["management"].delete(filters=filters) return None - @self.router.delete("/documents/filtered") + @self.router.delete("/documents/by-filter") @self.base_endpoint async def delete_document_by_id( filters: str = Query(..., description="JSON-encoded filters"), - auth_user=Depends(self.service.providers.auth.auth_wrapper), + auth_user=Depends(self.providers.auth.auth_wrapper), ) -> ResultsWrapper[None]: """ Delete documents based on provided filters. @@ -467,20 +476,26 @@ async def delete_document_by_id( return await self.service.management.delete(filters=filters_dict) - # Put this onto the collection page - # @self.router.get("/documents/{document_id}/collections") - # @self.base_endpoint - # async def get_document_collections( - # document_id: UUID = Path(...), - # offset: int = Query(0, ge=0), - # limit: int = Query(100, ge=1, le=1000), - # auth_user=Depends(self.providers.auth.auth_wrapper), - # ) -> WrappedCollectionListResponse: - # """ - # Get collections that contain a specific document. - # """ - # if not auth_user.is_superuser: - # document = await self.service.get + @self.router.get("/documents/{document_id}/collections") + @self.base_endpoint + async def get_document_collections( + document_id: str = Path(..., description="Document ID"), + offset: int = Query(0, ge=0), + limit: int = Query(100, ge=1, le=1000), + auth_user=Depends(self.providers.auth.auth_wrapper), + ) -> ResultsWrapper[list[CollectionResponse]]: + if not auth_user.is_superuser: + raise R2RException( + "Only a superuser can get the collections belonging to a document.", + 403, + ) + document_collections_response = await self.services[ + "management" + ].document_collections(document_id, offset, limit) + + return document_collections_response["results"], { # type: ignore + "total_entries": document_collections_response["total_entries"] + } @staticmethod async def _process_file(file): diff --git a/py/core/main/assembly/builder.py b/py/core/main/assembly/builder.py index 84e92e257..52d7753ef 100644 --- a/py/core/main/assembly/builder.py +++ b/py/core/main/assembly/builder.py @@ -251,10 +251,10 @@ async def build(self, *args, **kwargs) -> R2RApp: orchestration_provider=orchestration_provider, ).get_router(), "document_router": DocumentRouter( - providers, - services, + providers=providers, + services=services, orchestration_provider=orchestration_provider, - ), + ).get_router(), } return R2RApp( diff --git a/py/core/main/orchestration/simple/ingestion_workflow.py b/py/core/main/orchestration/simple/ingestion_workflow.py index 9cf71d1cd..e7e442430 100644 --- a/py/core/main/orchestration/simple/ingestion_workflow.py +++ b/py/core/main/orchestration/simple/ingestion_workflow.py @@ -22,9 +22,12 @@ async def ingest_files(input_data): from core.base import IngestionStatus from core.main import IngestionServiceAdapter + print("a") parsed_data = IngestionServiceAdapter.parse_ingest_file_input( input_data ) + print("b") + is_update = parsed_data["is_update"] ingestion_result = await service.ingest_file_ingress(**parsed_data) diff --git a/py/core/main/services/ingestion_service.py b/py/core/main/services/ingestion_service.py index 72637b1c0..2263160c8 100644 --- a/py/core/main/services/ingestion_service.py +++ b/py/core/main/services/ingestion_service.py @@ -115,12 +115,12 @@ async def ingest_file_ingress( ): raise R2RException( status_code=409, - message=f"Must increment version number before attempting to overwrite document {document_id}. Use the `update_files` endpoint if you are looking to update the existing version.", + message=f"Document {document_id} already exists. Increment the version to overwrite existing document. Otherwise, submit a POST request to `/documents/{document_id}` to update the existing version.", ) elif existing_doc.ingestion_status != IngestionStatus.FAILED: raise R2RException( status_code=409, - message=f"Document {document_id} was already ingested and is not in a failed state.", + message=f"Document {document_id} is currently ingesting.", ) await self.providers.database.upsert_documents_overview( @@ -147,7 +147,9 @@ def _create_document_info_from_file( version: str, size_in_bytes: int, ) -> DocumentInfo: - file_extension = file_name.split(".")[-1].lower() + file_extension = ( + file_name.split(".")[-1].lower() if file_name != "N/A" else "txt" + ) if file_extension.upper() not in DocumentType.__members__: raise R2RException( status_code=415, @@ -162,7 +164,11 @@ def _create_document_info_from_file( user_id=user.id, collection_ids=metadata.get("collection_ids", []), document_type=DocumentType[file_extension.upper()], - title=metadata.get("title", file_name.split("/")[-1]), + title=( + metadata.get("title", file_name.split("/")[-1]) + if file_name != "N/A" + else "N/A" + ), metadata=metadata, version=version, size_in_bytes=size_in_bytes, diff --git a/py/sdk/async_client.py b/py/sdk/async_client.py index df48089d0..040d18593 100644 --- a/py/sdk/async_client.py +++ b/py/sdk/async_client.py @@ -6,7 +6,7 @@ from shared.abstractions import R2RException from .base.base_client import BaseClient -from .mixins import ( +from .v2.mixins import ( AuthMixins, IngestionMixins, KGMixins, @@ -14,6 +14,7 @@ RetrievalMixins, ServerMixins, ) +from .v3.document import DocumentSDK class R2RAsyncClient( @@ -44,6 +45,7 @@ def __init__( ): super().__init__(base_url, prefix, timeout) self.client = custom_client or httpx.AsyncClient(timeout=timeout) + self.documents = DocumentSDK(self) async def _make_request(self, method: str, endpoint: str, **kwargs): url = self._get_full_url(endpoint) diff --git a/py/sdk/base/base_client.py b/py/sdk/base/base_client.py index 096cd69aa..8bafa5e2e 100644 --- a/py/sdk/base/base_client.py +++ b/py/sdk/base/base_client.py @@ -1,8 +1,37 @@ +import asyncio +from functools import wraps from typing import Optional from shared.abstractions import R2RException +def sync_wrapper(async_func): + """Decorator to convert async methods to sync methods""" + + @wraps(async_func) + def wrapper(*args, **kwargs): + loop = asyncio.get_event_loop() + return loop.run_until_complete(async_func(*args, **kwargs)) + + return wrapper + + +def sync_generator_wrapper(async_gen_func): + """Decorator to convert async generators to sync generators""" + + @wraps(async_gen_func) + def wrapper(*args, **kwargs): + async_gen = async_gen_func(*args, **kwargs) + loop = asyncio.get_event_loop() + try: + while True: + yield loop.run_until_complete(async_gen.__anext__()) + except StopAsyncIteration: + pass + + return wrapper + + class BaseClient: def __init__( self, diff --git a/py/sdk/sync_client.py b/py/sdk/sync_client.py index b5371ebb2..5e00aa1a1 100644 --- a/py/sdk/sync_client.py +++ b/py/sdk/sync_client.py @@ -2,6 +2,7 @@ from .async_client import R2RAsyncClient from .utils import SyncClientMetaclass +from .v3.document import SyncDocumentSDK class R2RClient(R2RAsyncClient, metaclass=SyncClientMetaclass): @@ -17,6 +18,7 @@ class R2RClient(R2RAsyncClient, metaclass=SyncClientMetaclass): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self.documents = SyncDocumentSDK(self.documents) def _make_streaming_request(self, method: str, endpoint: str, **kwargs): async_gen = super()._make_streaming_request(method, endpoint, **kwargs) diff --git a/py/sdk/mixins/__init__.py b/py/sdk/v2/mixins/__init__.py similarity index 100% rename from py/sdk/mixins/__init__.py rename to py/sdk/v2/mixins/__init__.py diff --git a/py/sdk/mixins/auth.py b/py/sdk/v2/mixins/auth.py similarity index 99% rename from py/sdk/mixins/auth.py rename to py/sdk/v2/mixins/auth.py index 03dc5db5c..f6d24dc78 100644 --- a/py/sdk/mixins/auth.py +++ b/py/sdk/v2/mixins/auth.py @@ -1,7 +1,7 @@ from typing import Optional, Union from uuid import UUID -from ..models import Token, UserResponse +from ...models import Token, UserResponse class AuthMixins: diff --git a/py/sdk/mixins/ingestion.py b/py/sdk/v2/mixins/ingestion.py similarity index 100% rename from py/sdk/mixins/ingestion.py rename to py/sdk/v2/mixins/ingestion.py diff --git a/py/sdk/mixins/kg.py b/py/sdk/v2/mixins/kg.py similarity index 99% rename from py/sdk/mixins/kg.py rename to py/sdk/v2/mixins/kg.py index 87c090636..4310dac62 100644 --- a/py/sdk/mixins/kg.py +++ b/py/sdk/v2/mixins/kg.py @@ -1,7 +1,7 @@ from typing import Optional, Union from uuid import UUID -from ..models import ( +from ...models import ( KGCreationSettings, KGEnrichmentSettings, KGEntityDeduplicationResponse, diff --git a/py/sdk/mixins/management.py b/py/sdk/v2/mixins/management.py similarity index 99% rename from py/sdk/mixins/management.py rename to py/sdk/v2/mixins/management.py index ccdd95e53..450cc393d 100644 --- a/py/sdk/mixins/management.py +++ b/py/sdk/v2/mixins/management.py @@ -2,7 +2,7 @@ from typing import Any, Optional, Union from uuid import UUID -from ..models import Message +from ...models import Message class ManagementMixins: diff --git a/py/sdk/mixins/retrieval.py b/py/sdk/v2/mixins/retrieval.py similarity index 99% rename from py/sdk/mixins/retrieval.py rename to py/sdk/v2/mixins/retrieval.py index 8c94a1fa1..230257e3b 100644 --- a/py/sdk/mixins/retrieval.py +++ b/py/sdk/v2/mixins/retrieval.py @@ -1,7 +1,7 @@ import logging from typing import AsyncGenerator, Optional, Union -from ..models import ( +from ...models import ( GenerationConfig, KGSearchSettings, Message, diff --git a/py/sdk/mixins/server.py b/py/sdk/v2/mixins/server.py similarity index 100% rename from py/sdk/mixins/server.py rename to py/sdk/v2/mixins/server.py diff --git a/py/sdk/v3/document.py b/py/sdk/v3/document.py new file mode 100644 index 000000000..1b5bfa12a --- /dev/null +++ b/py/sdk/v3/document.py @@ -0,0 +1,303 @@ +import asyncio +import json +import os +from inspect import getmembers, isasyncgenfunction, iscoroutinefunction +from io import BytesIO +from typing import Any, Dict, List, Optional, Union +from uuid import UUID + +from ..base.base_client import sync_generator_wrapper, sync_wrapper + + +class DocumentSDK: + """ + SDK for interacting with documents in the v3 API. + """ + + def __init__(self, client): + self.client = client + + async def create( + self, + file_path: Optional[str] = None, + content: Optional[str] = None, + document_id: Optional[Union[str, UUID]] = None, + metadata: Optional[dict] = None, + ingestion_config: Optional[dict] = None, + run_with_orchestration: Optional[bool] = True, + ) -> dict: + """ + Create a new document from either a file or content. + """ + if not file_path and not content: + raise ValueError("Either file_path or content must be provided") + if file_path and content: + raise ValueError("Cannot provide both file_path and content") + + data = {} + files = None + + if document_id: + data["document_id"] = json.dumps(str(document_id)) + if metadata: + data["metadata"] = json.dumps(metadata) + if ingestion_config: + data["ingestion_config"] = json.dumps(ingestion_config) + if run_with_orchestration is not None: + data["run_with_orchestration"] = str(run_with_orchestration) + + if file_path: + # Create a new file instance that will remain open during the request + file_instance = open(file_path, "rb") + files = [ + ( + "file", + (file_path, file_instance, "application/octet-stream"), + ) + ] + try: + result = await self.client._make_request( + "POST", "documents", data=data, files=files + ) + finally: + # Ensure we close the file after the request is complete + file_instance.close() + return result + else: + data["content"] = content # type: ignore + return await self.client._make_request( + "POST", "documents", data=data + ) + + async def update( + self, + document_id: Union[str, UUID], + file_path: Optional[str] = None, + content: Optional[str] = None, + metadata: Optional[dict] = None, + ingestion_config: Optional[dict] = None, + run_with_orchestration: Optional[bool] = True, + ) -> dict: + """ + Update an existing document. + + Args: + document_id (Union[str, UUID]): ID of document to update + file_path (Optional[str]): Path to the new file + content (Optional[str]): New text content + metadata (Optional[dict]): Updated metadata + ingestion_config (Optional[dict]): Custom ingestion configuration + run_with_orchestration (Optional[bool]): Whether to run with orchestration + + Returns: + dict: Update results containing processed document information + """ + if not file_path and not content: + raise ValueError("Either file_path or content must be provided") + if file_path and content: + raise ValueError("Cannot provide both file_path and content") + + data = {} + files = None + + if metadata: + data["metadata"] = json.dumps([metadata]) + if ingestion_config: + data["ingestion_config"] = json.dumps(ingestion_config) + if run_with_orchestration is not None: + data["run_with_orchestration"] = str(run_with_orchestration) + + if file_path: + # Create a new file instance that will remain open during the request + file_instance = open(file_path, "rb") + files = [ + ( + "file", + (file_path, file_instance, "application/octet-stream"), + ) + ] + try: + result = await self.client._make_request( + "POST", + f"documents/{str(document_id)}", + data=data, + files=files, + ) + finally: + # Ensure we close the file after the request is complete + file_instance.close() + return result + else: + data["content"] = content # type: ignore + return await self.client._make_request( + "POST", f"documents/{str(document_id)}", data=data + ) + + async def retrieve( + self, + document_id: Union[str, UUID], + ) -> dict: + """ + Get a specific document by ID. + + Args: + document_id (Union[str, UUID]): ID of document to retrieve + + Returns: + dict: Document information + """ + return await self.client._make_request( + "GET", f"documents/{str(document_id)}" + ) + + async def list( + self, + document_ids: Optional[List[Union[str, UUID]]] = None, + offset: Optional[int] = 0, + limit: Optional[int] = 100, + ) -> dict: + """ + List documents with pagination. + + Args: + document_ids (Optional[List[Union[str, UUID]]]): Optional list of document IDs to filter by + offset (Optional[int]): Pagination offset + limit (Optional[int]): Maximum number of documents to return + + Returns: + dict: List of documents and pagination information + """ + params = { + "offset": offset, + "limit": limit, + } + if document_ids: + params["document_ids"] = [str(doc_id) for doc_id in document_ids] # type: ignore + + return await self.client._make_request( + "GET", "documents", params=params + ) + + async def download( + self, + document_id: Union[str, UUID], + ) -> BytesIO: + """ + Download a document's file content. + + Args: + document_id (Union[str, UUID]): ID of document to download + + Returns: + BytesIO: File content as a binary stream + """ + return await self.client._make_request( + "GET", f"documents/{str(document_id)}/download" + ) + + async def delete( + self, + document_id: Union[str, UUID], + ) -> None: + """ + Delete a specific document. + + Args: + document_id (Union[str, UUID]): ID of document to delete + """ + await self.client._make_request( + "DELETE", f"documents/{str(document_id)}" + ) + + async def list_chunks( + self, + document_id: Union[str, UUID], + offset: Optional[int] = 0, + limit: Optional[int] = 100, + include_vectors: Optional[bool] = False, + ) -> dict: + """ + Get chunks for a specific document. + + Args: + document_id (Union[str, UUID]): ID of document to retrieve chunks for + offset (Optional[int]): Pagination offset + limit (Optional[int]): Maximum number of chunks to return + include_vectors (Optional[bool]): Whether to include vector embeddings in the response + + Returns: + dict: List of document chunks and pagination information + """ + params = { + "offset": offset, + "limit": limit, + "include_vectors": include_vectors, + } + + return await self.client._make_request( + "GET", f"documents/{str(document_id)}/chunks", params=params + ) + + async def list_collections( + self, + document_id: Union[str, UUID], + offset: Optional[int] = 0, + limit: Optional[int] = 100, + include_vectors: Optional[bool] = False, + ) -> dict: + """ + Get chunks for a specific document. + + Args: + document_id (Union[str, UUID]): ID of document to retrieve chunks for + offset (Optional[int]): Pagination offset + limit (Optional[int]): Maximum number of chunks to return + include_vectors (Optional[bool]): Whether to include vector embeddings in the response + + Returns: + dict: List of document chunks and pagination information + """ + params = { + "offset": offset, + "limit": limit, + "include_vectors": include_vectors, + } + + return await self.client._make_request( + "GET", f"documents/{str(document_id)}/collections", params=params + ) + + async def delete_by_filter( + self, + filters: Dict[str, Any], + ) -> None: + """ + Delete documents based on filters. + + Args: + filters (Dict[str, Any]): Filters to apply when selecting documents to delete + """ + filters_json = json.dumps(filters) + await self.client._make_request( + "DELETE", "documents/filtered", params={"filters": filters_json} + ) + + +class SyncDocumentSDK: + """Synchronous wrapper for DocumentSDK""" + + def __init__(self, async_sdk: DocumentSDK): + self._async_sdk = async_sdk + + # Get all attributes from the instance + for name in dir(async_sdk): + if not name.startswith("_"): # Skip private methods + attr = getattr(async_sdk, name) + # Check if it's a method and if it's async + if callable(attr) and ( + iscoroutinefunction(attr) or isasyncgenfunction(attr) + ): + if isasyncgenfunction(attr): + setattr(self, name, sync_generator_wrapper(attr)) + else: + setattr(self, name, sync_wrapper(attr))