From 3241c31c5a0ab3111020f7844c44a47e10fbadce Mon Sep 17 00:00:00 2001 From: emrgnt-cmplxty Date: Thu, 5 Dec 2024 10:14:55 -0800 Subject: [PATCH] add web search to rag agent --- py/core/agent/rag.py | 11 +++++------ py/core/agent/serper.py | 2 +- py/core/base/abstractions/__init__.py | 2 +- py/core/main/api/v3/retrieval_router.py | 7 +------ py/core/pipes/retrieval/search_rag_pipe.py | 10 ++++++---- py/shared/abstractions/search.py | 23 +++++++++++++--------- py/shared/utils/base_utils.py | 22 ++++++++++++--------- 7 files changed, 41 insertions(+), 36 deletions(-) diff --git a/py/core/agent/rag.py b/py/core/agent/rag.py index 65efe1214..67b8e4632 100644 --- a/py/core/agent/rag.py +++ b/py/core/agent/rag.py @@ -9,7 +9,7 @@ AggregateSearchResult, GraphSearchSettings, SearchSettings, - WebSearchResponse + WebSearchResponse, ) from core.base.agent import AgentConfig, Tool from core.base.providers import CompletionProvider @@ -37,7 +37,7 @@ def _register_tools(self): self._tools.append(self.web_search()) else: raise ValueError(f"Unsupported tool name: {tool_name}") - + def web_search(self) -> Tool: return Tool( name="web_search", @@ -65,19 +65,18 @@ async def _web_search( **kwargs, ) -> list[AggregateSearchResult]: from .serper import SerperClient + serper_client = SerperClient() # TODO - make async! # TODO - Move to search pipeline, make configurable. - raw_results = serper_client.get_raw(query) + raw_results = serper_client.get_raw(query) web_response = WebSearchResponse.from_serper_results(raw_results) return AggregateSearchResult( chunk_search_results=None, graph_search_results=None, - web_search_results=web_response.organic_results, # TODO - How do we feel about throwing away so much info? + web_search_results=web_response.organic_results, # TODO - How do we feel about throwing away so much info? ) - - def local_search(self) -> Tool: return Tool( name="local_search", diff --git a/py/core/agent/serper.py b/py/core/agent/serper.py index 4d7af18f4..bfed0cad3 100644 --- a/py/core/agent/serper.py +++ b/py/core/agent/serper.py @@ -101,4 +101,4 @@ def construct_context(results: list) -> str: context += f"Item {index}:\n" context += process_json(item) + "\n" - return context \ No newline at end of file + return context diff --git a/py/core/base/abstractions/__init__.py b/py/core/base/abstractions/__init__.py index 90216d707..edeb8b9d7 100644 --- a/py/core/base/abstractions/__init__.py +++ b/py/core/base/abstractions/__init__.py @@ -51,7 +51,6 @@ ) from shared.abstractions.prompt import Prompt from shared.abstractions.search import ( - WebSearchResponse, AggregateSearchResult, ChunkSearchResult, ChunkSearchSettings, @@ -64,6 +63,7 @@ KGRelationshipResult, KGSearchResultType, SearchSettings, + WebSearchResponse, ) from shared.abstractions.user import Token, TokenData, User from shared.abstractions.vector import ( diff --git a/py/core/main/api/v3/retrieval_router.py b/py/core/main/api/v3/retrieval_router.py index 6b3c0f888..7f226df0c 100644 --- a/py/core/main/api/v3/retrieval_router.py +++ b/py/core/main/api/v3/retrieval_router.py @@ -7,12 +7,7 @@ from fastapi import Body, Depends from fastapi.responses import StreamingResponse -from core.base import ( - GenerationConfig, - Message, - R2RException, - SearchSettings, -) +from core.base import GenerationConfig, Message, R2RException, SearchSettings from core.base.api.models import ( WrappedAgentResponse, WrappedCompletionResponse, diff --git a/py/core/pipes/retrieval/search_rag_pipe.py b/py/core/pipes/retrieval/search_rag_pipe.py index dd9416cbf..d2b13c820 100644 --- a/py/core/pipes/retrieval/search_rag_pipe.py +++ b/py/core/pipes/retrieval/search_rag_pipe.py @@ -7,7 +7,7 @@ AsyncState, CompletionProvider, DatabaseProvider, - KGSearchResultType + KGSearchResultType, ) from core.base.abstractions import GenerationConfig, RAGCompletion @@ -111,11 +111,13 @@ async def _collect_context( # context += f"Results:\n" if search_result.result_type == KGSearchResultType.ENTITY: context += f"[{it}]: Entity Name - {search_result.content.name}\n\nDescription - {search_result.content.description}\n\n" - elif search_result.result_type == KGSearchResultType.RELATIONSHIP: + elif ( + search_result.result_type + == KGSearchResultType.RELATIONSHIP + ): context += f"[{it}]: Relationship - {search_result.content.subject} - {search_result.content.predicate} - {search_result.content.object}\n\n" else: - context += f"[{it}]: Community Name - {search_result.content.name}\n\nDescription - {search_result.content.summary}\n\n" - + context += f"[{it}]: Community Name - {search_result.content.name}\n\nDescription - {search_result.content.summary}\n\n" it += 1 total_results = ( diff --git a/py/shared/abstractions/search.py b/py/shared/abstractions/search.py index cbd3a6798..a27d546ad 100644 --- a/py/shared/abstractions/search.py +++ b/py/shared/abstractions/search.py @@ -149,22 +149,25 @@ class Config: class WebSearchResult(R2RSerializable): title: str link: str - snippet: str + snippet: str position: int type: str = "organic" date: Optional[str] = None sitelinks: Optional[list[dict]] = None + class RelatedSearchResult(R2RSerializable): query: str type: str = "related" + class PeopleAlsoAskResult(R2RSerializable): question: str snippet: str link: str title: str - type: str = "peopleAlsoAsk" + type: str = "peopleAlsoAsk" + class WebSearchResponse(R2RSerializable): organic_results: list[WebSearchResult] = [] @@ -176,7 +179,7 @@ def from_serper_results(cls, results: list[dict]) -> "WebSearchResponse": organic = [] related = [] paa = [] - + for result in results: if result["type"] == "organic": organic.append(WebSearchResult(**result)) @@ -184,15 +187,14 @@ def from_serper_results(cls, results: list[dict]) -> "WebSearchResponse": related.append(RelatedSearchResult(**result)) elif result["type"] == "peopleAlsoAsk": paa.append(PeopleAlsoAskResult(**result)) - + return cls( organic_results=organic, related_searches=related, - people_also_ask=paa + people_also_ask=paa, ) - class AggregateSearchResult(R2RSerializable): """Result of an aggregate search operation.""" @@ -213,9 +215,12 @@ def as_dict(self) -> dict: if self.chunk_search_results else [] ), - "graph_search_results": [result.to_dict() for result in self.graph_search_results], - "web_search_results": [result.to_dict() for result in self.web_search_results] - + "graph_search_results": [ + result.to_dict() for result in self.graph_search_results + ], + "web_search_results": [ + result.to_dict() for result in self.web_search_results + ], } diff --git a/py/shared/utils/base_utils.py b/py/shared/utils/base_utils.py index b2c32a2e9..0db6e8ac0 100644 --- a/py/shared/utils/base_utils.py +++ b/py/shared/utils/base_utils.py @@ -84,25 +84,28 @@ def format_search_results_for_llm(results: AggregateSearchResult) -> str: if results.web_search_results: formatted_results.append("Web Search Results:") for result in results.web_search_results: - formatted_results.extend(( - f"Source [{source_counter}]:", - f"Title: {result.title}", - f"Link: {result.link}", - f"Snippet: {result.snippet}" - )) + formatted_results.extend( + ( + f"Source [{source_counter}]:", + f"Title: {result.title}", + f"Link: {result.link}", + f"Snippet: {result.snippet}", + ) + ) if result.date: formatted_results.append(f"Date: {result.date}") source_counter += 1 return "\n".join(formatted_results) + def format_search_results_for_stream(result: AggregateSearchResult) -> str: - CHUNK_SEARCH_STREAM_MARKER = "chunk_search" + CHUNK_SEARCH_STREAM_MARKER = "chunk_search" GRAPH_SEARCH_STREAM_MARKER = "graph_search" WEB_SEARCH_STREAM_MARKER = "web_search" - + context = "" - + if result.chunk_search_results: context += f"<{CHUNK_SEARCH_STREAM_MARKER}>" vector_results_list = [ @@ -129,6 +132,7 @@ def format_search_results_for_stream(result: AggregateSearchResult) -> str: return context + if TYPE_CHECKING: from ..pipeline.base_pipeline import AsyncPipeline