Skip to content

Commit

Permalink
Feature/add documents search (#1549)
Browse files Browse the repository at this point in the history
* up

* rm extra printout

* bump release

* add ot js

* return metadata
  • Loading branch information
emrgnt-cmplxty authored Nov 1, 2024
1 parent 6fe18bd commit 9b67dde
Show file tree
Hide file tree
Showing 21 changed files with 358 additions and 3 deletions.
41 changes: 41 additions & 0 deletions js/sdk/src/r2rClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1873,6 +1873,47 @@ export class r2rClient {
//
// -----------------------------------------------------------------------------

/**
* Search over documents.
* @param query The query to search for.
* @param settings Settings for the document search.
* @returns A promise that resolves to the response from the server.
*/
@feature("searchDocuments")
async searchDocuments(
query: string,
settings?: {
searchOverMetadata?: boolean;
metadataKeys?: string[];
searchOverBody?: boolean;
filters?: Record<string, any>;
searchFilters?: Record<string, any>;
offset?: number;
limit?: number;
titleWeight?: number;
metadataWeight?: number;
},
): Promise<any> {
this._ensureAuthenticated();

const json_data: Record<string, any> = {
query,
settings: {
search_over_metadata: settings?.searchOverMetadata ?? true,
metadata_keys: settings?.metadataKeys ?? ["title"],
search_over_body: settings?.searchOverBody ?? false,
filters: settings?.filters ?? {},
search_filters: settings?.searchFilters ?? {},
offset: settings?.offset ?? 0,
limit: settings?.limit ?? 10,
title_weight: settings?.titleWeight ?? 0.5,
metadata_weight: settings?.metadataWeight ?? 0.5,
},
};

return await this._makeRequest("POST", "search_documents", { data: json_data });
}

/**
* Conduct a vector and/or KG search.
* @param query The query to search for.
Expand Down
1 change: 1 addition & 0 deletions py/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
"KGSearchSettings",
"VectorSearchResult",
"VectorSearchSettings",
"DocumentSearchSettings",
"HybridSearchSettings",
# User abstractions
"Token",
Expand Down
1 change: 1 addition & 0 deletions py/core/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
"KGSearchSettings",
"VectorSearchResult",
"VectorSearchSettings",
"DocumentSearchSettings",
"HybridSearchSettings",
# KG abstractions
"KGCreationSettings",
Expand Down
2 changes: 2 additions & 0 deletions py/core/base/abstractions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from shared.abstractions.prompt import Prompt
from shared.abstractions.search import (
AggregateSearchResult,
DocumentSearchSettings,
HybridSearchSettings,
KGCommunityResult,
KGEntityResult,
Expand Down Expand Up @@ -129,6 +130,7 @@
"KGSearchSettings",
"VectorSearchResult",
"VectorSearchSettings",
"DocumentSearchSettings",
"HybridSearchSettings",
# KG abstractions
"KGCreationSettings",
Expand Down
2 changes: 2 additions & 0 deletions py/core/base/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
RAGResponse,
SearchResponse,
WrappedCompletionResponse,
WrappedDocumentSearchResponse,
WrappedRAGAgentResponse,
WrappedRAGResponse,
WrappedSearchResponse,
Expand Down Expand Up @@ -143,6 +144,7 @@
"SearchResponse",
"RAGResponse",
"RAGAgentResponse",
"WrappedDocumentSearchResponse",
"WrappedSearchResponse",
"WrappedCompletionResponse",
"WrappedRAGResponse",
Expand Down
12 changes: 12 additions & 0 deletions py/core/base/providers/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
)
from core.base.abstractions import (
DocumentInfo,
DocumentSearchSettings,
IndexArgsHNSW,
IndexArgsIVFFlat,
IndexMeasure,
Expand Down Expand Up @@ -520,6 +521,12 @@ async def full_text_search(
) -> list[VectorSearchResult]:
pass

@abstractmethod
async def search_documents(
self, query_text: str, settings: DocumentSearchSettings
) -> list[dict]:
pass

@abstractmethod
async def hybrid_search(
self,
Expand Down Expand Up @@ -1419,6 +1426,11 @@ async def hybrid_search(
query_text, query_vector, search_settings, *args, **kwargs
)

async def search_documents(
self, query_text: str, settings: DocumentSearchSettings
) -> list[dict]:
return await self.vector_handler.search_documents(query_text, settings)

async def delete(
self, filters: dict[str, Any]
) -> dict[str, dict[str, str]]:
Expand Down
4 changes: 4 additions & 0 deletions py/core/configs/full_azure.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
[completion.generation_config]
model = "azure/gpt-4o"

[agent]
[agent.generation_config]
model = "azure/gpt-4o"

# KG settings
batch_size = 256

Expand Down
1 change: 1 addition & 0 deletions py/core/configs/local_llm.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,5 +64,6 @@ provider = "simple"
[ingestion]
vision_img_model = "ollama/llama3.2-vision"
vision_pdf_model = "ollama/llama3.2-vision"

[ingestion.extra_parsers]
pdf = "zerox"
4 changes: 4 additions & 0 deletions py/core/configs/r2r_azure.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
# A config which overrides all instances of `openai` with `azure` in the `r2r.toml` config
[agent]
[agent.generation_config]
model = "azure/gpt-4o"

[completion]
[completion.generation_config]
model = "azure/gpt-4o"
Expand Down
33 changes: 33 additions & 0 deletions py/core/main/api/retrieval_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from fastapi.responses import StreamingResponse

from core.base import (
DocumentSearchSettings,
GenerationConfig,
KGSearchSettings,
Message,
Expand All @@ -16,6 +17,7 @@
)
from core.base.api.models import (
WrappedCompletionResponse,
WrappedDocumentSearchResponse,
WrappedRAGAgentResponse,
WrappedRAGResponse,
WrappedSearchResponse,
Expand Down Expand Up @@ -100,6 +102,37 @@ def _setup_routes(self):
search_extras = self.openapi_extras.get("search", {})
search_descriptions = search_extras.get("input_descriptions", {})

@self.router.post(
"/search_documents",
openapi_extra=search_extras.get("openapi_extra"),
)
@self.base_endpoint
async def search_documents(
query: str = Body(
..., description=search_descriptions.get("query")
),
settings: DocumentSearchSettings = Body(
default_factory=DocumentSearchSettings,
description="Settings for document search",
),
auth_user=Depends(self.service.providers.auth.auth_wrapper),
) -> WrappedDocumentSearchResponse: # type: ignore
"""
Perform a search query on the vector database and knowledge graph.
This endpoint allows for complex filtering of search results using PostgreSQL-based queries.
Filters can be applied to various fields such as document_id, and internal metadata values.
Allowed operators include `eq`, `neq`, `gt`, `gte`, `lt`, `lte`, `like`, `ilike`, `in`, and `nin`.
"""

results = await self.service.search_documents(
query=query,
settings=settings,
)
return results

@self.router.post(
"/search",
openapi_extra=search_extras.get("openapi_extra"),
Expand Down
14 changes: 14 additions & 0 deletions py/core/main/services/retrieval_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from core import R2RStreamingRAGAgent
from core.base import (
DocumentSearchSettings,
EmbeddingPurpose,
GenerationConfig,
KGSearchSettings,
Message,
Expand Down Expand Up @@ -114,6 +116,18 @@ async def search(

return results.as_dict()

@telemetry_event("SearchDocuments")
async def search_documents(
self,
query: str,
settings: DocumentSearchSettings,
) -> list[dict]:

return await self.providers.database.search_documents(
query_text=query,
settings=settings,
)

@telemetry_event("Completion")
async def completion(
self,
Expand Down
Loading

0 comments on commit 9b67dde

Please sign in to comment.