Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updating Enrichment status #1544

Merged
merged 16 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,5 @@ dist/
*.test
go.work
go.work.sum

.vscode/
5 changes: 4 additions & 1 deletion py/core/base/providers/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -1130,7 +1130,10 @@ async def get_workflow_status(
return await self.document_handler.get_workflow_status(id, status_type)

async def set_workflow_status(
self, id: Union[UUID, list[UUID]], status_type: str, status: str
self,
id: Union[UUID, list[UUID]],
status_type: str,
status: str,
):
return await self.document_handler.set_workflow_status(
id, status_type, status
Expand Down
1 change: 1 addition & 0 deletions py/core/base/providers/ingestion.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class IngestionConfig(ProviderConfig):
chunk_enrichment_settings: ChunkEnrichmentSettings = (
ChunkEnrichmentSettings()
)

extra_parsers: dict[str, str] = {}

audio_transcription_model: str = "openai/whisper-1"
Expand Down
2 changes: 1 addition & 1 deletion py/core/main/api/management_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from datetime import datetime, timezone
from typing import Any, Optional, Set, Union
from uuid import UUID

import os
import psutil
from fastapi import Body, Depends, Path, Query
from fastapi.responses import StreamingResponse
Expand Down
28 changes: 19 additions & 9 deletions py/core/main/orchestration/hatchet/ingestion_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
increment_version,
)
from core.base.abstractions import DocumentInfo, R2RException
from core.utils import generate_default_user_collection_id
from core.utils import (
generate_default_user_collection_id,
update_settings_from_dict,
)

from ...services import IngestionService, IngestionServiceAdapter

Expand Down Expand Up @@ -163,15 +166,21 @@ async def parse(self, context: Context) -> dict:
document_id=document_info.id, collection_id=collection_id
)

