Skip to content

Commit

Permalink
add web search to rag agent
Browse files Browse the repository at this point in the history
  • Loading branch information
emrgnt-cmplxty committed Dec 5, 2024
1 parent d5cb93e commit 3241c31
Show file tree
Hide file tree
Showing 7 changed files with 41 additions and 36 deletions.
11 changes: 5 additions & 6 deletions py/core/agent/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
AggregateSearchResult,
GraphSearchSettings,
SearchSettings,
WebSearchResponse
WebSearchResponse,
)
from core.base.agent import AgentConfig, Tool
from core.base.providers import CompletionProvider
Expand Down Expand Up @@ -37,7 +37,7 @@ def _register_tools(self):
self._tools.append(self.web_search())
else:
raise ValueError(f"Unsupported tool name: {tool_name}")

def web_search(self) -> Tool:
return Tool(
name="web_search",
Expand Down Expand Up @@ -65,19 +65,18 @@ async def _web_search(
**kwargs,
) -> list[AggregateSearchResult]:
from .serper import SerperClient

serper_client = SerperClient()
# TODO - make async!
# TODO - Move to search pipeline, make configurable.
raw_results = serper_client.get_raw(query)
raw_results = serper_client.get_raw(query)
web_response = WebSearchResponse.from_serper_results(raw_results)
return AggregateSearchResult(
chunk_search_results=None,
graph_search_results=None,
web_search_results=web_response.organic_results, # TODO - How do we feel about throwing away so much info?
web_search_results=web_response.organic_results, # TODO - How do we feel about throwing away so much info?
)



def local_search(self) -> Tool:
return Tool(
name="local_search",
Expand Down
2 changes: 1 addition & 1 deletion py/core/agent/serper.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,4 +101,4 @@ def construct_context(results: list) -> str:
context += f"Item {index}:\n"
context += process_json(item) + "\n"

return context
return context
2 changes: 1 addition & 1 deletion py/core/base/abstractions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
)
from shared.abstractions.prompt import Prompt
from shared.abstractions.search import (
WebSearchResponse,
AggregateSearchResult,
ChunkSearchResult,
ChunkSearchSettings,
Expand All @@ -64,6 +63,7 @@
KGRelationshipResult,
KGSearchResultType,
SearchSettings,
WebSearchResponse,
)
from shared.abstractions.user import Token, TokenData, User
from shared.abstractions.vector import (
Expand Down
7 changes: 1 addition & 6 deletions py/core/main/api/v3/retrieval_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,7 @@
from fastapi import Body, Depends
from fastapi.responses import StreamingResponse

from core.base import (
GenerationConfig,
Message,
R2RException,
SearchSettings,
)
from core.base import GenerationConfig, Message, R2RException, SearchSettings
from core.base.api.models import (
WrappedAgentResponse,
WrappedCompletionResponse,
Expand Down
10 changes: 6 additions & 4 deletions py/core/pipes/retrieval/search_rag_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
AsyncState,
CompletionProvider,
DatabaseProvider,
KGSearchResultType
KGSearchResultType,
)
from core.base.abstractions import GenerationConfig, RAGCompletion

Expand Down Expand Up @@ -111,11 +111,13 @@ async def _collect_context(
# context += f"Results:\n"
if search_result.result_type == KGSearchResultType.ENTITY:
context += f"[{it}]: Entity Name - {search_result.content.name}\n\nDescription - {search_result.content.description}\n\n"
elif search_result.result_type == KGSearchResultType.RELATIONSHIP:
elif (
search_result.result_type
== KGSearchResultType.RELATIONSHIP
):
context += f"[{it}]: Relationship - {search_result.content.subject} - {search_result.content.predicate} - {search_result.content.object}\n\n"
else:
context += f"[{it}]: Community Name - {search_result.content.name}\n\nDescription - {search_result.content.summary}\n\n"

context += f"[{it}]: Community Name - {search_result.content.name}\n\nDescription - {search_result.content.summary}\n\n"

it += 1
total_results = (
Expand Down
23 changes: 14 additions & 9 deletions py/shared/abstractions/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,22 +149,25 @@ class Config:
class WebSearchResult(R2RSerializable):
title: str
link: str
snippet: str
snippet: str
position: int
type: str = "organic"
date: Optional[str] = None
sitelinks: Optional[list[dict]] = None


class RelatedSearchResult(R2RSerializable):
query: str
type: str = "related"


class PeopleAlsoAskResult(R2RSerializable):
question: str
snippet: str
link: str
title: str
type: str = "peopleAlsoAsk"
type: str = "peopleAlsoAsk"


class WebSearchResponse(R2RSerializable):
organic_results: list[WebSearchResult] = []
Expand All @@ -176,23 +179,22 @@ def from_serper_results(cls, results: list[dict]) -> "WebSearchResponse":
organic = []
related = []
paa = []

for result in results:
if result["type"] == "organic":
organic.append(WebSearchResult(**result))
elif result["type"] == "relatedSearches":
related.append(RelatedSearchResult(**result))
elif result["type"] == "peopleAlsoAsk":
paa.append(PeopleAlsoAskResult(**result))

return cls(
organic_results=organic,
related_searches=related,
people_also_ask=paa
people_also_ask=paa,
)



class AggregateSearchResult(R2RSerializable):
"""Result of an aggregate search operation."""

Expand All @@ -213,9 +215,12 @@ def as_dict(self) -> dict:
if self.chunk_search_results
else []
),
"graph_search_results": [result.to_dict() for result in self.graph_search_results],
"web_search_results": [result.to_dict() for result in self.web_search_results]

"graph_search_results": [
result.to_dict() for result in self.graph_search_results
],
"web_search_results": [
result.to_dict() for result in self.web_search_results
],
}


Expand Down
22 changes: 13 additions & 9 deletions py/shared/utils/base_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,25 +84,28 @@ def format_search_results_for_llm(results: AggregateSearchResult) -> str:
if results.web_search_results:
formatted_results.append("Web Search Results:")
for result in results.web_search_results:
formatted_results.extend((
f"Source [{source_counter}]:",
f"Title: {result.title}",
f"Link: {result.link}",
f"Snippet: {result.snippet}"
))
formatted_results.extend(
(
f"Source [{source_counter}]:",
f"Title: {result.title}",
f"Link: {result.link}",
f"Snippet: {result.snippet}",
)
)
if result.date:
formatted_results.append(f"Date: {result.date}")
source_counter += 1

return "\n".join(formatted_results)


def format_search_results_for_stream(result: AggregateSearchResult) -> str:
CHUNK_SEARCH_STREAM_MARKER = "chunk_search"
CHUNK_SEARCH_STREAM_MARKER = "chunk_search"
GRAPH_SEARCH_STREAM_MARKER = "graph_search"
WEB_SEARCH_STREAM_MARKER = "web_search"

context = ""

if result.chunk_search_results:
context += f"<{CHUNK_SEARCH_STREAM_MARKER}>"
vector_results_list = [
Expand All @@ -129,6 +132,7 @@ def format_search_results_for_stream(result: AggregateSearchResult) -> str:

return context


if TYPE_CHECKING:
from ..pipeline.base_pipeline import AsyncPipeline

Expand Down

0 comments on commit 3241c31

Please sign in to comment.