Skip to content

Commit

Permalink
Merge pull request #1687 from SciPhi-AI/feature/clustering-as-a-service
Browse files Browse the repository at this point in the history
Feature/clustering as a service
  • Loading branch information
emrgnt-cmplxty authored Dec 12, 2024
2 parents 87222ec + 0eaf32e commit 9cbabb1
Show file tree
Hide file tree
Showing 75 changed files with 500 additions and 9,667 deletions.
51 changes: 51 additions & 0 deletions .github/workflows/build-cluster-docker.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
name: Build and Publish Unstructured Docker Image

on:
workflow_dispatch:

env:
REGISTRY_BASE: ragtoriches

jobs:
build:
runs-on: ubuntu-latest
steps:
- name: Checkout Repository
uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.12'

- name: Install toml package
run: pip install toml

- name: Determine version
id: version
run: |
echo "REGISTRY_IMAGE=${{ env.REGISTRY_BASE }}/cluster-prod" >> $GITHUB_OUTPUT
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3

- name: Docker Auth
uses: docker/login-action@v3
with:
username: ${{ secrets.RAGTORICHES_DOCKER_UNAME }}
password: ${{ secrets.RAGTORICHES_DOCKER_TOKEN }}

- name: Build and push image
uses: docker/build-push-action@v5
with:
context: ./services/clustering
file: ./services/clustering/Dockerfile.clustering
platforms: linux/amd64,linux/arm64
push: true
tags: ${{ steps.version.outputs.REGISTRY_IMAGE }}:latest
provenance: false
sbom: false

- name: Verify manifest
run: |
docker buildx imagetools inspect ${{ steps.version.outputs.REGISTRY_IMAGE }}:latest
20 changes: 17 additions & 3 deletions py/compose.full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,6 @@ services:

unstructured:
image: ${UNSTRUCTURED_IMAGE:-ragtoriches/unst-prod}
ports:
- "${R2R_UNSTRUCTURED_PORT:-7275}:7275"
networks:
- r2r-network
healthcheck:
Expand All @@ -270,6 +268,18 @@ services:
timeout: 5s
retries: 5

graph_clustering:
image: ${GRAPH_CLUSTERING_IMAGE:-ragtoriches/clustering-prod}
ports:
- "${R2R_GRAPH_CLUSTERING_PORT:-7276}:7276"
networks:
- r2r-network
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:7276/health"]
interval: 10s
timeout: 5s
retries: 5

r2r:
image: ${R2R_IMAGE:-ragtoriches/prod:latest}
build:
Expand Down Expand Up @@ -342,13 +352,17 @@ services:
# Unstructured
- UNSTRUCTURED_API_KEY=${UNSTRUCTURED_API_KEY:-}
- UNSTRUCTURED_API_URL=${UNSTRUCTURED_API_URL:-https://api.unstructured.io/general/v0/general}
- UNSTRUCTURED_LOCAL_URL=${UNSTRUCTURED_LOCAL_URL:-http://unstructured:7275}
- UNSTRUCTURED_SERVICE_URL=${UNSTRUCTURED_SERVICE_URL:-http://unstructured:7275}
- UNSTRUCTURED_NUM_WORKERS=${UNSTRUCTURED_NUM_WORKERS:-10}

# Hatchet
- HATCHET_CLIENT_TLS_STRATEGY=none
- HATCHET_CLIENT_GRPC_MAX_RECV_MESSAGE_LENGTH=${HATCHET_CLIENT_GRPC_MAX_RECV_MESSAGE_LENGTH:-134217728}
- HATCHET_CLIENT_GRPC_MAX_SEND_MESSAGE_LENGTH=${HATCHET_CLIENT_GRPC_MAX_SEND_MESSAGE_LENGTH:-134217728}

# Graphologic
- CLUSTERING_SERVICE_URL=http://graph_clustering:7276

command: >
sh -c '
if [ -z "$${HATCHET_CLIENT_TOKEN}" ]; then
Expand Down
4 changes: 4 additions & 0 deletions py/core/configs/full.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
provider = "litellm"
concurrent_request_limit = 128

[database]
[database.graph_creation_settings]
clustering_mode = "remote"

[ingestion]
provider = "unstructured_local"
strategy = "auto"
Expand Down
5 changes: 2 additions & 3 deletions py/core/configs/full_azure.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,9 @@ concurrent_request_limit = 128
[agent.generation_config]
model = "azure/gpt-4o"

# KG settings
batch_size = 256

[database]
[database.graph_creation_settings]
clustering_mode = "remote"
generation_config = { model = "azure/gpt-4o-mini" }

[database.graph_entity_deduplication_settings]
Expand Down
1 change: 1 addition & 0 deletions py/core/configs/full_local_llm.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ concurrent_request_limit = 1
provider = "postgres"

[database.graph_creation_settings]
clustering_mode = "remote"
graph_entity_description_prompt = "graphrag_entity_description"
entity_types = [] # if empty, all entities are extracted
relation_types = [] # if empty, all relations are extracted
Expand Down
1 change: 1 addition & 0 deletions py/core/main/services/kg_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,7 @@ async def kg_clustering(
"generation_config": generation_config,
"leiden_params": leiden_params,
"logger": logger,
"clustering_mode": self.config.database.graph_creation_settings.clustering_mode,
}
),
state=None,
Expand Down
4 changes: 4 additions & 0 deletions py/core/pipes/kg/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ async def cluster_kg(
self,
collection_id: UUID,
leiden_params: dict,
clustering_mode: str,
):
"""
Clusters the knowledge graph relationships into communities using hierarchical Leiden algorithm. Uses graspologic library.
Expand All @@ -52,6 +53,7 @@ async def cluster_kg(
num_communities = await self.database_provider.graph_handler.perform_graph_clustering(
collection_id=collection_id,
leiden_params=leiden_params,
clustering_mode=clustering_mode,
)

return {
Expand All @@ -72,8 +74,10 @@ async def _run_logic( # type: ignore

collection_id = input.message.get("collection_id", None)
leiden_params = input.message["leiden_params"]
clustering_mode = input.message["clustering_mode"]

yield await self.cluster_kg(
collection_id=collection_id,
leiden_params=leiden_params,
clustering_mode=clustering_mode,
)
12 changes: 10 additions & 2 deletions py/core/pipes/kg/community_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ async def _run_logic( # type: ignore
generation_config = input.message["generation_config"]
max_summary_input_length = input.message["max_summary_input_length"]
collection_id = input.message.get("collection_id", None)
clustering_mode = input.message.get("clustering_mode", None)
community_summary_jobs = []
logger = input.message.get("logger", logging.getLogger())

Expand Down Expand Up @@ -295,16 +296,23 @@ async def _run_logic( # type: ignore
relationship_ids_cache={},
leiden_params=leiden_params,
collection_id=collection_id,
clustering_mode=clustering_mode,
)
)

# Organize clusters
clusters: dict[Any] = {}
for item in community_clusters:
cluster_id = item.cluster
cluster_id = (
item["cluster"]
if clustering_mode == "remote"
else item.cluster
)
if cluster_id not in clusters:
clusters[cluster_id] = []
clusters[cluster_id].append(item.node)
clusters[cluster_id].append(
item["node"] if clustering_mode == "remote" else item.node
)

# Now, process the clusters
for _, nodes in clusters.items():
Expand Down
Loading

0 comments on commit 9cbabb1

Please sign in to comment.