chunk_enrichment_settings = getattr(
# get server chunk enrichment settings and override parts of it if provided in the ingestion config
server_chunk_enrichment_settings = getattr(
service.providers.ingestion.config,
"chunk_enrichment_settings",
None,
)

if chunk_enrichment_settings and getattr(
chunk_enrichment_settings, "enable_chunk_enrichment", False
):
if server_chunk_enrichment_settings:
chunk_enrichment_settings = update_settings_from_dict(
server_chunk_enrichment_settings,
ingestion_config.get("chunk_enrichment_settings", {})
or {},
)

if chunk_enrichment_settings.enable_chunk_enrichment:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a check to ensure chunk_enrichment_settings is not None before accessing chunk_enrichment_settings.enable_chunk_enrichment to avoid potential AttributeError.

Suggested change
if chunk_enrichment_settings.enable_chunk_enrichment:
if chunk_enrichment_settings and chunk_enrichment_settings.enable_chunk_enrichment:


logger.info("Enriching document with contextual chunks")

Expand All @@ -192,6 +201,7 @@ async def parse(self, context: Context) -> dict:

await self.ingestion_service.chunk_enrichment(
document_id=document_info.id,
chunk_enrichment_settings=chunk_enrichment_settings,
)

await self.ingestion_service.update_document_status(
Expand Down Expand Up @@ -242,10 +252,10 @@ async def on_failure(self, context: Context) -> None:
document_info = documents_overview[0]

# Update the document status to FAILED
if (
not document_info.ingestion_status
== IngestionStatus.SUCCESS
):
if document_info.ingestion_status not in [
IngestionStatus.SUCCESS,
IngestionStatus.ENRICHED,
]:
await self.ingestion_service.update_document_status(
document_info,
status=IngestionStatus.FAILED,
Expand Down
66 changes: 47 additions & 19 deletions py/core/main/orchestration/hatchet/kg_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from core import GenerationConfig
from core.base import OrchestrationProvider
from core.base.abstractions import KGExtractionStatus
from core.base.abstractions import KGEnrichmentStatus, KGExtractionStatus

from ...services import KgService

Expand All @@ -26,6 +26,10 @@ def hatchet_kg_factory(

def get_input_data_dict(input_data):
for key, value in input_data.items():

if key == "collection_id":
input_data[key] = uuid.UUID(value)

if key == "kg_creation_settings":
input_data[key] = json.loads(value)
input_data[key]["generation_config"] = GenerationConfig(
Expand Down Expand Up @@ -291,8 +295,8 @@ async def kg_entity_deduplication_setup(
key=f"{i}/{total_workflows}_entity_deduplication_part",
)
)
await asyncio.gather(*workflows)

await asyncio.gather(*workflows)
return {
"result": f"successfully queued kg entity deduplication for collection {collection_id} with {number_of_distinct_entities} distinct entities"
}
Expand Down Expand Up @@ -384,27 +388,51 @@ async def kg_community_summary(self, context: Context) -> dict:
for i in range(total_workflows):
offset = i * parallel_communities
workflows.append(
context.aio.spawn_workflow(
"kg-community-summary",
{
"request": {
"offset": offset,
"limit": min(
parallel_communities,
num_communities - offset,
),
"collection_id": collection_id,
**input_data["kg_enrichment_settings"],
}
},
key=f"{i}/{total_workflows}_community_summary",
)
(
await context.aio.spawn_workflow(
"kg-community-summary",
{
"request": {
"offset": offset,
"limit": min(
parallel_communities,
num_communities - offset,
),
"collection_id": str(collection_id),
**input_data["kg_enrichment_settings"],
}
},
key=f"{i}/{total_workflows}_community_summary",
)
).result()
)
await asyncio.gather(*workflows)

results = await asyncio.gather(*workflows)

logger.info(f"Ran {len(results)} workflows for community summary")

# set status to success
await self.kg_service.providers.database.set_workflow_status(
id=collection_id,
status_type="kg_enrichment_status",
status=KGEnrichmentStatus.SUCCESS,
)

return {
"result": f"Successfully spawned summary workflows for {num_communities} communities."
"result": f"Successfully completed enrichment for collection {collection_id} in {len(results)} workflows."
}

@orchestration_provider.failure()
async def on_failure(self, context: Context) -> None:
collection_id = context.workflow_input()["request"][
"collection_id"
]
await self.kg_service.providers.database.set_workflow_status(
id=collection_id,
status_type="kg_enrichment_status",
status=KGEnrichmentStatus.FAILED,
)

@orchestration_provider.workflow(
name="kg-community-summary", timeout="360m"
)
Expand Down
69 changes: 47 additions & 22 deletions py/core/main/orchestration/simple/kg_workflow.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import json
import logging
import math
import uuid

from core import GenerationConfig
from core import R2RException

from ...services import KgService
from core.base.abstractions import KGEnrichmentStatus

logger = logging.getLogger()

Expand All @@ -13,6 +16,10 @@ def simple_kg_factory(service: KgService):

def get_input_data_dict(input_data):
for key, value in input_data.items():

if key == "collection_id":
input_data[key] = uuid.UUID(value)

if key == "kg_creation_settings":
input_data[key] = json.loads(value)
input_data[key]["generation_config"] = GenerationConfig(
Expand Down Expand Up @@ -61,32 +68,50 @@ async def enrich_graph(input_data):

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a check to ensure input_data['kg_enrichment_settings'] is not None before accessing its attributes to avoid potential AttributeError.

input_data = get_input_data_dict(input_data)

num_communities = await service.kg_clustering(
collection_id=input_data["collection_id"],
**input_data["kg_enrichment_settings"],
)
num_communities = num_communities[0]["num_communities"]
# TODO - Do not hardcode the number of parallel communities,
# make it a configurable parameter at runtime & add server-side defaults
parallel_communities = min(100, num_communities)

total_workflows = math.ceil(num_communities / parallel_communities)
for i in range(total_workflows):
input_data_copy = input_data.copy()
input_data_copy["offset"] = i * parallel_communities
input_data_copy["limit"] = min(
parallel_communities,
num_communities - i * parallel_communities,
try:
num_communities = await service.kg_clustering(
collection_id=input_data["collection_id"],
**input_data["kg_enrichment_settings"],
)
# running i'th workflow out of total_workflows
logger.info(
f"Running kg community summary for {i+1}'th workflow out of total {total_workflows} workflows"
num_communities = num_communities[0]["num_communities"]
# TODO - Do not hardcode the number of parallel communities,
# make it a configurable parameter at runtime & add server-side defaults
parallel_communities = min(100, num_communities)

total_workflows = math.ceil(num_communities / parallel_communities)
for i in range(total_workflows):
input_data_copy = input_data.copy()
input_data_copy["offset"] = i * parallel_communities
input_data_copy["limit"] = min(
parallel_communities,
num_communities - i * parallel_communities,
)
# running i'th workflow out of total_workflows
logger.info(
f"Running kg community summary for {i+1}'th workflow out of total {total_workflows} workflows"
)
await kg_community_summary(
input_data=input_data_copy,
)

await service.providers.database.set_workflow_status(
id=input_data["collection_id"],
status_type="kg_enrichment_status",
status=KGEnrichmentStatus.SUCCESS,
)
await kg_community_summary(
input_data=input_data_copy,
return {
"result": "successfully ran kg community summary workflows"
}

except Exception as e:

await service.providers.database.set_workflow_status(
id=input_data["collection_id"],
status_type="kg_enrichment_status",
status=KGEnrichmentStatus.FAILED,
)

return {"result": "successfully ran kg community summary workflows"}
raise R2RException(f"Error in enriching graph: {e}")

async def kg_community_summary(input_data):

Expand Down
28 changes: 21 additions & 7 deletions py/core/main/services/ingestion_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,21 @@ async def _get_enriched_chunk_text(
for neighbor in semantic_neighbors
)

context_chunk_ids = list(set(context_chunk_ids))
# weird behavior, sometimes we get UUIDs
# FIXME: figure out why
context_chunk_ids_str = list(
set(
[
str(context_chunk_id)
for context_chunk_id in context_chunk_ids
]
)
)

context_chunk_ids_uuid = [
UUID(context_chunk_id)
for context_chunk_id in context_chunk_ids_str
]

context_chunk_texts = [
(
Expand All @@ -466,7 +480,7 @@ async def _get_enriched_chunk_text(
"chunk_order"
],
)
for context_chunk_id in context_chunk_ids
for context_chunk_id in context_chunk_ids_uuid
]

# sort by chunk_order
Expand Down Expand Up @@ -521,13 +535,13 @@ async def _get_enriched_chunk_text(
metadata=chunk["metadata"],
)

async def chunk_enrichment(self, document_id: UUID) -> int:
async def chunk_enrichment(
self,
document_id: UUID,
chunk_enrichment_settings: ChunkEnrichmentSettings,
) -> int:
# just call the pipe on every chunk of the document

# TODO: Why is the config not recognized as an ingestionconfig but as a providerconfig?
chunk_enrichment_settings = (
self.providers.ingestion.config.chunk_enrichment_settings # type: ignore
)
# get all document_chunks
document_chunks = (
await self.providers.database.get_document_chunks(
Expand Down
6 changes: 6 additions & 0 deletions py/core/main/services/management_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from uuid import UUID

import toml
import os

from core.base import (
AnalysisTypes,
Expand All @@ -26,6 +27,9 @@
from ..config import R2RConfig
from .base import Service

from importlib.metadata import version as get_version


logger = logging.getLogger()


Expand Down Expand Up @@ -196,6 +200,8 @@ async def app_settings(self, *args: Any, **kwargs: Any):
return {
"config": config_dict,
"prompts": prompts,
"r2r_project_name": os.environ["R2R_PROJECT_NAME"],
# "r2r_version": get_version("r2r"),
}

@telemetry_event("UsersOverview")
Expand Down
5 changes: 3 additions & 2 deletions py/core/providers/database/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,14 +352,14 @@ async def get_collections_overview(
"""Get an overview of collections, optionally filtered by collection IDs, with pagination."""
query = f"""
WITH collection_overview AS (
SELECT g.collection_id, g.name, g.description, g.created_at, g.updated_at,
SELECT g.collection_id, g.name, g.description, g.created_at, g.updated_at, g.kg_enrichment_status,
COUNT(DISTINCT u.user_id) AS user_count,
COUNT(DISTINCT d.document_id) AS document_count
FROM {self._get_table_name(PostgresCollectionHandler.TABLE_NAME)} g
LEFT JOIN {self._get_table_name('users')} u ON g.collection_id = ANY(u.collection_ids)
LEFT JOIN {self._get_table_name('document_info')} d ON g.collection_id = ANY(d.collection_ids)
{' WHERE g.collection_id = ANY($1)' if collection_ids else ''}
GROUP BY g.collection_id, g.name, g.description, g.created_at, g.updated_at
GROUP BY g.collection_id, g.name, g.description, g.created_at, g.updated_at, g.kg_enrichment_status
),
counted_overview AS (
SELECT *, COUNT(*) OVER() AS total_entries
Expand Down Expand Up @@ -393,6 +393,7 @@ async def get_collections_overview(
updated_at=row["updated_at"],
user_count=row["user_count"],
document_count=row["document_count"],
kg_enrichment_status=row["kg_enrichment_status"],
)
for row in results
]
Expand Down
Loading
Loading