From ab1d20561b8490843ab22e453f1459854e2faee0 Mon Sep 17 00:00:00 2001 From: Nolan Tremelling <34580718+NolanTrem@users.noreply.github.com> Date: Thu, 5 Dec 2024 12:06:01 -0800 Subject: [PATCH] Fix agent bug, remove alias (#1661) --- js/sdk/package-lock.json | 2 +- js/sdk/package.json | 2 +- js/sdk/src/models.tsx | 75 ----------- js/sdk/src/r2rClient.ts | 27 ++-- js/sdk/src/types.ts | 54 ++++++-- js/sdk/src/v3/clients/retrieval.ts | 60 ++++----- py/core/main/services/retrieval_service.py | 138 +++++++++++++-------- py/shared/abstractions/graph.py | 2 - py/shared/abstractions/kg.py | 1 - py/shared/abstractions/llm.py | 4 - py/shared/abstractions/search.py | 15 --- py/shared/abstractions/user.py | 2 - 12 files changed, 172 insertions(+), 210 deletions(-) diff --git a/js/sdk/package-lock.json b/js/sdk/package-lock.json index 168df2af0..c7db76bbc 100644 --- a/js/sdk/package-lock.json +++ b/js/sdk/package-lock.json @@ -1,6 +1,6 @@ { "name": "r2r-js", - "version": "0.4.3", + "version": "0.4.5", "lockfileVersion": 3, "requires": true, "packages": { diff --git a/js/sdk/package.json b/js/sdk/package.json index 0f88a0bb3..eb8008feb 100644 --- a/js/sdk/package.json +++ b/js/sdk/package.json @@ -1,6 +1,6 @@ { "name": "r2r-js", - "version": "0.4.4", + "version": "0.4.5", "description": "", "main": "dist/index.js", "browser": "dist/index.browser.js", diff --git a/js/sdk/src/models.tsx b/js/sdk/src/models.tsx index f0d467a27..973af86bd 100644 --- a/js/sdk/src/models.tsx +++ b/js/sdk/src/models.tsx @@ -17,80 +17,11 @@ export interface RefreshTokenResponse { }; } -export interface GenerationConfig { - model?: string; - temperature?: number; - topP?: number; - maxTokensToSample?: number; - stream?: boolean; - functions?: Array>; - tools?: Array>; - addGenerationKwargs?: Record; - apiBase?: string; - responseFormat?: string; -} - -export interface HybridSearchSettings { - fullTextWeight: number; - semanticWeight: number; - fullTextLimit: number; - rrfK: number; -} - -export interface ChunkSearchSettings { - useVectorSearch?: boolean; - useHybridSearch?: boolean; - filters?: Record; - searchLimit?: number; - offset?: number; - selectedCollectionIds?: string[]; - indexMeasure: IndexMeasure; - includeScores?: boolean; - includeMetadatas?: boolean; - probes?: number; - efSearch?: number; - hybridSearchSettings?: HybridSearchSettings; - searchStrategy?: string; -} - -export interface KGSearchSettings { - useKgSearch?: boolean; - filters?: Record; - selectedCollectionIds?: string[]; - graphragMapSystemPrompt?: string; - kgSearchType?: "local"; - kgSearchLevel?: number | null; - generationConfig?: GenerationConfig; - maxCommunityDescriptionLength?: number; - maxLlmQueriesForGlobalSearch?: number; - localSearchLimits?: Record; -} - export enum KGRunType { ESTIMATE = "estimate", RUN = "run", } -export interface KGCreationSettings { - kgRelationshipsExtractionPrompt?: string; - kgEntityDescriptionPrompt?: string; - forceKgCreation?: boolean; - entityTypes?: string[]; - relationTypes?: string[]; - extractionsMergeCount?: number; - maxKnowledgeRelationships?: number; - maxDescriptionInputLength?: number; - generationConfig?: GenerationConfig; -} - -export interface KGEnrichmentSettings { - forceKgEnrichment?: boolean; - communityReportsPrompt?: string; - maxSummaryInputLength?: number; - generationConfig?: GenerationConfig; - leidenParams?: Record; -} - export interface KGEntityDeduplicationSettings { kgEntityDeduplicationType?: KGEntityDeduplicationType; } @@ -121,12 +52,6 @@ export interface R2RDocumentChunksRequest { documentId: string; } -export enum IndexMeasure { - COSINE_DISTANCE = "cosine_distance", - L2_DISTANCE = "l2_distance", - MAX_INNER_PRODUCT = "max_inner_product", -} - export interface RawChunk { text: string; } diff --git a/js/sdk/src/r2rClient.ts b/js/sdk/src/r2rClient.ts index 3f9ed4bb4..d363c62cb 100644 --- a/js/sdk/src/r2rClient.ts +++ b/js/sdk/src/r2rClient.ts @@ -27,13 +27,8 @@ import { TokenInfo, Message, RefreshTokenResponse, - ChunkSearchSettings, - KGSearchSettings, KGRunType, - KGCreationSettings, - KGEnrichmentSettings, KGEntityDeduplicationSettings, - GenerationConfig, RawChunk, } from "./models"; @@ -1615,7 +1610,7 @@ export class r2rClient extends BaseClient { async createGraph( collection_id?: string, run_type?: KGRunType, - graph_creation_settings?: KGCreationSettings | Record, + graph_creation_settings?: Record, ): Promise> { this._ensureAuthenticated(); @@ -1643,7 +1638,7 @@ export class r2rClient extends BaseClient { async enrichGraph( collection_id?: string, run_type?: KGRunType, - graph_enrichment_settings?: KGEnrichmentSettings | Record, + graph_enrichment_settings?: Record, ): Promise { this._ensureAuthenticated(); @@ -1866,7 +1861,7 @@ export class r2rClient extends BaseClient { @feature("searchDocuments") async searchDocuments( query: string, - vector_search_settings?: ChunkSearchSettings | Record, + vector_search_settings?: Record, ): Promise { this._ensureAuthenticated(); const json_data: Record = { @@ -1894,8 +1889,8 @@ export class r2rClient extends BaseClient { @feature("search") async search( query: string, - vector_search_settings?: ChunkSearchSettings | Record, - graph_search_settings?: KGSearchSettings | Record, + vector_search_settings?: Record, + graph_search_settings?: Record, ): Promise { this._ensureAuthenticated(); @@ -1926,9 +1921,9 @@ export class r2rClient extends BaseClient { @feature("rag") async rag( query: string, - vector_search_settings?: ChunkSearchSettings | Record, - graph_search_settings?: KGSearchSettings | Record, - rag_generation_config?: GenerationConfig | Record, + vector_search_settings?: Record, + graph_search_settings?: Record, + rag_generation_config?: Record, task_prompt_override?: string, include_title_if_available?: boolean, ): Promise> { @@ -1986,9 +1981,9 @@ export class r2rClient extends BaseClient { @feature("agent") async agent( messages: Message[], - rag_generation_config?: GenerationConfig | Record, - vector_search_settings?: ChunkSearchSettings | Record, - graph_search_settings?: KGSearchSettings | Record, + rag_generation_config?: Record, + vector_search_settings?: Record, + graph_search_settings?: Record, task_prompt_override?: string, include_title_if_available?: boolean, conversation_id?: string, diff --git a/js/sdk/src/types.ts b/js/sdk/src/types.ts index b1fda1a61..7dee4c40a 100644 --- a/js/sdk/src/types.ts +++ b/js/sdk/src/types.ts @@ -1,9 +1,3 @@ -import { - ChunkSearchSettings, - GenerationConfig, - HybridSearchSettings, -} from "./models"; - export interface UnprocessedChunk { id: string; document_id?: string; @@ -134,6 +128,13 @@ export interface GraphResponse { updated_at: string; } +// Index types +export enum IndexMeasure { + COSINE_DISTANCE = "cosine_distance", + L2_DISTANCE = "l2_distance", + MAX_INNER_PRODUCT = "max_inner_product", +} + // Ingestion types export interface IngestionResponse { message: string; @@ -184,8 +185,41 @@ export interface RelationshipResponse { } // Retrieval types +export interface ChunkSearchSettings { + index_measure?: IndexMeasure; + probes?: number; + ef_search?: number; + enabled?: boolean; +} + +export interface GenerationConfig { + model?: string; + temperature?: number; + top_p?: number; + max_tokens_to_sample?: number; + stream?: boolean; + functions?: Array>; + tools?: Array>; + add_generation_kwargs?: Record; + api_base?: string; + response_format?: string; +} + +export interface HybridSearchSettings { + full_text_weight?: number; + semantic_weight?: number; + full_text_limit?: number; + rrf_k?: number; +} + export interface GraphSearchSettings { generation_config?: GenerationConfig; + graphrag_map_system?: string; + graphrag_reduce_system?: string; + max_community_description_length?: number; + max_llm_queries_for_global_search?: number; + limits?: Record; + enabled?: boolean; } export interface SearchSettings { @@ -196,7 +230,7 @@ export interface SearchSettings { limit?: number; offset?: number; include_metadata?: boolean; - include_scores: boolean; + include_scores?: boolean; search_strategy?: string; hybrid_settings?: HybridSearchSettings; chunk_settings?: ChunkSearchSettings; @@ -212,7 +246,11 @@ export interface VectorSearchResult { metadata?: Record; } -export type KGSearchResultType = "entity" | "relationship" | "community" | "global"; +export type KGSearchResultType = + | "entity" + | "relationship" + | "community" + | "global"; export interface GraphSearchResult { content: any; diff --git a/js/sdk/src/v3/clients/retrieval.ts b/js/sdk/src/v3/clients/retrieval.ts index a248bb541..405ae47b8 100644 --- a/js/sdk/src/v3/clients/retrieval.ts +++ b/js/sdk/src/v3/clients/retrieval.ts @@ -1,13 +1,12 @@ import { r2rClient } from "../../r2rClient"; +import { Message } from "../../models"; +import { feature } from "../../feature"; import { - Message, - ChunkSearchSettings, - KGSearchSettings, + SearchSettings, + WrappedSearchResponse, GenerationConfig, -} from "../../models"; -import { feature } from "../../feature"; -import { SearchSettings, WrappedSearchResponse } from "../../types"; +} from "../../types"; export class RetrievalClient { constructor(private client: r2rClient) {} @@ -23,8 +22,7 @@ export class RetrievalClient { * Allowed operators include: `eq`, `neq`, `gt`, `gte`, `lt`, `lte`, * `like`, `ilike`, `in`, and `nin`. * @param query Search query to find relevant documents - * @param VectorSearchSettings Settings for vector-based search - * @param KGSearchSettings Settings for knowledge graph search + * @param searchSettings Settings for the search * @returns */ @feature("retrieval.search") @@ -35,7 +33,7 @@ export class RetrievalClient { const data = { query: options.query, ...(options.searchSettings && { - searchSettings: options.searchSettings, + search_settings: options.searchSettings, }), }; @@ -53,9 +51,8 @@ export class RetrievalClient { * * The generation process can be customized using the `rag_generation_config` parameter. * @param query + * @param searchSettings Settings for the search * @param ragGenerationConfig Configuration for RAG generation - * @param vectorSearchSettings Settings for vector-based search - * @param kgSearchSettings Settings for knowledge graph search * @param taskPromptOverride Optional custom prompt to override default * @param includeTitleIfAvailable Include document titles in responses when available * @returns @@ -63,28 +60,24 @@ export class RetrievalClient { @feature("retrieval.rag") async rag(options: { query: string; + searchSettings?: SearchSettings | Record; ragGenerationConfig?: GenerationConfig | Record; - vectorSearchSettings?: ChunkSearchSettings | Record; - kgSearchSettings?: KGSearchSettings | Record; taskPromptOverride?: string; includeTitleIfAvailable?: boolean; }): Promise> { const data = { query: options.query, - ...(options.vectorSearchSettings && { - vectorSearchSettings: options.vectorSearchSettings, + ...(options.searchSettings && { + search_settings: options.searchSettings, }), ...(options.ragGenerationConfig && { - ragGenerationConfig: options.ragGenerationConfig, - }), - ...(options.kgSearchSettings && { - kgSearchSettings: options.kgSearchSettings, + rag_generation_config: options.ragGenerationConfig, }), ...(options.taskPromptOverride && { - taskPromptOverride: options.taskPromptOverride, + task_prompt_override: options.taskPromptOverride, }), ...(options.includeTitleIfAvailable && { - includeTitleIfAvailable: options.includeTitleIfAvailable, + include_title_if_available: options.includeTitleIfAvailable, }), }; @@ -151,9 +144,8 @@ export class RetrievalClient { * find and synthesize information, providing detailed, factual responses * with proper attribution to source documents. * @param message Current message to process + * @param searchSettings Settings for the search * @param ragGenerationConfig Configuration for RAG generation - * @param vectorSearchSettings Settings for vector-based search - * @param kgSearchSettings Settings for knowledge graph search * @param taskPromptOverride Optional custom prompt to override default * @param includeTitleIfAvailable Include document titles in responses when available * @param conversationId ID of the conversation @@ -163,9 +155,8 @@ export class RetrievalClient { @feature("retrieval.agent") async agent(options: { message: Message; + searchSettings?: SearchSettings | Record; ragGenerationConfig?: GenerationConfig | Record; - vectorSearchSettings?: ChunkSearchSettings | Record; - kgSearchSettings?: KGSearchSettings | Record; taskPromptOverride?: string; includeTitleIfAvailable?: boolean; conversationId?: string; @@ -173,26 +164,23 @@ export class RetrievalClient { }): Promise> { const data: Record = { message: options.message, - ...(options.vectorSearchSettings && { - vectorSearchSettings: options.vectorSearchSettings, - }), - ...(options.kgSearchSettings && { - kgSearchSettings: options.kgSearchSettings, + ...(options.searchSettings && { + search_settings: options.searchSettings, }), ...(options.ragGenerationConfig && { - ragGenerationConfig: options.ragGenerationConfig, + rag_generation_config: options.ragGenerationConfig, }), ...(options.taskPromptOverride && { - taskPromptOverride: options.taskPromptOverride, + task_prompt_override: options.taskPromptOverride, }), ...(options.includeTitleIfAvailable && { - includeTitleIfAvailable: options.includeTitleIfAvailable, + include_title_if_available: options.includeTitleIfAvailable, }), ...(options.conversationId && { - conversationId: options.conversationId, + conversation_id: options.conversationId, }), ...(options.branchId && { - branchId: options.branchId, + branch_id: options.branchId, }), }; @@ -243,7 +231,7 @@ export class RetrievalClient { const data = { messages: options.messages, ...(options.generationConfig && { - generationConfig: options.generationConfig, + generation_config: options.generationConfig, }), }; diff --git a/py/core/main/services/retrieval_service.py b/py/core/main/services/retrieval_service.py index 461b6d6ae..9d15b4a3c 100644 --- a/py/core/main/services/retrieval_service.py +++ b/py/core/main/services/retrieval_service.py @@ -270,67 +270,96 @@ async def agent( message="Either message or messages should be provided", ) + # Ensure 'message' is a Message instance + if message and not isinstance(message, Message): + if isinstance(message, dict): + message = Message.from_dict(message) + else: + raise R2RException( + status_code=400, + message="Invalid message format", + ) + + # Ensure 'messages' is a list of Message instances + if messages: + messages = [ + ( + msg + if isinstance(msg, Message) + else Message.from_dict(msg) + ) + for msg in messages + ] + else: + messages = [] + # Transform UUID filters to strings - for filter, value in search_settings.filters.items(): + for filter_key, value in search_settings.filters.items(): if isinstance(value, UUID): - search_settings.filters[filter] = str(value) + search_settings.filters[filter_key] = str(value) - ids = None + ids = [] - if not messages: - if not message: - raise R2RException( - status_code=400, - message="Message not provided", + if conversation_id: + # Fetch existing conversation + conversation = ( + await self.logging_connection.get_conversation( + conversation_id, branch_id ) - # Fetch or create conversation - if conversation_id: - conversation = ( - await self.logging_connection.get_conversation( - conversation_id, branch_id - ) + ) + if not conversation: + logger.error( + f"No conversation found for ID: {conversation_id}" + ) + raise R2RException( + status_code=404, + message=f"Conversation not found: {conversation_id}", ) - if not conversation: + # Assuming 'conversation' is a list of dicts with 'id' and 'message' keys + messages_from_conversation = [] + for resp in conversation: + if isinstance(resp, dict): + msg = Message.from_dict(resp["message"]) + messages_from_conversation.append(msg) + ids.append(resp["id"]) + else: logger.error( - f"No conversation found for ID: {conversation_id}" - ) - raise R2RException( - status_code=404, - message=f"Conversation not found: {conversation_id}", + f"Unexpected type in conversation: {type(resp)}" ) - messages = [resp.message for resp in conversation] + [ # type: ignore - message - ] - ids = [resp.id for resp in conversation] - else: - conversation = ( - await self.logging_connection.create_conversation() - ) - conversation_id = conversation["id"] + messages = messages_from_conversation + messages + else: + # Create new conversation + conversation_id = ( + await self.logging_connection.create_conversation() + ) + ids = [] + # messages already initialized earlier - parent_id = None - if conversation_id and messages: - for inner_message in messages[:-1]: - parent_id = await self.logging_connection.add_message( - conversation_id, # Use the stored conversation_id - inner_message, - parent_id, - ) - messages = messages or [] + # Append 'message' to 'messages' if provided + if message: + messages.append(message) - if message and not messages: - messages = [message] + if not messages: + raise R2RException( + status_code=400, + message="No messages to process", + ) - current_message = messages[-1] # type: ignore + current_message = messages[-1] # Save the new message to the conversation - message = await self.logging_connection.add_message( - conversation_id, # type: ignore - current_message, # type: ignore - parent_id=str(ids[-2]) if (ids and len(ids) > 1) else None, # type: ignore + parent_id = ids[-1] if ids else None + + message_response = await self.logging_connection.add_message( + conversation_id, + current_message, + parent_id=parent_id, ) - if message is not None: - message_id = message["id"] # type: ignore + + if message_response is not None: + message_id = message_response["id"] + else: + message_id = None if rag_generation_config.stream: t1 = time.time() @@ -372,9 +401,20 @@ async def stream_response(): *args, **kwargs, ) + + # Save the assistant's reply to the conversation + if isinstance(results[-1], dict): + assistant_message = Message(**results[-1]) + elif isinstance(results[-1], Message): + assistant_message = results[-1] + else: + assistant_message = Message( + role="assistant", content=str(results[-1]) + ) + await self.logging_connection.add_message( conversation_id=conversation_id, - content=Message(**results[-1]), + content=assistant_message, parent_id=message_id, ) @@ -387,7 +427,7 @@ async def stream_response(): value=latency, ) return { - "messages": results, + "messages": [msg.to_dict() for msg in results], "conversation_id": str( conversation_id ), # Ensure it's a string diff --git a/py/shared/abstractions/graph.py b/py/shared/abstractions/graph.py index 4bdc9dd62..f14dba462 100644 --- a/py/shared/abstractions/graph.py +++ b/py/shared/abstractions/graph.py @@ -113,11 +113,9 @@ class Graph(R2RSerializable): name: str description: Optional[str] = None created_at: datetime = Field( - alias="createdAt", default_factory=datetime.utcnow, ) updated_at: datetime = Field( - alias="updatedAt", default_factory=datetime.utcnow, ) status: str = "pending" diff --git a/py/shared/abstractions/kg.py b/py/shared/abstractions/kg.py index bb9a06fcf..adfb46c57 100644 --- a/py/shared/abstractions/kg.py +++ b/py/shared/abstractions/kg.py @@ -171,7 +171,6 @@ class GraphCommunitySettings(R2RSerializable): graphrag_communities: str = Field( default="graphrag_communities", description="The prompt to use for knowledge graph enrichment.", - alias="graphrag_communities", # TODO - mark deprecated & remove ) max_summary_input_length: int = Field( diff --git a/py/shared/abstractions/llm.py b/py/shared/abstractions/llm.py index 2f747953e..677ef865a 100644 --- a/py/shared/abstractions/llm.py +++ b/py/shared/abstractions/llm.py @@ -52,13 +52,11 @@ class GenerationConfig(R2RSerializable): ) top_p: float = Field( default_factory=lambda: GenerationConfig._defaults["top_p"], - alias="topP", ) max_tokens_to_sample: int = Field( default_factory=lambda: GenerationConfig._defaults[ "max_tokens_to_sample" ], - alias="maxTokensToSample", ) stream: bool = Field( default_factory=lambda: GenerationConfig._defaults["stream"] @@ -73,11 +71,9 @@ class GenerationConfig(R2RSerializable): default_factory=lambda: GenerationConfig._defaults[ "add_generation_kwargs" ], - alias="addGenerationKwargs", ) api_base: Optional[str] = Field( default_factory=lambda: GenerationConfig._defaults["api_base"], - alias="apiBase", ) response_format: Optional[dict | BaseModel] = None diff --git a/py/shared/abstractions/search.py b/py/shared/abstractions/search.py index a27d546ad..59a7398b5 100644 --- a/py/shared/abstractions/search.py +++ b/py/shared/abstractions/search.py @@ -257,7 +257,6 @@ class ChunkSearchSettings(R2RSerializable): """Settings specific to chunk/vector search.""" index_measure: IndexMeasure = Field( - alias="indexMeasure", default=IndexMeasure.cosine_distance, description="The distance measure to use for indexing", ) @@ -266,7 +265,6 @@ class ChunkSearchSettings(R2RSerializable): description="Number of ivfflat index lists to query. Higher increases accuracy but decreases speed.", ) ef_search: int = Field( - alias="efSearch", default=40, description="Size of the dynamic candidate list for HNSW index search. Higher increases accuracy but decreases speed.", ) @@ -280,30 +278,24 @@ class GraphSearchSettings(R2RSerializable): """Settings specific to knowledge graph search.""" generation_config: GenerationConfig = Field( - alias="generationConfig", default_factory=GenerationConfig, description="Configuration for text generation during graph search.", ) graphrag_map_system: str = Field( - alias="graphragMapSystem", default="graphrag_map_system", description="The system prompt for the graphrag map prompt.", ) graphrag_reduce_system: str = Field( - alias="graphragReduceSystem", default="graphrag_reduce_system", description="The system prompt for the graphrag reduce prompt.", ) max_community_description_length: int = Field( - alias="maxCommunityDescriptionLength", default=65536, ) max_llm_queries_for_global_search: int = Field( - alias="maxLLMQueriesForGlobalSearch", default=250, ) limits: dict[str, int] = Field( - alias="localSearchLimits", default={}, ) enabled: bool = Field( @@ -319,17 +311,14 @@ class SearchSettings(R2RSerializable): use_hybrid_search: bool = Field( default=False, description="Whether to perform a hybrid search. This is equivalent to setting `use_semantic_search=True` and `use_fulltext_search=True`, e.g. combining vector and keyword search.", - alias="useHybridSearch", ) use_semantic_search: bool = Field( default=True, description="Whether to use semantic search", - alias="useSemanticSearch", ) use_fulltext_search: bool = Field( default=False, description="Whether to use full-text search", - alias="useFulltextSearch", ) # Common search parameters @@ -359,24 +348,20 @@ class SearchSettings(R2RSerializable): description="Offset to paginate search results", ) include_metadatas: bool = Field( - alias="includeMetadatas", default=True, description="Whether to include element metadata in the search results", ) include_scores: bool = Field( - alias="includeScores", default=True, description="Whether to include search score values in the search results", ) # Search strategy and settings search_strategy: str = Field( - alias="searchStrategy", default="vanilla", description="Search strategy to use (e.g., 'vanilla', 'query_fusion', 'hyde')", ) hybrid_settings: HybridSearchSettings = Field( - alias="hybridSearchSettings", default_factory=HybridSearchSettings, description="Settings for hybrid search (only used if `use_semantic_search` and `use_fulltext_search` are both true)", ) diff --git a/py/shared/abstractions/user.py b/py/shared/abstractions/user.py index d6ad8e2fc..31d68bf30 100644 --- a/py/shared/abstractions/user.py +++ b/py/shared/abstractions/user.py @@ -14,11 +14,9 @@ class Collection(BaseModel): name: str description: Optional[str] = None created_at: datetime = Field( - alias="createdAt", default_factory=datetime.utcnow, ) updated_at: datetime = Field( - alias="updatedAt", default_factory=datetime.utcnow, )