Skip to content

Commit

Permalink
Feature/run pre commit (#1209)
Browse files Browse the repository at this point in the history
* revamp agent streaming

* revamp agent streaming

* fix kg bug

* rm print

* add pre-commit
  • Loading branch information
emrgnt-cmplxty authored Sep 19, 2024
1 parent c7f5250 commit 1f94f11
Show file tree
Hide file tree
Showing 8 changed files with 80 additions and 31 deletions.
2 changes: 1 addition & 1 deletion docs/api-reference/openapi.json

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions py/cli/utils/docker_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ async def run_local_serve(
await r2r_instance.orchestration_provider.start_worker()
r2r_instance.serve(host, available_port)


def run_docker_serve(
host: str,
port: int,
Expand Down
19 changes: 13 additions & 6 deletions py/core/base/abstractions/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,12 @@ class KGSearchResultType(str, Enum):
RELATIONSHIP = "relationship"
COMMUNITY = "community"


class KGSearchMethod(str, Enum):
LOCAL = "local"
GLOBAL = "global"


class KGEntityResult(BaseModel):
name: str
description: str
Expand All @@ -76,6 +78,7 @@ class Config:
"metadata": {},
}


class KGRelationshipResult(BaseModel):
name: str
description: str
Expand All @@ -88,6 +91,7 @@ class Config:
"metadata": {},
}


class KGCommunityResult(BaseModel):
name: str
description: str
Expand All @@ -100,6 +104,7 @@ class Config:
"metadata": {},
}


class KGGlobalResult(BaseModel):
name: str
description: str
Expand All @@ -112,9 +117,12 @@ class Config:
"metadata": {},
}


class KGSearchResult(BaseModel):
method: KGSearchMethod
content: Union[KGEntityResult, KGRelationshipResult, KGCommunityResult, KGGlobalResult]
content: Union[
KGEntityResult, KGRelationshipResult, KGCommunityResult, KGGlobalResult
]
result_type: Optional[KGSearchResultType] = None
fragment_ids: Optional[list[UUID]] = None
document_ids: Optional[list[UUID]] = None
Expand All @@ -125,13 +133,12 @@ class Config:
"method": "local",
"content": KGEntityResult.Config.json_schema_extra,
"result_type": "entity",
"fragment_ids": [ 'c68dc72e-fc23-5452-8f49-d7bd46088a96'],
"document_ids": [ '3e157b3a-8469-51db-90d9-52e7d896b49b'],
"metadata": {
"associated_query": "What is the capital of France?"
},
"fragment_ids": ["c68dc72e-fc23-5452-8f49-d7bd46088a96"],
"document_ids": ["3e157b3a-8469-51db-90d9-52e7d896b49b"],
"metadata": {"associated_query": "What is the capital of France?"},
}


class AggregateSearchResult(BaseModel):
"""Result of an aggregate search operation."""

