Skip to content

Commit

Permalink
vec index creation endpoint (#1373)
Browse files Browse the repository at this point in the history
* Update graphrag.mdx

* upload files

* create vector index endpoint

* add to fastapi background task

* pre-commit

* move logging

* add api spec, support for all vecs

* pre-commit

* add workflow
  • Loading branch information
shreyaspimpalgaonkar authored Oct 10, 2024
1 parent fde366d commit 0071467
Show file tree
Hide file tree
Showing 16 changed files with 327 additions and 121 deletions.
4 changes: 4 additions & 0 deletions docs/api-reference/endpoint/create_vector_index.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
title: Create Vector Index
openapi: 'POST /v2/create_vector_index'
---
2 changes: 1 addition & 1 deletion docs/api-reference/openapi.json

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion docs/mint.json
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,8 @@
"pages": [
"api-reference/endpoint/ingest_files",
"api-reference/endpoint/ingest_chunks",
"api-reference/endpoint/update_files"
"api-reference/endpoint/update_files",
"api-reference/endpoint/create_vector_index"
]
},
{
Expand Down
4 changes: 4 additions & 0 deletions py/core/base/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
)
from shared.api.models.ingestion.responses import (
IngestionResponse,
CreateVectorIndexResponse,
WrappedIngestionResponse,
WrappedUpdateResponse,
WrappedCreateVectorIndexResponse,
)
from shared.api.models.kg.responses import (
KGCreationResponse,
Expand Down Expand Up @@ -71,6 +73,8 @@
"IngestionResponse",
"WrappedIngestionResponse",
"WrappedUpdateResponse",
"CreateVectorIndexResponse",
"WrappedCreateVectorIndexResponse",
# Restructure Responses
"KGCreationResponse",
"WrappedKGCreationResponse",
Expand Down
71 changes: 71 additions & 0 deletions py/core/main/api/ingestion_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,26 @@
from pydantic import Json

from core.base import R2RException, RawChunk, generate_document_id

from core.base.api.models import (
CreateVectorIndexResponse,
WrappedIngestionResponse,
WrappedUpdateResponse,
WrappedCreateVectorIndexResponse,
)
from core.base.providers import OrchestrationProvider, Workflow

from ..services.ingestion_service import IngestionService
from .base_router import BaseRouter, RunType

from shared.abstractions.vector import (
IndexMethod,
IndexArgsIVFFlat,
IndexArgsHNSW,
VectorTableName,
IndexMeasure,
)

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -52,6 +63,11 @@ def _register_workflows(self):
if self.orchestration_provider.config.provider != "simple"
else "Update task queued successfully."
),
"create-vector-index": (
"Vector index creation task queued successfully."
if self.orchestration_provider.config.provider != "simple"
else "Vector index creation task completed successfully."
),
},
)

Expand Down Expand Up @@ -310,6 +326,61 @@ async def ingest_chunks_app(
raw_message["document_id"] = str(document_id)
return raw_message # type: ignore

@self.router.post("/create_vector_index")
@self.base_endpoint
async def create_vector_index_app(
table_name: Optional[VectorTableName] = Body(
default=VectorTableName.CHUNKS,
description="The name of the vector table to create.",
),
index_method: IndexMethod = Body(
default=IndexMethod.hnsw,
description="The type of vector index to create.",
),
measure: IndexMeasure = Body(
default=IndexMeasure.cosine_distance,
description="The measure for the index.",
),
index_arguments: Optional[
Union[IndexArgsIVFFlat, IndexArgsHNSW]
] = Body(
None,
description="The arguments for the index method.",
),
replace: bool = Body(
default=False,
description="Whether to replace an existing index.",
),
concurrently: bool = Body(
default=False,
description="Whether to create the index concurrently.",
),
auth_user=Depends(self.service.providers.auth.auth_wrapper),
) -> WrappedCreateVectorIndexResponse:

logger.info(
f"Creating vector index for {table_name} with method {index_method}, measure {measure}, replace {replace}, concurrently {concurrently}"
)

raw_message = await self.orchestration_provider.run_workflow(
"create-vector-index",
{
"request": {
"table_name": table_name,
"index_method": index_method,
"measure": measure,
"index_arguments": index_arguments,
"replace": replace,
"concurrently": concurrently,
},
},
options={
"additional_metadata": {},
},
)

return raw_message # type: ignore

@staticmethod
async def _process_files(files):
import base64
Expand Down
27 changes: 27 additions & 0 deletions py/core/main/orchestration/hatchet/ingestion_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,11 +437,38 @@ async def on_failure(self, context: Context) -> None:
f"Failed to update document status for {document_id}: {e}"
)

