Skip to content

Commit

Permalink
KG hatchet orchestration (#1286)
Browse files Browse the repository at this point in the history
* up

* up

* cleanup kg migration

* up

* up

* up

* Kg testing (#1280)

* up

* up

* up

* up

* rename

* project name

* up

* add chunk order

* fragments => extractions

* bug squash

* up

* up

* up

* change postgres project name

* up

* up

---------

Co-authored-by: emrgnt-cmplxty <owen@algofi.org>
  • Loading branch information
shreyaspimpalgaonkar and emrgnt-cmplxty authored Oct 1, 2024
1 parent d34d40f commit 42ad9b8
Show file tree
Hide file tree
Showing 8 changed files with 167 additions and 96 deletions.
27 changes: 15 additions & 12 deletions py/cli/utils/docker_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def check_llm_reqs(llm_provider, model_provider, include_ollama=False):
"env_vars": [
"AWS_ACCESS_KEY_ID",
"AWS_SECRET_ACCESS_KEY",
"AWS_REGION_NAME",
"AWS_REGION_NAME",
]
},
"groq": {"env_vars": ["GROQ_API_KEY"]},
Expand Down Expand Up @@ -236,17 +236,16 @@ def check_external_ollama(ollama_url="http://localhost:11434/api/version"):


def check_set_docker_env_vars():
env_vars = []

postgres_vars = [
"POSTGRES_HOST",
"POSTGRES_USER",
"POSTGRES_PASSWORD",
"POSTGRES_PORT",
"POSTGRES_DBNAME",
# "POSTGRES_PROJECT_NAME", TODO - uncomment in next release
]
env_vars.extend(postgres_vars)

env_vars = {
"POSTGRES_PROJECT_NAME": "r2r",
"POSTGRES_HOST": "localhost",
"POSTGRES_PORT": "5432",
"POSTGRES_DBNAME": "postgres",
"POSTGRES_USER": "postgres",
"POSTGRES_PASSWORD": "postgres",
}


is_test = (
"pytest" in sys.modules
Expand All @@ -258,6 +257,10 @@ def check_set_docker_env_vars():
for var in env_vars:
if value := os.environ.get(var):
warning_text = click.style("Warning:", fg="red", bold=True)

if value == env_vars[var]:
continue

prompt = (
f"{warning_text} It's only necessary to set this environment variable when connecting to an instance not managed by R2R.\n"
f"Environment variable {var} is set to '{value}'. Unset it?"
Expand Down
1 change: 1 addition & 0 deletions py/core/base/logging/run_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,7 @@ async def _init(self):
statement_cache_size=0, # Disable statement caching
)
async with self.pool.acquire() as conn:

await conn.execute(
f"""
CREATE TABLE IF NOT EXISTS {self.project_name}.{self.log_table} (
Expand Down
21 changes: 13 additions & 8 deletions py/core/main/api/kg_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ async def create_graph(
collection_id: str = Body(
description="Collection ID to create graph for.",
),
kg_creation_settings: Optional[dict] = Body(
kg_creation_settings: Optional[
Union[dict, KGCreationSettings]
] = Body(
default='{}',
description="Settings for the graph creation process.",
),
Expand Down Expand Up @@ -107,9 +109,9 @@ async def enrich_graph(
description="Collection name to enrich graph for.",
),
kg_enrichment_settings: Optional[
Json[KGEnrichmentSettings]
Union[dict, KGEnrichmentSettings]
] = Body(
default=None,
default='{}',
description="Settings for the graph enrichment process.",
),
auth_user=Depends(self.service.providers.auth.auth_wrapper),
Expand All @@ -122,14 +124,17 @@ async def enrich_graph(
if not auth_user.is_superuser:
logger.warning("Implement permission checks here.")

if kg_enrichment_settings is None:
kg_enrichment_settings = (
self.service.providers.kg.config.kg_enrichment_settings
)
server_kg_enrichment_settings = (
self.service.providers.kg.config.kg_enrichment_settings
)

for key, value in kg_enrichment_settings.items():
if value is not None:
setattr(server_kg_enrichment_settings, key, value)

workflow_input = {
"collection_id": collection_id,
"kg_enrichment_settings": kg_enrichment_settings.json(),
"kg_enrichment_settings": server_kg_enrichment_settings.json(),
"user": auth_user.json(),
}

Expand Down
122 changes: 93 additions & 29 deletions py/core/main/orchestration/hatchet/kg_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,24 +23,58 @@ def hatchet_kg_factory(
orchestration_provider: OrchestrationProvider, service: KgService
) -> list["Hatchet.Workflow"]:

def get_input_data_dict(input_data):
for key, value in input_data.items():
if key == "kg_creation_settings":
input_data[key] = json.loads(value)
input_data[key]['generation_config'] = GenerationConfig(
**input_data[key]['generation_config']
)
if key == "kg_enrichment_settings":
input_data[key] = json.loads(value)

if key == "generation_config":
input_data[key] = GenerationConfig(**input_data[key])
return input_data


@orchestration_provider.workflow(name="kg-extract", timeout="360m")
class KGExtractDescribeEmbedWorkflow:
def __init__(self, kg_service: KgService):
self.kg_service = kg_service

@orchestration_provider.step(retries=3, timeout="360m")
@orchestration_provider.step(retries=1, timeout="360m")
async def kg_extract(self, context: Context) -> dict:
return await self.kg_service.kg_extraction(
**context.workflow_input()["request"]

context.log(f"Running KG Extraction for input: {context.workflow_input()['request']}")

input_data = get_input_data_dict(context.workflow_input()["request"])

# context.log(f"Running KG Extraction for collection ID: {input_data['collection_id']}")
document_id = input_data["document_id"]

await self.kg_service.kg_extraction(
document_id=uuid.UUID(document_id),
logger=context.log,
**input_data["kg_creation_settings"]
)

@orchestration_provider.step(retries=3, timeout="360m")
return {"result": f"successfully ran kg triples extraction for document {document_id}"}

@orchestration_provider.step(retries=1, timeout="360m", parents=["kg_extract"])
async def kg_node_description(self, context: Context) -> dict:
return await self.kg_service.kg_node_description(
**context.workflow_input()["request"]

input_data = get_input_data_dict(context.workflow_input()["request"])
document_id = input_data["document_id"]

await self.kg_service.kg_node_description(
document_id=uuid.UUID(document_id),
**input_data["kg_creation_settings"]
)

@orchestration_provider.workflow(name="create-graph", timeout="60m")
return {"result": f"successfully ran kg node description for document {document_id}"}

@orchestration_provider.workflow(name="create-graph", timeout="360m")
class CreateGraphWorkflow:
def __init__(self, kg_service: KgService):
self.kg_service = kg_service
Expand All @@ -49,21 +83,39 @@ def __init__(self, kg_service: KgService):
async def get_document_ids_for_create_graph(
self, context: Context
) -> dict:
return await self.kg_service.get_document_ids_for_create_graph(
**context.workflow_input()["request"]
)

input_data = get_input_data_dict(context.workflow_input()["request"])
collection_id = input_data["collection_id"]

return_val = {
"document_ids": [
str(doc_id)
for doc_id in await self.kg_service.get_document_ids_for_create_graph(
collection_id=collection_id,
**input_data["kg_creation_settings"]
)
]
}

if len(return_val["document_ids"]) == 0:
raise ValueError("No documents to process, either all documents to create the graph were already created or in progress, or the collection is empty.")

return return_val

@orchestration_provider.step(
retries=1, parents=["get_document_ids_for_create_graph"]
)
async def kg_extraction_ingress(self, context: Context) -> dict:

document_ids = context.step_output(
"get_document_ids_for_create_graph"
)
document_ids = [
uuid.UUID(doc_id)
for doc_id in context.step_output(
"get_document_ids_for_create_graph"
)["document_ids"]
]
results = []
for cnt, document_id in enumerate(document_ids):
context.logger.info(
context.log(
f"Running Graph Creation Workflow for document ID: {document_id}"
)
results.append(
Expand All @@ -86,35 +138,42 @@ async def kg_extraction_ingress(self, context: Context) -> dict:
)

if not document_ids:
logger.info(
context.log(
"No documents to process, either all graphs were created or in progress, or no documents were provided. Skipping graph creation."
)
return {"result": "No documents to process"}

logger.info(f"Ran {len(results)} workflows for graph creation")
context.log(f"Ran {len(results)} workflows for graph creation")
results = await asyncio.gather(*results)
return {
"result": f"successfully ran graph creation workflows for {len(results)} documents"
}

@orchestration_provider.workflow(name="enrich-graph", timeout="60m")
@orchestration_provider.workflow(name="enrich-graph", timeout="360m")
class EnrichGraphWorkflow:
def __init__(self, kg_service: KgService):
self.kg_service = kg_service

@orchestration_provider.step(retries=1, parents=[], timeout="360m")
async def kg_clustering(self, context: Context) -> dict:
return await self.kg_service.kg_clustering(
**context.workflow_input()["request"]

input_data = get_input_data_dict(context.workflow_input()["request"])
collection_id = input_data["collection_id"]

kg_clustering_results = await self.kg_service.kg_clustering(
collection_id=collection_id,
**input_data["kg_enrichment_settings"]
)

context.log(f"Successfully ran kg clustering for collection {collection_id}: {json.dumps(kg_clustering_results)}")
return {"result": f"successfully ran kg clustering for collection {collection_id}", "kg_clustering": kg_clustering_results}

@orchestration_provider.step(retries=1, parents=["kg_clustering"])
async def kg_community_summary(self, context: Context) -> dict:

input_data = context.workflow_input()["request"]
num_communities = context.step_output("kg_clustering")[0][
"num_communities"
]
input_data = get_input_data_dict(context.workflow_input()["request"])
collection_id = input_data["collection_id"]
num_communities = context.step_output("kg_clustering")["kg_clustering"][0]["num_communities"]

parallel_communities = min(100, num_communities)
total_workflows = math.ceil(num_communities / parallel_communities)
Expand All @@ -129,29 +188,34 @@ async def kg_community_summary(self, context: Context) -> dict:
"request": {
"offset": offset,
"limit": parallel_communities,
**input_data,
"collection_id": collection_id,
**input_data["kg_enrichment_settings"]
}
},
key=f"{i}/{total_workflows}_community_summary",
)
)
await asyncio.gather(*workflows)
return {
"result": "successfully ran kg community summary workflows"
"result": f"successfully ran kg community summary workflows for {num_communities} communities"
}

@orchestration_provider.workflow(
name="kg-community-summary", timeout="60m"
name="kg-community-summary", timeout="360m"
)
class KGCommunitySummaryWorkflow:
def __init__(self, kg_service: KgService):
self.kg_service = kg_service

@orchestration_provider.step(retries=1, timeout="60m")
@orchestration_provider.step(retries=1, timeout="360m")
async def kg_community_summary(self, context: Context) -> dict:
return await self.kg_service.kg_community_summary(
**context.workflow_input()["request"]
input_data = get_input_data_dict(context.workflow_input()["request"])

community_summary = await self.kg_service.kg_community_summary(
**input_data
)
context.log(f"Successfully ran kg community summary for communities {input_data['offset']} to {input_data['offset'] + len(community_summary)}")
return {"result": f"successfully ran kg community summary for communities {input_data['offset']} to {input_data['offset'] + len(community_summary)}"}

return {
"kg-extract": KGExtractDescribeEmbedWorkflow(service),
Expand Down
40 changes: 19 additions & 21 deletions py/core/main/services/kg_service.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import math
from typing import AsyncGenerator
from typing import AsyncGenerator, Optional
from uuid import UUID

from core.base import KGCreationStatus, KGCreationSettings
Expand Down Expand Up @@ -53,6 +53,7 @@ async def kg_extraction(
max_knowledge_triples: int,
entity_types: list[str],
relation_types: list[str],
hatchet_logger: Optional = None,
**kwargs,
):
try:
Expand All @@ -74,6 +75,7 @@ async def kg_extraction(
"max_knowledge_triples": max_knowledge_triples,
"entity_types": entity_types,
"relation_types": relation_types,
"hatchet_logger": hatchet_logger,
}
),
state=None,
Expand Down Expand Up @@ -140,32 +142,28 @@ async def kg_node_description(

# process 50 entities at a time
num_batches = math.ceil(entity_count / 50)
workflows = []

all_results = []
for i in range(num_batches):
logger.info(
f"Running kg_node_description for batch {i+1}/{num_batches} for document {document_id}"
)
# await self.kg_service.kg_node_description(
# offset=i * 50,
# limit=50,
# document_id=document_id,
# max_description_input_length=max_description_input_length,
# )

node_extractions = await self.pipes.kg_node_description_pipe.run(
input=self.pipes.kg_node_description_pipe.Input(
message={
"offset": i * 50,
"limit": 50,
"max_description_input_length": max_description_input_length,
"document_id": document_id,
}
),
state=None,
run_manager=self.run_manager,
)
return await _collect_results(node_extractions)
input=self.pipes.kg_node_description_pipe.Input(
message={
"offset": i * 50,
"limit": 50,
"max_description_input_length": max_description_input_length,
"document_id": document_id,
}
),
state=None,
run_manager=self.run_manager,
)

all_results.append(await _collect_results(node_extractions))

return all_results

@telemetry_event("kg_clustering")
async def kg_clustering(
Expand Down
4 changes: 1 addition & 3 deletions py/core/pipes/kg/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,7 @@ def parse_fn(response_str: str) -> Any:
str(extraction.id)
for extraction in extractions
],
attributes={
"extraction_text": combined_extraction
},
attributes={},
)
)

Expand Down
Loading

0 comments on commit 42ad9b8

Please sign in to comment.