Expand Down
2 changes: 0 additions & 2 deletions py/core/base/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,6 @@ async def handle_function_or_tool_call(
*args,
**kwargs,
) -> Union[str, AsyncGenerator[str, None]]:
print("args:", args)
print("kwargs:", kwargs)
(
self.conversation.append(
Message(
Expand Down
51 changes: 38 additions & 13 deletions py/core/pipes/retrieval/kg_search_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,18 +130,26 @@ async def local_search(
search_type
],
query_embedding=query_embedding,
property_names=["name", "description", "fragment_ids", "document_ids"],
property_names=[
"name",
"description",
"fragment_ids",
"document_ids",
],
):
print(search_result)
yield KGSearchResult(
content=KGEntityResult(name=search_result["name"], description=search_result["description"]),
content=KGEntityResult(
name=search_result["name"],
description=search_result["description"],
),
method=KGSearchMethod.LOCAL,
result_type=KGSearchResultType.ENTITY,
fragment_ids=search_result["fragment_ids"],
document_ids=search_result["document_ids"],
metadata={'associated_query': message},
metadata={"associated_query": message},
)

# relationship search
search_type = "__Relationship__"
async for search_result in self.kg_provider.vector_query(
Expand All @@ -151,15 +159,23 @@ async def local_search(
search_type
],
query_embedding=query_embedding,
property_names=["name", "description", "fragment_ids", "document_ids"],
property_names=[
"name",
"description",
"fragment_ids",
"document_ids",
],
):
yield KGSearchResult(
content=KGRelationshipResult(name=search_result["name"], description=search_result["description"]),
content=KGRelationshipResult(
name=search_result["name"],
description=search_result["description"],
),
method=KGSearchMethod.LOCAL,
result_type=KGSearchResultType.RELATIONSHIP,
fragment_ids=search_result["fragment_ids"],
document_ids=search_result["document_ids"],
metadata={'associated_query': message},
metadata={"associated_query": message},
)

# community search
Expand All @@ -174,7 +190,7 @@ async def local_search(
query_embedding=query_embedding,
property_names=["title", "summary"],
):

summary = search_result["summary"]

# try loading it as a json
Expand All @@ -183,17 +199,24 @@ async def local_search(
description = summary_json.get("summary", "")
name = summary_json.get("title", "")

description += "\n\n" + "\n".join([finding["summary"] for finding in summary_json.get("findings", [])])

description += "\n\n" + "\n".join(
[
finding["summary"]
for finding in summary_json.get("findings", [])
]
)

except json.JSONDecodeError:
logger.warning(f"Summary is not valid JSON: {summary}")
continue

yield KGSearchResult(
content=KGCommunityResult(name=name, description=description),
content=KGCommunityResult(
name=name, description=description
),
method=KGSearchMethod.LOCAL,
result_type=KGSearchResultType.COMMUNITY,
metadata={'associated_query': message},
metadata={"associated_query": message},
)

async def global_search(
Expand Down Expand Up @@ -283,7 +306,9 @@ async def process_community(merged_report):
output = output.choices[0].message.content

yield KGSearchResult(
content=KGGlobalResult(name="Global Result", description=output),
content=KGGlobalResult(
name="Global Result", description=output
),
method=KGSearchMethod.GLOBAL,
metadata={},
)
Expand Down
4 changes: 3 additions & 1 deletion py/core/pipes/retrieval/search_rag_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,9 @@ async def _collect_context(
context += f"Knowledge Graph ({iteration}):\n"
it = total_results + 1
for search_results in results.kg_search_results: # [1]:
context += f"Query: {search_results.metadata['associated_query']}\n\n"
context += (
f"Query: {search_results.metadata['associated_query']}\n\n"
)
context += f"Results:\n"
for search_result in search_results:
context += f"[{it}]: {search_result}\n\n"
Expand Down
13 changes: 10 additions & 3 deletions py/core/providers/kg/neo4j/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,9 @@ def get_schema(self, refresh: bool = False) -> str:
def retrieve_cache(self, cache_type: str, cache_id: str) -> bool:
return False

async def vector_query(self, query, **kwargs: Any) -> AsyncGenerator[dict[str, Any], None]:
async def vector_query(
self, query, **kwargs: Any
) -> AsyncGenerator[dict[str, Any], None]:

query_embedding = kwargs.get("query_embedding", None)
search_type = kwargs.get("search_type", "__Entity__")
Expand Down Expand Up @@ -436,9 +438,14 @@ async def vector_query(self, query, **kwargs: Any) -> AsyncGenerator[dict[str, A
# descriptions = [record['e']._properties[property_name] for record in neo4j_results.records for property_name in property_names]
# return descriptions, scores
if search_type == "__Entity__" and len(neo4j_results.records) == 0:
raise R2RException("No search results found. Please make sure you have run the KG enrichment step before running the search: r2r create-graph and r2r enrich-graph", 400)
raise R2RException(
"No search results found. Please make sure you have run the KG enrichment step before running the search: r2r create-graph and r2r enrich-graph",
400,
)

logger.info(f"Neo4j results: Returning {len(neo4j_results.records)} records for query of type {search_type}")
logger.info(
f"Neo4j results: Returning {len(neo4j_results.records)} records for query of type {search_type}"
)

for record in neo4j_results.records:
yield {
Expand Down
19 changes: 14 additions & 5 deletions py/sdk/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,15 +188,18 @@ class Config:
},
}


class KGSearchResultType(str, Enum):
ENTITY = "entity"
RELATIONSHIP = "relationship"
COMMUNITY = "community"


class KGSearchMethod(str, Enum):
LOCAL = "local"
GLOBAL = "global"


class KGEntityResult(BaseModel):
name: str
description: str
Expand All @@ -209,6 +212,7 @@ class Config:
"metadata": {},
}


class KGRelationshipResult(BaseModel):
name: str
description: str
Expand All @@ -221,6 +225,7 @@ class Config:
"metadata": {},
}


class KGCommunityResult(BaseModel):
name: str
description: str
Expand All @@ -245,10 +250,13 @@ class Config:
"description": "Global Result Description",
"metadata": {},
}



class KGSearchResult(BaseModel):
method: KGSearchMethod
content: Union[KGEntityResult, KGRelationshipResult, KGCommunityResult, KGGlobalResult]
content: Union[
KGEntityResult, KGRelationshipResult, KGCommunityResult, KGGlobalResult
]
result_type: Optional[KGSearchResultType] = None
fragment_ids: Optional[list[UUID]] = None
document_ids: Optional[list[UUID]] = None
Expand All @@ -259,11 +267,12 @@ class Config:
"method": "local",
"content": KGEntityResult.Config.json_schema_extra,
"result_type": "entity",
"fragment_ids": [ 'c68dc72e-fc23-5452-8f49-d7bd46088a96'],
"document_ids": [ '3e157b3a-8469-51db-90d9-52e7d896b49b'],
"metadata": { "associated_query": "What is the capital of France?" },
"fragment_ids": ["c68dc72e-fc23-5452-8f49-d7bd46088a96"],
"document_ids": ["3e157b3a-8469-51db-90d9-52e7d896b49b"],
"metadata": {"associated_query": "What is the capital of France?"},
}


class R2RException(Exception):
def __init__(
self, message: str, status_code: int, detail: Optional[Any] = None
Expand Down

0 comments on commit 1f94f11

Please sign in to comment.