@orchestration_provider.workflow(
name="create-vector-index", timeout="360m"
)
class HatchetCreateVectorIndexWorkflow:
def __init__(self, ingestion_service: IngestionService):
self.ingestion_service = ingestion_service

@orchestration_provider.step(timeout="60m")
async def create_vector_index(self, context: Context) -> dict:
input_data = context.workflow_input()["request"]
parsed_data = (
IngestionServiceAdapter.parse_create_vector_index_input(
input_data
)
)

self.ingestion_service.providers.database.vector.create_index(
**parsed_data
)

return {
"status": "Vector index creation queued successfully.",
}

ingest_files_workflow = HatchetIngestFilesWorkflow(service)
update_files_workflow = HatchetUpdateFilesWorkflow(service)
ingest_chunks_workflow = HatchetIngestChunksWorkflow(service)
create_vector_index_workflow = HatchetCreateVectorIndexWorkflow(service)

return {
"ingest_files": ingest_files_workflow,
"update_files": update_files_workflow,
"ingest_chunks": ingest_chunks_workflow,
"create_vector_index": create_vector_index_workflow,
}
20 changes: 20 additions & 0 deletions py/core/main/orchestration/simple/ingestion_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,28 @@ async def ingest_chunks(input_data):
message=f"Error during chunk ingestion: {str(e)}",
)

async def create_vector_index(input_data):

try:
from core.main import IngestionServiceAdapter

parsed_data = (
IngestionServiceAdapter.parse_create_vector_index_input(
input_data
)
)

service.providers.database.vector.create_index(**parsed_data)

except Exception as e:
raise R2RException(
status_code=500,
message=f"Error during vector index creation: {str(e)}",
)

return {
"ingest-files": ingest_files,
"update-files": update_files,
"ingest-chunks": ingest_chunks,
"create-vector-index": create_vector_index,
}
16 changes: 16 additions & 0 deletions py/core/main/services/ingestion_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@
from core.telemetry.telemetry_decorator import telemetry_event

from ..abstractions import R2RAgents, R2RPipelines, R2RPipes, R2RProviders
from shared.abstractions.vector import (
IndexMethod,
IndexMeasure,
VectorTableName,
)
from ..config import R2RConfig
from .base import Service

Expand Down Expand Up @@ -377,3 +382,14 @@ def parse_update_files_input(data: dict) -> dict:
"file_sizes_in_bytes": data["file_sizes_in_bytes"],
"file_datas": data["file_datas"],
}

@staticmethod
def parse_create_vector_index_input(data: dict) -> dict:
return {
"table_name": VectorTableName(data["table_name"]),
"index_method": IndexMethod(data["index_method"]),
"measure": IndexMeasure(data["measure"]),
"index_arguments": data["index_arguments"],
"replace": data["replace"],
"concurrently": data["concurrently"],
}
8 changes: 0 additions & 8 deletions py/core/providers/database/vecs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,13 @@
from .client import Client
from .collection import (
Collection,
IndexArgsHNSW,
IndexArgsIVFFlat,
IndexMeasure,
IndexMethod,
)

__project__ = "vecs"
__version__ = "0.4.2"


__all__ = [
"IndexArgsIVFFlat",
"IndexArgsHNSW",
"IndexMethod",
"IndexMeasure",
"Collection",
"Client",
"exc",
Expand Down
Loading

0 comments on commit 0071467

Please sign in to comment.