Skip to content

Commit

Permalink
checkin document router
Browse files Browse the repository at this point in the history
  • Loading branch information
emrgnt-cmplxty committed Oct 31, 2024
1 parent 7635180 commit 1a71b1b
Show file tree
Hide file tree
Showing 18 changed files with 618 additions and 133 deletions.
103 changes: 103 additions & 0 deletions py/core/main/api/v3/base_router.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import functools
import logging
from abc import abstractmethod
from typing import Callable, Union

from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse

from core.base import R2RException, manage_run
from core.base.logger.base import RunType
from core.providers import (
HatchetOrchestrationProvider,
SimpleOrchestrationProvider,
)

from ...services.base import Service

logger = logging.getLogger()


class BaseRouterV3:
def __init__(self, providers, services, orchestration_provider, run_type):
self.providers = providers
self.services = services
self.run_type = run_type
self.orchestration_provider = orchestration_provider
self.router = APIRouter()
self.openapi_extras = self._load_openapi_extras()
self._setup_routes()
self._register_workflows()

def get_router(self):
return self.router

def base_endpoint(self, func: Callable):
@functools.wraps(func)
async def wrapper(*args, **kwargs):
async with manage_run(
self.services["ingestion"].run_manager, func.__name__
) as run_id:
auth_user = kwargs.get("auth_user")
if auth_user:
await self.services[
"ingestion"
].run_manager.log_run_info( # TODO - this is a bit of a hack
run_type=self.run_type,
user=auth_user,
)

try:
func_result = await func(*args, **kwargs)
if (
isinstance(func_result, tuple)
and len(func_result) == 2
):
results, outer_kwargs = func_result
else:
results, outer_kwargs = func_result, {}

if isinstance(results, StreamingResponse):
return results
return {"results": results, **outer_kwargs}

except R2RException:
raise

except Exception as e:

await self.services["ingestion"].logging_connection.log(
run_id=run_id,
key="error",
value=str(e),
)

logger.error(
f"Error in base endpoint {func.__name__}() - \n\n{str(e)}",
exc_info=True,
)

raise HTTPException(
status_code=500,
detail={
"message": f"An error '{e}' occurred during {func.__name__}",
"error": str(e),
"error_type": type(e).__name__,
},
) from e

return wrapper

@classmethod
def build_router(cls, engine):
return cls(engine).router

@abstractmethod
def _setup_routes(self):
pass

def _register_workflows(self):
pass

def _load_openapi_extras(self):
return {}
14 changes: 14 additions & 0 deletions py/core/main/api/v3/chunk_responses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from typing import Any, Optional
from uuid import UUID

from pydantic import BaseModel


class ChunkResponse(BaseModel):
document_id: UUID
extraction_id: UUID
user_id: UUID
collection_ids: list[UUID]
text: str
metadata: dict[str, Any]
vector: Optional[list[float]] = None
8 changes: 8 additions & 0 deletions py/core/main/api/v3/document_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,11 @@ class DocumentResponse(BaseModel):
version: str
collection_ids: list[UUID]
metadata: dict[str, Any]


class CollectionResponse(BaseModel):
collection_id: UUID
name: str
description: Optional[str]
created_at: datetime
updated_at: datetime
Loading

0 comments on commit 1a71b1b

Please sign in to comment.