From 794c40871cd3d19fb4bcc9cd5cbbaa59a78cf1c4 Mon Sep 17 00:00:00 2001 From: Nolan Tremelling <34580718+NolanTrem@users.noreply.github.com> Date: Tue, 29 Oct 2024 17:07:02 -0700 Subject: [PATCH] Update community model (#1524) * Feature/tweak actions (#1507) * up * tweak actions * Sync JS SDK, Harmonize Python SDK KG Methods (#1511) * Feature/move logging (#1492) * move logging provider out * move logging provider to own directory, remove singleton * cleanup * fix refactoring tweak (#1496) * Fix JSON serialization and Prompt ID Bugs for Prompts (#1491) * Bug in get prompts * Add tests * Prevent verbose logging on standup * Remove kg as required key in config, await get_all_prompts * Remove reference to fragment id * comment out ingestion * complete logging port (#1499) * Feature/dev rebased (#1500) * Feature/move logging (#1493) * move logging provider out * move logging provider to own directory, remove singleton * cleanup * Update js package (#1498) * fix refactoring tweak (#1496) * Fix JSON serialization and Prompt ID Bugs for Prompts (#1491) * Bug in get prompts * Add tests * Prevent verbose logging on standup * Remove kg as required key in config, await get_all_prompts * Remove reference to fragment id * comment out ingestion * complete logging port (#1499) --------- Co-authored-by: Nolan Tremelling <34580718+NolanTrem@users.noreply.github.com> * Fix handling for R2R exceptions (#1501) * fix doc test (#1502) * Harmonize python SDK KG methods for optional params, add missing JS methods --------- Co-authored-by: emrgnt-cmplxty <68796651+emrgnt-cmplxty@users.noreply.github.com> Co-authored-by: emrgnt-cmplxty * Clean up pagination and offset around KG (#1519) * Move to R2R light for integration testing (#1521) * Update community model --------- Co-authored-by: emrgnt-cmplxty <68796651+emrgnt-cmplxty@users.noreply.github.com> Co-authored-by: emrgnt-cmplxty --- .../r2r-js-sdk-integration-tests.yml | 46 +- .../r2rClientIntegrationSuperUser.test.ts | 17 +- .../r2rClientIntegrationUser.test.ts | 29 +- js/sdk/src/models.tsx | 35 ++ js/sdk/src/r2rClient.ts | 510 ++++++++++++++---- py/core/base/providers/database.py | 61 ++- py/core/main/api/kg_router.py | 50 +- py/core/main/services/kg_service.py | 46 +- py/core/providers/database/kg.py | 139 +++-- py/sdk/mixins/kg.py | 45 +- py/shared/abstractions/graph.py | 36 +- 11 files changed, 694 insertions(+), 320 deletions(-) diff --git a/.github/workflows/r2r-js-sdk-integration-tests.yml b/.github/workflows/r2r-js-sdk-integration-tests.yml index 45cba8dbd..ec804761e 100644 --- a/.github/workflows/r2r-js-sdk-integration-tests.yml +++ b/.github/workflows/r2r-js-sdk-integration-tests.yml @@ -9,49 +9,45 @@ jobs: test: runs-on: ubuntu-latest - defaults: - run: - working-directory: ./js/sdk + env: + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + TELEMETRY_ENABLED: 'false' + R2R_POSTGRES_HOST: localhost + R2R_POSTGRES_DBNAME: postgres + R2R_POSTGRES_PORT: '5432' + R2R_POSTGRES_PASSWORD: postgres + R2R_POSTGRES_USER: postgres + R2R_PROJECT_NAME: r2r_default steps: - uses: actions/checkout@v2 - - name: Set up Python - uses: actions/setup-python@v2 + - name: Set up Python and install dependencies + uses: ./.github/actions/setup-python-light with: - python-version: "3.x" + os: ubuntu-latest - - name: Install R2R - run: | - python -m pip install --upgrade pip - pip install r2r + - name: Setup and start PostgreSQL + uses: ./.github/actions/setup-postgres-ext + with: + os: ubuntu-latest - - name: Start R2R server - env: - OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} - run: | - r2r serve --docker - sleep 60 + - name: Start R2R Light server + uses: ./.github/actions/start-r2r-light - name: Use Node.js uses: actions/setup-node@v2 with: node-version: "20.x" - - name: Install dependencies + - name: Install JS SDK dependencies + working-directory: ./js/sdk run: npm ci - name: Check if R2R server is running run: | curl http://localhost:7272/v2/health || echo "Server not responding" - - name: Display R2R server logs if server not responding - if: failure() - run: docker logs r2r-r2r-1 - - name: Run integration tests + working-directory: ./js/sdk run: npm run test - - - name: Display R2R server logs if tests fail - if: failure() - run: docker logs r2r-r2r-1 diff --git a/js/sdk/__tests__/r2rClientIntegrationSuperUser.test.ts b/js/sdk/__tests__/r2rClientIntegrationSuperUser.test.ts index c52b7e7ea..14ca868ad 100644 --- a/js/sdk/__tests__/r2rClientIntegrationSuperUser.test.ts +++ b/js/sdk/__tests__/r2rClientIntegrationSuperUser.test.ts @@ -34,9 +34,16 @@ let newCollectionId: string; * - updateFiles * - ingestChunks * - updateChunks + * X createVectorIndex + * X listVectorIndices + * X deleteVectorIndex * Management: * - serverStats * X updatePrompt + * X addPrompt + * X getPrompt + * X getAllPrompts + * X deletePrompt * - analytics * - logs * - appSettings @@ -45,7 +52,6 @@ let newCollectionId: string; * X downloadFile * - documentsOverview * - documentChunks - * X inspectKnowledgeGraph * X collectionsOverview * - createCollection * - getCollection @@ -70,8 +76,15 @@ let newCollectionId: string; * X getPreviousBranch * X branchAtMessage * - deleteConversation - * Restructure: + * Knowledge Graphs: + * X createGraph * X enrichGraph + * X getEntities + * X getTriples + * X getCommunities + * X getTunedPrompt + * X deduplicateEntities + * X deleteGraphForCollection * Retrieval: * - search * - rag diff --git a/js/sdk/__tests__/r2rClientIntegrationUser.test.ts b/js/sdk/__tests__/r2rClientIntegrationUser.test.ts index 2afa7f5dc..63113ae5a 100644 --- a/js/sdk/__tests__/r2rClientIntegrationUser.test.ts +++ b/js/sdk/__tests__/r2rClientIntegrationUser.test.ts @@ -30,9 +30,16 @@ const baseUrl = "http://localhost:7272"; * - updateFiles * X ingestChunks * X updateChunks + * X createVectorIndex + * X listVectorIndices + * X deleteVectorIndex * Management: * - serverStats * X updatePrompt + * X addPrompt + * X getPrompt + * X getAllPrompts + * X deletePrompt * X analytics * X logs * - appSettings @@ -41,7 +48,6 @@ const baseUrl = "http://localhost:7272"; * X downloadFile * - documentsOverview * - documentChunks - * X inspectKnowledgeGraph * X collectionsOverview * X createCollection * X getCollection @@ -66,8 +72,15 @@ const baseUrl = "http://localhost:7272"; * X getPreviousBranch * X branchAtMessage * X deleteConversation - * Restructure: + * Knowledge Graphs: + * X createGraph * X enrichGraph + * X getEntities + * X getTriples + * X getCommunities + * X getTunedPrompt + * X deduplicateEntities + * X deleteGraphForCollection * Retrieval: * - search * X rag @@ -169,21 +182,15 @@ describe("r2rClient Integration Tests", () => { }); test("Only an authorized user can call server stats", async () => { - await expect(client.serverStats()).rejects.toThrow( - "Status 403: Only an authorized user can call the `server_stats` endpoint.", - ); + await expect(client.serverStats()).rejects.toThrow(/Status 403/); }); test("Only a superuser can call app settings", async () => { - await expect(client.appSettings()).rejects.toThrow( - "Status 403: Only a superuser can call the `app_settings` endpoint.", - ); + await expect(client.appSettings()).rejects.toThrow(/Status 403/); }); test("Only a superuser can call users overview", async () => { - await expect(client.usersOverview()).rejects.toThrow( - "Status 403: Only a superuser can call the `users_overview` endpoint.", - ); + await expect(client.usersOverview()).rejects.toThrow(/Status 403/); }); test("Document chunks", async () => { diff --git a/js/sdk/src/models.tsx b/js/sdk/src/models.tsx index 6cc26341b..1684ba14d 100644 --- a/js/sdk/src/models.tsx +++ b/js/sdk/src/models.tsx @@ -27,6 +27,7 @@ export interface GenerationConfig { tools?: Array>; add_generation_kwargs?: Record; api_base?: string; + response_format?: string; } export interface HybridSearchSettings { @@ -67,6 +68,40 @@ export interface KGSearchSettings { local_search_limits?: Record; } +export enum KGRunType { + ESTIMATE = "estimate", + RUN = "run", +} + +export interface KGCreationSettings { + kg_triples_extraction_prompt?: string; + kg_entity_description_prompt?: string; + force_kg_creation?: boolean; + entity_types?: string[]; + relation_types?: string[]; + extractions_merge_count?: number; + max_knowledge_triples?: number; + max_description_input_length?: number; + generation_config?: GenerationConfig; +} + +export interface KGEnrichmentSettings { + force_kg_enrichment?: boolean; + community_reports_prompt?: string; + max_summary_input_length?: number; + generation_config?: GenerationConfig; + leiden_params?: Record; +} + +export interface KGEntityDeduplicationSettings { + kg_entity_deduplication_type?: KGEntityDeduplicationType; +} + +export enum KGEntityDeduplicationType { + BY_NAME = "by_name", + BY_DESCRIPTION = "by_description", +} + export interface KGLocalSearchResult { query: string; entities: Record; diff --git a/js/sdk/src/r2rClient.ts b/js/sdk/src/r2rClient.ts index 26480aaf8..d8f6b6fd0 100644 --- a/js/sdk/src/r2rClient.ts +++ b/js/sdk/src/r2rClient.ts @@ -21,6 +21,10 @@ import { RefreshTokenResponse, VectorSearchSettings, KGSearchSettings, + KGRunType, + KGCreationSettings, + KGEnrichmentSettings, + KGEntityDeduplicationSettings, GenerationConfig, RawChunk, } from "./models"; @@ -717,6 +721,89 @@ export class r2rClient { ); } + /** + * Create a vector index for similarity search. + * @param options The options for creating the vector index + * @returns Promise resolving to the creation response + */ + @feature("createVectorIndex") + async createVectorIndex(options: { + tableName: string; + indexMethod: "hnsw" | "ivfflat" | "auto"; + indexMeasure: "cosine_distance" | "l2_distance" | "max_inner_product"; + indexArguments?: { + m?: number; // HNSW: Number of connections per element + ef_construction?: number; // HNSW: Size of dynamic candidate list + n_lists?: number; // IVFFlat: Number of clusters/inverted lists + }; + indexName?: string; + concurrently?: boolean; + }): Promise> { + this._ensureAuthenticated(); + + const data = { + table_name: options.tableName, + index_method: options.indexMethod, + index_measure: options.indexMeasure, + index_arguments: options.indexArguments, + index_name: options.indexName, + concurrently: options.concurrently ?? true, + }; + + return await this._makeRequest("POST", "create_vector_index", { + data, + headers: { + "Content-Type": "application/json", + }, + }); + } + + /** + * List existing vector indices for a table. + * @param options The options for listing vector indices + * @returns Promise resolving to the list of indices + */ + @feature("listVectorIndices") + async listVectorIndices(options: { + tableName?: string; + }): Promise> { + this._ensureAuthenticated(); + + const params: Record = {}; + if (options.tableName) { + params.table_name = options.tableName; + } + + return await this._makeRequest("GET", "list_vector_indices", { params }); + } + + /** + * Delete a vector index from a table. + * @param options The options for deleting the vector index + * @returns Promise resolving to the deletion response + */ + @feature("deleteVectorIndex") + async deleteVectorIndex(options: { + indexName: string; + tableName?: string; + concurrently?: boolean; + }): Promise> { + this._ensureAuthenticated(); + + const data = { + index_name: options.indexName, + table_name: options.tableName, + concurrently: options.concurrently ?? true, + }; + + return await this._makeRequest("DELETE", "delete_vector_index", { + data, + headers: { + "Content-Type": "application/json", + }, + }); + } + // ----------------------------------------------------------------------------- // // Management @@ -772,6 +859,78 @@ export class r2rClient { }); } + /** + * Add a new prompt to the system. + * @returns A promise that resolves to the response from the server. + * @param name The name of the prompt. + * @param template The template for the prompt. + * @param input_types The input types for the prompt. + */ + @feature("addPrompt") + async addPrompt( + name: string, + template: string, + input_types: Record, + ): Promise> { + this._ensureAuthenticated(); + + const data: Record = { name, template, input_types }; + + return await this._makeRequest("POST", "add_prompt", { + data, + headers: { + "Content-Type": "application/json", + }, + }); + } + + /** + * Get a prompt from the system. + * @param name The name of the prompt to retrieve. + * @param inputs Inputs for the prompt. + * @param prompt_override Override for the prompt template. + * @returns + */ + @feature("getPrompt") + async getPrompt( + name: string, + inputs?: Record, + prompt_override?: string, + ): Promise> { + this._ensureAuthenticated(); + + const params: Record = {}; + if (inputs) { + params["inputs"] = JSON.stringify(inputs); + } + if (prompt_override) { + params["prompt_override"] = prompt_override; + } + + return await this._makeRequest("GET", `get_prompt/${name}`, { params }); + } + + /** + * Get all prompts from the system. + * @returns A promise that resolves to the response from the server. + */ + @feature("getAllPrompts") + async getAllPrompts(): Promise> { + this._ensureAuthenticated(); + return await this._makeRequest("GET", "get_all_prompts"); + } + + /** + * Delete a prompt from the system. + * @param prompt_name The name of the prompt to delete. + * @returns A promise that resolves to the response from the server. + */ + @feature("deletePrompt") + async deletePrompt(prompt_name: string): Promise> { + this._ensureAuthenticated(); + return await this._makeRequest("DELETE", `delete_prompt/${prompt_name}`); + } + /** * Get analytics data from the server. * @param filter_criteria The filter criteria to use. @@ -957,29 +1116,6 @@ export class r2rClient { }); } - // /** - // * Inspect the knowledge graph associated with your R2R deployment. - // * @param limit The maximum number of nodes to return. Defaults to 100. - // * @returns A promise that resolves to the response from the server. - // */ - // @feature("inspectKnowledgeGraph") - // async inspectKnowledgeGraph( - // offset?: number, - // limit?: number, - // ): Promise> { - // this._ensureAuthenticated(); - - // const params: Record = {}; - // if (offset !== undefined) { - // params.offset = offset; - // } - // if (limit !== undefined) { - // params.limit = limit; - // } - - // return this._makeRequest("GET", "inspect_knowledge_graph", { params }); - // } - /** * Get an overview of existing collections. * @param collectionIds List of collection IDs to get an overview for. @@ -1262,7 +1398,7 @@ export class r2rClient { return this._makeRequest( "GET", - `get_document_collections/${encodeURIComponent(documentId)}`, + `document_collections/${encodeURIComponent(documentId)}`, { params }, ); } @@ -1462,18 +1598,256 @@ export class r2rClient { // ----------------------------------------------------------------------------- // - // Restructure + // Knowledge Graphs // // ----------------------------------------------------------------------------- + /** + * Create a graph from the given settings. + * @returns A promise that resolves to the response from the server. + * + * @param collection_id The ID of the collection to create the graph for. + * @param run_type The type of run to perform. + * @param kg_creation_settings Settings for the graph creation process. + */ + @feature("createGraph") + async createGraph( + collection_id?: string, + run_type?: KGRunType, + kg_creation_settings?: KGCreationSettings | Record, + ): Promise> { + this._ensureAuthenticated(); + + const json_data: Record = { + collection_id, + run_type, + kg_creation_settings, + }; + + Object.keys(json_data).forEach( + (key) => json_data[key] === undefined && delete json_data[key], + ); + + return await this._makeRequest("POST", "create_graph", { data: json_data }); + } /** * Perform graph enrichment over the entire graph. * @returns A promise that resolves to the response from the server. + * + * @param collection_id The ID of the collection to enrich the graph for. + * @param run_type The type of run to perform. + * @param kg_enrichment_settings Settings for the graph enrichment process. */ @feature("enrichGraph") - async enrichGraph(): Promise { + async enrichGraph( + collection_id?: string, + run_type?: KGRunType, + kg_enrichment_settings?: KGEnrichmentSettings | Record, + ): Promise { this._ensureAuthenticated(); - return await this._makeRequest("POST", "enrich_graph"); + + const json_data: Record = { + collection_id, + run_type, + kg_enrichment_settings, + }; + + Object.keys(json_data).forEach( + (key) => json_data[key] === undefined && delete json_data[key], + ); + + return await this._makeRequest("POST", "enrich_graph", { data: json_data }); + } + + /** + * Retrieve entities from the knowledge graph. + * @returns A promise that resolves to the response from the server. + * @param collection_id The ID of the collection to retrieve entities for. + * @param offset The offset for pagination. + * @param limit The limit for pagination. + * @param entity_level The level of entity to filter by. + * @param entity_ids Entity IDs to filter by. + * @returns + */ + @feature("getEntities") + async getEntities( + collection_id?: string, + offset?: number, + limit?: number, + entity_level?: string, + entity_ids?: string[], + ): Promise { + this._ensureAuthenticated(); + + const params: Record = {}; + if (collection_id !== undefined) { + params.collection_id = collection_id; + } + if (offset !== undefined) { + params.offset = offset; + } + if (limit !== undefined) { + params.limit = limit; + } + if (entity_level !== undefined) { + params.entity_level = entity_level; + } + if (entity_ids !== undefined) { + params.entity_ids = entity_ids; + } + + return this._makeRequest("GET", `entities`, { params }); + } + + /** + * Retrieve triples from the knowledge graph. + * @returns A promise that resolves to the response from the server. + * @param collection_id The ID of the collection to retrieve entities for. + * @param offset The offset for pagination. + * @param limit The limit for pagination. + * @param entity_level The level of entity to filter by. + * @param triple_ids Triple IDs to filter by. + */ + @feature("getTriples") + async getTriples( + collection_id?: string, + offset?: number, + limit?: number, + entity_level?: string, + triple_ids?: string[], + ): Promise { + this._ensureAuthenticated(); + + const params: Record = {}; + if (collection_id !== undefined) { + params.collection_id = collection_id; + } + if (offset !== undefined) { + params.offset = offset; + } + if (limit !== undefined) { + params.limit = limit; + } + if (entity_level !== undefined) { + params.entity_level = entity_level; + } + if (triple_ids !== undefined) { + params.entity_ids = triple_ids; + } + + return this._makeRequest("GET", `triples`, { params }); + } + + /** + * Retrieve communities from the knowledge graph. + * @param collection_id The ID of the collection to retrieve entities for. + * @param offset The offset for pagination. + * @param limit The limit for pagination. + * @param levels Levels to filter by. + * @param community_numbers Community numbers to filter by. + * @returns + */ + @feature("getCommunities") + async getCommunities( + collection_id?: string, + offset?: number, + limit?: number, + levels?: number, + community_numbers?: number[], + ): Promise { + this._ensureAuthenticated(); + + const params: Record = {}; + if (collection_id !== undefined) { + params.collection_id = collection_id; + } + if (offset !== undefined) { + params.offset = offset; + } + if (limit !== undefined) { + params.limit = limit; + } + if (levels !== undefined) { + params.levels = levels; + } + if (community_numbers !== undefined) { + params.community_numbers = community_numbers; + } + + return this._makeRequest("GET", `communities`, { params }); + } + + @feature("getTunedPrompt") + async getTunedPrompt( + prompt_name: string, + collection_id?: string, + documents_offset?: number, + documents_limit?: number, + chunk_offset?: number, + chunk_limit?: number, + ): Promise { + this._ensureAuthenticated(); + + const params: Record = { prompt_name }; + if (collection_id !== undefined) { + params.collection_id = collection_id; + } + if (documents_offset !== undefined) { + params.documents_offset = documents_offset; + } + if (documents_limit !== undefined) { + params.documents_limit = documents_limit; + } + if (chunk_offset !== undefined) { + params.chunk_offset = chunk_offset; + } + if (chunk_limit !== undefined) { + params.chunk_limit = chunk_limit; + } + + return this._makeRequest("GET", `tuned_prompt`, { params }); + } + + @feature("deduplicateEntities") + async deduplicateEntities( + collections_id?: string, + run_type?: KGRunType, + deduplication_settings?: + | KGEntityDeduplicationSettings + | Record, + ): Promise { + this._ensureAuthenticated(); + + const json_data: Record = { + collections_id, + run_type, + deduplication_settings, + }; + + Object.keys(json_data).forEach( + (key) => json_data[key] === undefined && delete json_data[key], + ); + + return await this._makeRequest("POST", "deduplicate_entities", { + data: json_data, + }); + } + + @feature("deleteGraphForCollection") + async deleteGraphForCollection( + collection_id: string, + cascade: boolean = false, + ): Promise { + this._ensureAuthenticated(); + + const json_data: Record = { + collection_id, + cascade, + }; + + return await this._makeRequest("DELETE", `delete_graph`, { + data: json_data, + }); } // ----------------------------------------------------------------------------- @@ -1629,88 +2003,6 @@ export class r2rClient { responseType: "stream", }); } - /** - * Create a vector index for similarity search. - * @param options The options for creating the vector index - * @returns Promise resolving to the creation response - */ - @feature("createVectorIndex") - async createVectorIndex(options: { - tableName: string; - indexMethod: "hnsw" | "ivfflat" | "auto"; - indexMeasure: "cosine_distance" | "l2_distance" | "max_inner_product"; - indexArguments?: { - m?: number; // HNSW: Number of connections per element - ef_construction?: number; // HNSW: Size of dynamic candidate list - n_lists?: number; // IVFFlat: Number of clusters/inverted lists - }; - indexName?: string; - concurrently?: boolean; - }): Promise> { - this._ensureAuthenticated(); - - const data = { - table_name: options.tableName, - index_method: options.indexMethod, - index_measure: options.indexMeasure, - index_arguments: options.indexArguments, - index_name: options.indexName, - concurrently: options.concurrently ?? true, - }; - - return await this._makeRequest("POST", "create_vector_index", { - data, - headers: { - "Content-Type": "application/json", - }, - }); - } - - /** - * List existing vector indices for a table. - * @param options The options for listing vector indices - * @returns Promise resolving to the list of indices - */ - @feature("listVectorIndices") - async listVectorIndices(options: { - tableName?: string; - }): Promise> { - this._ensureAuthenticated(); - - const params: Record = {}; - if (options.tableName) { - params.table_name = options.tableName; - } - - return await this._makeRequest("GET", "list_vector_indices", { params }); - } - - /** - * Delete a vector index from a table. - * @param options The options for deleting the vector index - * @returns Promise resolving to the deletion response - */ - @feature("deleteVectorIndex") - async deleteVectorIndex(options: { - indexName: string; - tableName?: string; - concurrently?: boolean; - }): Promise> { - this._ensureAuthenticated(); - - const data = { - index_name: options.indexName, - table_name: options.tableName, - concurrently: options.concurrently ?? true, - }; - - return await this._makeRequest("DELETE", "delete_vector_index", { - data, - headers: { - "Content-Type": "application/json", - }, - }); - } } export default r2rClient; diff --git a/py/core/base/providers/database.py b/py/core/base/providers/database.py index 65e6ea6d4..a035ea639 100644 --- a/py/core/base/providers/database.py +++ b/py/core/base/providers/database.py @@ -670,20 +670,15 @@ async def add_communities(self, communities: List[Any]) -> None: @abstractmethod async def get_communities( self, - collection_id: UUID, - offset: int = 0, - limit: int = 100, + collection_id: Optional[UUID] = None, levels: Optional[list[int]] = None, community_numbers: Optional[list[int]] = None, + offset: int = 0, + limit: int = -1, ) -> dict: """Get communities for a collection.""" pass - @abstractmethod - async def get_community_count(self, collection_id: UUID) -> int: - """Get total number of communities for a collection.""" - pass - @abstractmethod async def add_community_report( self, community_report: CommunityReport @@ -740,12 +735,12 @@ async def delete_node_via_document_id( @abstractmethod async def get_entities( self, - collection_id: UUID, - offset: int = 0, - limit: int = -1, + collection_id: Optional[UUID] = None, entity_ids: Optional[List[str]] = None, entity_names: Optional[List[str]] = None, entity_table_name: str = "document_entity", + offset: int = 0, + limit: int = -1, ) -> dict: """Get entities from storage.""" pass @@ -753,11 +748,11 @@ async def get_entities( @abstractmethod async def get_triples( self, - collection_id: UUID, - offset: int = 0, - limit: int = 100, + collection_id: Optional[UUID] = None, entity_names: Optional[List[str]] = None, triple_ids: Optional[List[str]] = None, + offset: int = 0, + limit: int = -1, ) -> dict: """Get triples from storage.""" pass @@ -1547,21 +1542,21 @@ async def add_communities(self, communities: List[Any]) -> None: async def get_communities( self, - collection_id: UUID, - offset: int = 0, - limit: int = 100, + collection_id: Optional[UUID] = None, levels: Optional[list[int]] = None, community_numbers: Optional[list[int]] = None, + offset: int = 0, + limit: int = -1, ) -> dict: """Forward to KG handler get_communities method.""" return await self.kg_handler.get_communities( - collection_id, offset, limit, levels, community_numbers + collection_id, + levels, + community_numbers, + offset, + limit, ) - async def get_community_count(self, collection_id: UUID) -> int: - """Forward to KG handler get_community_count method.""" - return await self.kg_handler.get_community_count(collection_id) - async def add_community_report( self, community_report: CommunityReport ) -> None: @@ -1620,34 +1615,38 @@ async def delete_node_via_document_id( # Entity and Triple operations async def get_entities( self, - collection_id: UUID, - offset: int = 0, - limit: int = -1, + collection_id: Optional[UUID], entity_ids: Optional[List[str]] = None, entity_names: Optional[List[str]] = None, entity_table_name: str = "document_entity", + offset: int = 0, + limit: int = -1, ) -> dict: """Forward to KG handler get_entities method.""" return await self.kg_handler.get_entities( collection_id, - offset, - limit, entity_ids, entity_names, entity_table_name, + offset, + limit, ) async def get_triples( self, - collection_id: UUID, - offset: int = 0, - limit: int = 100, + collection_id: Optional[UUID] = None, entity_names: Optional[List[str]] = None, triple_ids: Optional[List[str]] = None, + offset: int = 0, + limit: int = -1, ) -> dict: """Forward to KG handler get_triples method.""" return await self.kg_handler.get_triples( - collection_id, offset, limit, entity_names, triple_ids + collection_id, + entity_names, + triple_ids, + offset, + limit, ) async def get_entity_count( diff --git a/py/core/main/api/kg_router.py b/py/core/main/api/kg_router.py index ccb7b401f..2d1f02617 100644 --- a/py/core/main/api/kg_router.py +++ b/py/core/main/api/kg_router.py @@ -219,19 +219,21 @@ async def enrich_graph( @self.router.get("/entities") @self.base_endpoint async def get_entities( + collection_id: Optional[UUID] = Query( + None, description="Collection ID to retrieve entities from." + ), entity_level: Optional[EntityLevel] = Query( default=EntityLevel.DOCUMENT, description="Type of entities to retrieve. Options are: raw, dedup_document, dedup_collection.", ), - collection_id: Optional[UUID] = Query( - None, description="Collection ID to retrieve entities from." + entity_ids: Optional[list[str]] = Query( + None, description="Entity IDs to filter by." ), offset: int = Query(0, ge=0, description="Offset for pagination."), limit: int = Query( - 100, ge=1, le=1000, description="Limit for pagination." - ), - entity_ids: Optional[list[str]] = Query( - None, description="Entity IDs to filter by." + 100, + ge=-1, + description="Number of items to return. Use -1 to return all items.", ), auth_user=Depends(self.service.providers.auth.auth_wrapper), ) -> WrappedKGEntitiesResponse: @@ -255,10 +257,10 @@ async def get_entities( return await self.service.get_entities( collection_id, - offset, - limit, entity_ids, entity_table_name, + offset, + limit, ) @self.router.get("/triples") @@ -267,16 +269,18 @@ async def get_triples( collection_id: Optional[UUID] = Query( None, description="Collection ID to retrieve triples from." ), - offset: int = Query(0, ge=0, description="Offset for pagination."), - limit: int = Query( - 100, ge=1, le=1000, description="Limit for pagination." - ), entity_names: Optional[list[str]] = Query( None, description="Entity names to filter by." ), triple_ids: Optional[list[str]] = Query( None, description="Triple IDs to filter by." ), + offset: int = Query(0, ge=0, description="Offset for pagination."), + limit: int = Query( + 100, + ge=-1, + description="Number of items to return. Use -1 to return all items.", + ), auth_user=Depends(self.service.providers.auth.auth_wrapper), ) -> WrappedKGTriplesResponse: """ @@ -292,10 +296,10 @@ async def get_triples( return await self.service.get_triples( collection_id, - offset, - limit, entity_names, triple_ids, + offset, + limit, ) @self.router.get("/communities") @@ -304,16 +308,18 @@ async def get_communities( collection_id: Optional[UUID] = Query( None, description="Collection ID to retrieve communities from." ), - offset: int = Query(0, ge=0, description="Offset for pagination."), - limit: int = Query( - 100, ge=1, le=1000, description="Limit for pagination." - ), levels: Optional[list[int]] = Query( None, description="Levels to filter by." ), community_numbers: Optional[list[int]] = Query( None, description="Community numbers to filter by." ), + offset: int = Query(0, ge=0, description="Offset for pagination."), + limit: int = Query( + 100, + ge=-1, + description="Number of items to return. Use -1 to return all items.", + ), auth_user=Depends(self.service.providers.auth.auth_wrapper), ) -> WrappedKGCommunitiesResponse: """ @@ -329,10 +335,10 @@ async def get_communities( return await self.service.get_communities( collection_id, - offset, - limit, levels, community_numbers, + offset, + limit, ) @self.router.post("/deduplicate_entities") @@ -444,10 +450,10 @@ async def get_tuned_prompt( @self.router.delete("/delete_graph_for_collection") @self.base_endpoint async def delete_graph_for_collection( - collection_id: UUID = Body( + collection_id: UUID = Body( # FIXME: This should be a path parameter ..., description="Collection ID to delete graph for." ), - cascade: bool = Body( + cascade: bool = Body( # FIXME: This should be a query parameter default=False, description="Whether to cascade the deletion, and delete entities and triples belonging to the collection.", ), diff --git a/py/core/main/services/kg_service.py b/py/core/main/services/kg_service.py index 92b5d0a63..f4bef4284 100644 --- a/py/core/main/services/kg_service.py +++ b/py/core/main/services/kg_service.py @@ -319,55 +319,55 @@ async def get_enrichment_estimate( @telemetry_event("get_entities") async def get_entities( self, - collection_id: UUID, - offset: int = 0, - limit: int = 100, + collection_id: Optional[UUID] = None, entity_ids: Optional[list[str]] = None, entity_table_name: str = "document_entity", + offset: Optional[int] = None, + limit: Optional[int] = None, **kwargs, ): return await self.providers.database.get_entities( - collection_id, - offset, - limit, - entity_ids, + collection_id=collection_id, + entity_ids=entity_ids, entity_table_name=entity_table_name, + offset=offset or 0, + limit=limit or -1, ) @telemetry_event("get_triples") async def get_triples( self, - collection_id: UUID, - offset: int = 0, - limit: int = 100, + collection_id: Optional[UUID] = None, entity_names: Optional[list[str]] = None, triple_ids: Optional[list[str]] = None, + offset: Optional[int] = None, + limit: Optional[int] = None, **kwargs, ): return await self.providers.database.get_triples( - collection_id, - offset, - limit, - entity_names, - triple_ids, + collection_id=collection_id, + entity_names=entity_names, + triple_ids=triple_ids, + offset=offset or 0, + limit=limit or -1, ) @telemetry_event("get_communities") async def get_communities( self, - collection_id: UUID, - offset: int = 0, - limit: int = 100, + collection_id: Optional[UUID] = None, levels: Optional[list[int]] = None, community_numbers: Optional[list[int]] = None, + offset: Optional[int] = None, + limit: Optional[int] = None, **kwargs, ): return await self.providers.database.get_communities( - collection_id, - offset, - limit, - levels, - community_numbers, + collection_id=collection_id, + levels=levels, + community_numbers=community_numbers, + offset=offset or 0, + limit=limit or -1, ) @telemetry_event("get_deduplication_estimate") diff --git a/py/core/providers/database/kg.py b/py/core/providers/database/kg.py index 1cdc2922f..e79a42b27 100644 --- a/py/core/providers/database/kg.py +++ b/py/core/providers/database/kg.py @@ -528,51 +528,57 @@ async def add_communities(self, communities: list[Any]) -> None: async def get_communities( self, - collection_id: UUID, - offset: int = 0, - limit: int = 100, + collection_id: Optional[UUID] = None, levels: Optional[list[int]] = None, community_numbers: Optional[list[int]] = None, + offset: Optional[int] = 0, + limit: Optional[int] = -1, ) -> dict: - - query_parts = [ - f""" - SELECT id, community_number, collection_id, level, name, summary, findings, rating, rating_explanation - FROM {self._get_table_name('community_report')} WHERE collection_id = $1 ORDER BY community_number LIMIT $2 OFFSET $3 - """ - ] - params = [collection_id, limit, offset] + conditions = [] + params: list = [collection_id] + param_index = 2 if levels is not None: - query_parts.append(f"AND level = ANY(${len(params) + 1})") + conditions.append(f"level = ANY(${param_index})") params.append(levels) + param_index += 1 if community_numbers is not None: - query_parts.append( - f"AND community_number = ANY(${len(params) + 1})" - ) + conditions.append(f"community_number = ANY(${param_index})") params.append(community_numbers) + param_index += 1 + + pagination_params = [] + if offset: + pagination_params.append(f"OFFSET ${param_index}") + params.append(offset) + param_index += 1 - QUERY = " ".join(query_parts) + if limit != -1: + pagination_params.append(f"LIMIT ${param_index}") + params.append(limit) + param_index += 1 - communities = await self.connection_manager.fetch_query(QUERY, params) - communities = [ - CommunityReport(**community) for community in communities - ] + pagination_clause = " ".join(pagination_params) + + query = f""" + SELECT id, community_number, collection_id, level, name, summary, findings, rating, rating_explanation, COUNT(*) OVER() AS total_entries + FROM {self._get_table_name('community_report')} + WHERE collection_id = $1 + {" AND " + " AND ".join(conditions) if conditions else ""} + ORDER BY community_number + {pagination_clause} + """ + + results = await self.connection_manager.fetch_query(query, params) + total_entries = results[0]["total_entries"] if results else 0 + communities = [CommunityReport(**community) for community in results] return { "communities": communities, - "total_entries": (await self.get_community_count(collection_id)), + "total_entries": total_entries, } - async def get_community_count(self, collection_id: UUID) -> int: - QUERY = f""" - SELECT COUNT(*) FROM {self._get_table_name("community_report")} WHERE collection_id = $1 - """ - return ( - await self.connection_manager.fetch_query(QUERY, [collection_id]) - )[0]["count"] - async def add_community_report( self, community_report: CommunityReport ) -> None: @@ -1115,43 +1121,48 @@ async def get_schema(self): async def get_entities( self, - collection_id: UUID, - offset: int = 0, - limit: int = -1, + collection_id: Optional[UUID] = None, entity_ids: Optional[list[str]] = None, entity_names: Optional[list[str]] = None, entity_table_name: str = "document_entity", + offset: int = 0, + limit: int = -1, ) -> dict: conditions = [] params: list = [collection_id] + param_index = 2 if entity_ids: - conditions.append(f"id = ANY(${len(params) + 1})") + conditions.append(f"id = ANY(${param_index})") params.append(entity_ids) + param_index += 1 if entity_names: - conditions.append(f"name = ANY(${len(params) + 1})") + conditions.append(f"name = ANY(${param_index})") params.append(entity_names) + param_index += 1 - if limit != -1: - params.extend([offset, limit]) - offset_limit_clause = ( - f"OFFSET ${len(params) - 1} LIMIT ${len(params)}" - ) - else: + pagination_params = [] + if offset: + pagination_params.append(f"OFFSET ${param_index}") params.append(offset) - offset_limit_clause = f"OFFSET ${len(params)}" + param_index += 1 + + if limit != -1: + pagination_params.append(f"LIMIT ${param_index}") + params.append(limit) + param_index += 1 + + pagination_clause = " ".join(pagination_params) if entity_table_name == "collection_entity": - # entity deduplicated table has document_ids, not document_id. - # we directly use the collection_id to get the entities list. query = f""" SELECT id, name, description, extraction_ids, document_ids FROM {self._get_table_name(entity_table_name)} WHERE collection_id = $1 {" AND " + " AND ".join(conditions) if conditions else ""} ORDER BY id - {offset_limit_clause} + {pagination_clause} """ else: query = f""" @@ -1163,11 +1174,10 @@ async def get_entities( ) {" AND " + " AND ".join(conditions) if conditions else ""} ORDER BY id - {offset_limit_clause} - """ + {pagination_clause} + """ results = await self.connection_manager.fetch_query(query, params) - entities = [Entity(**entity) for entity in results] total_entries = await self.get_entity_count( @@ -1178,24 +1188,40 @@ async def get_entities( async def get_triples( self, - collection_id: UUID, - offset: int = 0, - limit: int = 100, + collection_id: Optional[UUID] = None, entity_names: Optional[list[str]] = None, triple_ids: Optional[list[str]] = None, + offset: Optional[int] = 0, + limit: Optional[int] = -1, ) -> dict: conditions = [] - params = [str(collection_id)] + params: list = [str(collection_id)] + param_index = 2 if triple_ids: - conditions.append(f"id = ANY(${len(params) + 1})") - params.append([str(ele) for ele in triple_ids]) # type: ignore + conditions.append(f"id = ANY(${param_index})") + params.append(triple_ids) + param_index += 1 if entity_names: conditions.append( - f"subject = ANY(${len(params) + 1}) or object = ANY(${len(params) + 1})" + f"subject = ANY(${param_index}) or object = ANY(${param_index})" ) - params.append([str(ele) for ele in entity_names]) # type: ignore + params.append(entity_names) + param_index += 1 + + pagination_params = [] + if offset: + pagination_params.append(f"OFFSET ${param_index}") + params.append(offset) + param_index += 1 + + if limit != -1: + pagination_params.append(f"LIMIT ${param_index}") + params.append(limit) + param_index += 1 + + pagination_clause = " ".join(pagination_params) query = f""" SELECT id, subject, predicate, object, description @@ -1206,9 +1232,8 @@ async def get_triples( ) {" AND " + " AND ".join(conditions) if conditions else ""} ORDER BY id - OFFSET ${len(params) + 1} LIMIT ${len(params) + 2} + {pagination_clause} """ - params.extend([offset, limit]) # type: ignore triples = await self.connection_manager.fetch_query(query, params) triples = [Triple(**triple) for triple in triples] diff --git a/py/sdk/mixins/kg.py b/py/sdk/mixins/kg.py index a718c390d..3caaece16 100644 --- a/py/sdk/mixins/kg.py +++ b/py/sdk/mixins/kg.py @@ -67,11 +67,11 @@ async def enrich_graph( async def get_entities( self, - collection_id: str, - offset: int = 0, - limit: int = 100, - entity_level: Optional[str] = "collection", + collection_id: Optional[Union[UUID, str]] = None, + entity_level: Optional[str] = None, entity_ids: Optional[list[str]] = None, + offset: Optional[int] = None, + limit: Optional[int] = None, ) -> dict: """ Retrieve entities from the knowledge graph. @@ -86,24 +86,26 @@ async def get_entities( Returns: dict: A dictionary containing the retrieved entities and total count. """ + params = { - "entity_level": entity_level, "collection_id": collection_id, + "entity_level": entity_level, + "entity_ids": entity_ids, "offset": offset, "limit": limit, } - if entity_ids: - params["entity_ids"] = ",".join(entity_ids) + + params = {k: v for k, v in params.items() if v is not None} return await self._make_request("GET", "entities", params=params) # type: ignore async def get_triples( self, - collection_id: str, - offset: int = 0, - limit: int = 100, + collection_id: Optional[Union[UUID, str]] = None, entity_names: Optional[list[str]] = None, triple_ids: Optional[list[str]] = None, + offset: Optional[int] = None, + limit: Optional[int] = None, ) -> dict: """ Retrieve triples from the knowledge graph. @@ -118,27 +120,26 @@ async def get_triples( Returns: dict: A dictionary containing the retrieved triples and total count. """ + params = { "collection_id": collection_id, + "entity_names": entity_names, + "triple_ids": triple_ids, "offset": offset, "limit": limit, } - if entity_names: - params["entity_names"] = entity_names - - if triple_ids: - params["triple_ids"] = ",".join(triple_ids) + params = {k: v for k, v in params.items() if v is not None} return await self._make_request("GET", "triples", params=params) # type: ignore async def get_communities( self, - collection_id: str, - offset: int = 0, - limit: int = 100, + collection_id: Optional[Union[UUID, str]] = None, levels: Optional[list[int]] = None, community_numbers: Optional[list[int]] = None, + offset: Optional[int] = None, + limit: Optional[int] = None, ) -> dict: """ Retrieve communities from the knowledge graph. @@ -153,16 +154,16 @@ async def get_communities( Returns: dict: A dictionary containing the retrieved communities. """ + params = { "collection_id": collection_id, + "levels": levels, + "community_numbers": community_numbers, "offset": offset, "limit": limit, } - if levels: - params["levels"] = levels - if community_numbers: - params["community_numbers"] = community_numbers + params = {k: v for k, v in params.items() if v is not None} return await self._make_request("GET", "communities", params=params) # type: ignore diff --git a/py/shared/abstractions/graph.py b/py/shared/abstractions/graph.py index 0b01f8dbd..049679cdc 100644 --- a/py/shared/abstractions/graph.py +++ b/py/shared/abstractions/graph.py @@ -166,35 +166,35 @@ class Community(BaseModel): id: int | None = None """The ID of the community.""" - level: int | None = None - """Community level.""" - - entity_ids: list[str] | None = None - """List of entity IDs related to the community (optional).""" + community_number: int | None = None + """The community number.""" - relationship_ids: list[str] | None = None - """List of relationship IDs related to the community (optional).""" + collection_id: uuid.UUID | None = None + """The ID of the collection this community is associated with.""" - covariate_ids: dict[str, list[str]] | None = None - """Dictionary of different types of covariates related to the community (optional), e.g. claims""" + level: int | None = None + """Community level.""" - attributes: dict[str, Any] | None = None - """A dictionary of additional attributes associated with the community (optional). To be included in the search prompt.""" + name: str = "" + """The name of the community.""" summary: str = "" """Summary of the report.""" - full_content: str = "" - """Full content of the report.""" + findings: list[str] = [] + """Findings of the report.""" + + rating: float | None = None + """Rating of the report.""" - rank: float | None = 1.0 - """Rank of the report, used for sorting (optional). Higher means more important""" + rating_explanation: str | None = None + """Explanation of the rating.""" embedding: list[float] | None = None - """The semantic (i.e. text) embedding of the report summary (optional).""" + """Embedding of summary and findings.""" - full_content_embedding: list[float] | None = None - """The semantic (i.e. text) embedding of the full report content (optional).""" + attributes: dict[str, Any] | None = None + """A dictionary of additional attributes associated with the community (optional). To be included in the search prompt.""" def __init__(self, **kwargs): super().__init__(**kwargs)